Compare commits

...

159 Commits

Author SHA1 Message Date
Timothy 736756b257 chore: fix test 2026-03-20 21:22:29 -07:00
Timothy 90efe7009d chore: lint 2026-03-20 21:13:22 -07:00
Timothy 4adb369bde chore: lint 2026-03-20 21:12:03 -07:00
Timothy d4a30eb2f3 feat: image model fallback 2026-03-20 20:18:07 -07:00
Timothy 94bb4a2984 Merge branch 'main' into feat/image-capabilities 2026-03-20 18:42:55 -07:00
Timothy 648bad26ed feat: user input image content 2026-03-20 18:40:28 -07:00
RichardTang-Aden f0c7470f3d Merge pull request #6663 from sundaram2021/fix/missing-minimax-option-on-windows
fix: minimax option in powershell quickstart
2026-03-20 17:00:11 -07:00
RichardTang-Aden fe533b72a6 Merge pull request #6648 from levxn/main
Antigravity subscription support as an LLM provider
2026-03-20 16:52:38 -07:00
Richard Tang e581767cab chore: ruff lint 2026-03-20 16:50:50 -07:00
Richard Tang 0663ee5950 feat: validate the existing credentials before auth 2026-03-20 16:45:56 -07:00
Richard Tang 4b97baa34b feat: native google oauth for antigravity support 2026-03-20 16:40:15 -07:00
levxn a89296d397 lint fix 2026-03-21 02:35:09 +05:30
Levin d568912ba2 Merge branch 'aden-hive:main' into main 2026-03-21 01:32:13 +05:30
Levin c4d7980058 Merge pull request #1 from levxn/subscription/antigravity
Subscription/antigravity
2026-03-21 01:30:27 +05:30
Timothy @aden 8549fe8238 Merge pull request #6635 from vakrahul/fix/skill-structured-errors-6366
feat: structured skill error codes and diagnostics (closes #6366)
2026-03-20 12:45:35 -07:00
levxn 2b8d85bb95 fixing tool calling issue, antigravity's model's expected thought_signature in functioncall parts, else faces 400 error stating invalid arguments 2026-03-20 23:26:50 +05:30
levxn 07f7801166 test v1 2026-03-20 22:32:30 +05:30
Levin 1f12a45151 Merge branch 'aden-hive:main' into main 2026-03-20 22:01:22 +05:30
Arshad Uzzama Shaik 936e02e8e6 fix(security): prevent symlink-based sandbox escape in get_secure_path (closes #1167) (#5635)
* fix(security): prevent symlink-based sandbox escape in get_secure_path (closes #1167)

* style: apply ruff formatting to tools to satisfy CI

---------

Co-authored-by: Arshad Shaik <arshad.shaik@violetis.ai>
2026-03-20 19:16:47 +08:00
Hundao d59fe1e109 fix(graph): remove dead check_constraint placeholder (#6660)
Never called anywhere in the codebase. Constraints are enforced
via prompt context, not runtime validation.
2026-03-20 18:44:18 +08:00
Sundaram Kumar Jha 274318d3e5 fix: minimax option in powershell quickstart 2026-03-20 15:33:26 +05:30
Anurag Kumar 0f0884c2e0 fix(tools): handle non-HTML content and add PDF URL support (#438)
* feat(tools): add URL support to pdf_read tool

Enable pdf_read to accept both local file paths and HTTP/HTTPS URLs.
Downloads PDF content to temporary file when URL is provided, validates
content-type, and cleans up automatically after extraction.

- Detect URL inputs (http:// or https://)
- Download PDF with httpx (60s timeout)
- Validate Content-Type is application/pdf
- Use temporary file for URL-based PDFs
- Automatic cleanup in finally block
- Maintains backward compatibility with local paths

Completes the workflow: web_scrape error on PDF → pdf_read from URL

* test(tools): Add test coverage for new features in web_scrape and pdf_read tools

* style: fix lint issues in pdf_read URL support

---------

Co-authored-by: Anurag <anuragkr-codes@users.noreply.github.com>
Co-authored-by: hundao <alchemy_wimp@hotmail.com>
2026-03-20 16:36:25 +08:00
Timothy @aden 764012c598 Merge pull request #6652 from aden-hive/feature/absolutely-parallel
Release / Create Release (push) Waiting to run
fix: parallel subagent execution display, session resume bugs, and GCU termination
2026-03-19 20:21:47 -07:00
Timothy fd4dc1a69a fix: google_sheets JSON parse error before credentials check
Move _get_client() before JSON deserialization so missing-credentials
errors aren't masked by input validation. Wrap json.loads in try/except
for non-JSON string inputs.
2026-03-19 20:13:18 -07:00
Timothy 377cd39c2a chore: lint 2026-03-19 20:07:42 -07:00
Timothy e92caeef24 fix: line too long in google_sheets_tool
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-19 20:06:31 -07:00
Timothy @aden b7e6226478 Update asset link in README.md 2026-03-19 19:41:19 -07:00
Timothy a995818db2 fix: subagent bubble boundary 2026-03-19 17:57:33 -07:00
Timothy 0772b4d300 feat: better subagent interleave logic 2026-03-19 16:58:34 -07:00
Timothy 684e0d8dc6 fix: no memory consolidation for worker 2026-03-19 16:58:00 -07:00
Timothy d284c5d790 feat: parallel execution display 2026-03-19 15:25:21 -07:00
Timothy 7a9b9666c4 fix: refresh system prompt with preamble 2026-03-19 15:25:04 -07:00
Timothy a852cb91bf fix: non-blocking memory consolidation 2026-03-19 15:24:30 -07:00
Timothy 2f21e9eb4b fix: session reload preamble 2026-03-19 15:24:12 -07:00
Timothy 8390ef8731 fix: google sheet tool support json string input 2026-03-19 15:23:31 -07:00
levxn 8d21479c24 fixing lint errors 2026-03-20 02:34:58 +05:30
levxn 965dec3ba1 fixing errors, finalising credential fetch (client id and secret) properly in fallback paths 2026-03-20 02:32:42 +05:30
Timothy d4b54446be Merge branch 'main' into feat/image-capabilities 2026-03-19 11:12:33 -07:00
Levin 7992b862c2 Merge branch 'aden-hive:main' into main 2026-03-19 22:16:10 +05:30
Ananya Verma 44b3e0eaa2 Configure pytest to ignore DeprecationWarning (#1727)
Add pytest configuration to ignore specific warnings.
2026-03-19 23:17:50 +08:00
levxn f480fc2b94 oauth creds for antigravity picked properly 2026-03-19 20:26:42 +05:30
vakrahul 2844dbf19f feat: structured skill error codes and diagnostics (closes #6366) 2026-03-19 13:18:18 +05:30
Timothy @aden 22b7e4b0c3 Merge pull request #6624 from aden-hive/feature/agent-skills
Release / Create Release (push) Waiting to run
feat: agent skills system and observability improvements
2026-03-18 20:28:34 -07:00
Timothy 5413833a69 fix: tool test 2026-03-18 20:20:32 -07:00
bryan 02e1a4584a fix: autolaunch gui (windows) 2026-03-18 20:15:25 -07:00
Timothy 520840b1dd fix: no immediate run digest 2026-03-18 20:14:20 -07:00
bryan ee96147336 feat: autolaunch gui (mac) 2026-03-18 20:11:03 -07:00
Timothy 705cef4dc1 fix: context window display 2026-03-18 20:05:48 -07:00
Timothy ab26e64122 Merge remote-tracking branch 'origin/main' into feature/agent-skills 2026-03-18 19:41:39 -07:00
Timothy @aden f365e219cb Merge pull request #6615 from aden-hive/feat/worker-llm
feat: support separate LLM model for worker agents
2026-03-18 19:41:06 -07:00
Timothy 01621881c2 chore: lint 2026-03-18 19:40:41 -07:00
Timothy f7639f8572 fix: realtime context display 2026-03-18 19:29:31 -07:00
Timothy fc643060ce fix: better message bubble handling 2026-03-18 17:49:55 -07:00
Timothy 9aebeb181e feat: compaction debugger 2026-03-18 17:42:10 -07:00
Timothy acbbfaaa79 feat: compaction debug 2026-03-18 17:41:22 -07:00
Timothy bf170bce10 feat: enable mcp server reuse by default 2026-03-18 17:30:31 -07:00
Timothy 0a090d058b Merge remote-tracking branch 'origin/main' into feature/agent-skills 2026-03-18 17:11:12 -07:00
Timothy @aden 47bfadaad9 Merge pull request #6622 from aden-hive/fix/resume-empty-message
Fix empty queen message bubbles on session resume
2026-03-18 16:55:50 -07:00
Timothy d968dcd44c Merge branch 'main' into feature/agent-skills 2026-03-18 16:53:42 -07:00
Timothy @aden 6fdaa9ea50 Merge pull request #6534 from VasuBansal7576/codex/mcp-connection-manager-6348-draft
feat: add shared MCP connection manager
2026-03-18 16:52:44 -07:00
Timothy @aden 4d251fbdc2 Merge pull request #6531 from VasuBansal7576/codex/mcp-transports-6347-single
feat: add unix and sse MCP transports
2026-03-18 16:38:17 -07:00
Timothy 6acceed288 feat: hive debugger 2026-03-18 16:26:55 -07:00
Richard Tang 8dd1d6e3aa chore: lint 2026-03-18 16:01:32 -07:00
Timothy 1da28644a6 Merge branch 'main' into feature/agent-skills 2026-03-18 15:38:49 -07:00
Timothy 6452fe7fef fix: discord bot 2026-03-18 15:34:08 -07:00
Richard Tang acff008bd2 fix: empty message render 2026-03-18 15:26:56 -07:00
Timothy 651d6850a1 fix: bounty tracker change 2026-03-18 14:49:21 -07:00
Timothy c7fdc92594 fix: bounty script 2026-03-18 14:27:24 -07:00
Richard Tang 43602a8801 fix: trim to remove empty message 2026-03-18 13:55:57 -07:00
Timothy @aden 3da04265a6 Merge pull request #6566 from levxn/skills/context-protection
feat(skills): AS-9 and AS-10 — skill directory allowlisting and context protection for activated skills
2026-03-18 13:51:25 -07:00
Timothy @aden 4c98f0d2d0 Merge pull request #6564 from levxn/skills/resource-loading
feat(skills): AS-6 tier 3 resource loading — base_dir in catalog XML and skill dirs wired through execution stack
2026-03-18 13:50:54 -07:00
bryan d84c3364d0 chore: update to pass make test 2026-03-18 13:20:56 -07:00
Timothy @aden ae921f6cee Merge pull request #6619 from aden-hive/fix/claude-code-subscription-support
fix(llm): restore Claude Code subscription OAuth support
2026-03-18 13:08:27 -07:00
Timothy 6b506a1c08 chore: lint 2026-03-18 13:05:00 -07:00
Timothy 0c9f4fa97e fix(llm): restore Claude Code subscription (OAuth) support after Anthropic API change
Anthropic tightened OAuth validation on 2026-03-17, requiring a
specific User-Agent header and a billing integrity system block for
subscription-authenticated requests. Without these, all OAuth calls
return HTTP 400 with a generic "Error" message.

Changes:
- Add billing integrity system block (SHA-256 hash derived from first
  user message content) prepended to system messages on OAuth requests
- Set User-Agent to claude-code/<version> for OAuth sessions
- Fix OAuth header patch to detect tokens in x-api-key (not just
  Authorization) and add required beta/browser-access headers
- Set litellm.drop_params=True to prevent unsupported params like
  stream_options from leaking to Anthropic (causes 400)
- Skip stream_options entirely for Anthropic models
- Honour LITELLM_LOG env var for debug logging instead of hardcoding
  LiteLLM logger to WARNING
2026-03-18 13:02:24 -07:00
Richard Tang 95e30bc607 chore: remove old queen history endpoint 2026-03-18 12:43:30 -07:00
bryan 0f1f0090b0 chore: linter update 2026-03-18 12:41:01 -07:00
bryan c0da3bec02 feat: strip image content for non-vision models 2026-03-18 12:40:30 -07:00
bryan 9dadb5264d feat: add screenshot image passthrough to LLM 2026-03-18 12:40:18 -07:00
bryan e39e6a75cc feat: add ref system for aria snapshots 2026-03-18 12:36:51 -07:00
Richard Tang 23c66d1059 feat: worker model loading 2026-03-18 12:14:02 -07:00
Richard Tang b9d529d94e feat: support separate worker llm setup 2026-03-18 11:19:44 -07:00
Bryan @ Aden 1c9b09fb78 Merge pull request #6602 from sundaram2021/cleanup/remove-commit-message-txt
micro-fix: remove unnecessary commit message file
2026-03-18 17:40:50 +00:00
Timothy @aden 9fb14f23d2 Merge pull request #6526 from sundaram2021/feature/openrouter-api-key-support
feat openrouter api key support
2026-03-18 10:15:40 -07:00
Sundaram Kumar Jha 4795dc4f68 chore: clean useless commit message file 2026-03-18 16:45:10 +05:30
Sundaram Kumar Jha acf0f804c5 style(llm): apply ruff formatting 2026-03-18 10:54:06 +05:30
Sundaram Kumar Jha 4e2951854b fix(openrouter): harden quickstart setup and model validation 2026-03-18 10:39:58 +05:30
Sundaram Kumar Jha 80dfb429d7 refactor(review): remove out-of-scope PR changes 2026-03-18 10:39:48 +05:30
Timothy @aden 9c0ba77e22 Replace demo image with GitHub asset link
Updated README to include new asset link and removed demo image.
2026-03-17 20:59:14 -07:00
Timothy @aden 46b4651073 Merge pull request #6589 from aden-hive/fix/data-disclosure-gaps
Release / Create Release (push) Waiting to run
Fix data disclosure gaps, add worker run digests, clean up deprecated tools
2026-03-17 20:46:12 -07:00
Timothy 86dd5246c6 Merge remote-tracking branch 'origin/fix/resume-with-scheduler' into fix/data-disclosure-gaps 2026-03-17 20:44:28 -07:00
Timothy a1227c88ee Merge remote-tracking branch 'origin/fix/resume-with-scheduler' into fix/data-disclosure-gaps 2026-03-17 20:42:25 -07:00
Timothy 535d7ab568 fix: worker digest sub event 2026-03-17 20:41:56 -07:00
Richard Tang af10494b31 chore: ruff lint 2026-03-17 20:41:08 -07:00
Richard Tang 39c1042827 fix: fall back to queen-only session when worker load fails on cold restore 2026-03-17 20:38:41 -07:00
Richard Tang 16e7dc11f4 fix: don't overwrite meta in queen creation 2026-03-17 20:27:39 -07:00
Richard Tang 7a27babefd feat: track and resume the session by phase 2026-03-17 20:22:54 -07:00
Timothy d53ae9d51d fix: deprecated tests 2026-03-17 20:20:21 -07:00
Timothy 910cf7727d Merge remote-tracking branch 'origin/fix/resume-with-scheduler' into fix/data-disclosure-gaps 2026-03-17 20:14:25 -07:00
Timothy 1698605f15 chore: lint 2026-03-17 19:59:23 -07:00
Timothy eda124a123 chore: lint 2026-03-17 19:58:08 -07:00
Timothy 15e9ce8d2f Merge remote-tracking branch 'origin/feature/session-digest' into fix/data-disclosure-gaps 2026-03-17 19:45:07 -07:00
Timothy c01dd603d7 fix: digest invocation 2026-03-17 19:44:22 -07:00
Timothy 9d5157d69f feat: queen subscribe to worker digest 2026-03-17 19:23:43 -07:00
Timothy d78795bdf5 Merge remote-tracking branch 'origin/feature/session-digest' into fix/data-disclosure-gaps 2026-03-17 19:15:22 -07:00
Timothy ff2b7f473e fix: subagent execution 2026-03-17 19:15:07 -07:00
Timothy 73c9a91811 feat: add worker memory consolidation hooks 2026-03-17 19:14:07 -07:00
Timothy 27b765d902 Merge branch 'feature/session-digest' into fix/data-disclosure-gaps 2026-03-17 18:32:20 -07:00
Timothy fddba419be fix: minor issues 2026-03-17 18:30:57 -07:00
Timothy f42d6308e8 Merge branch 'main' into fix/data-disclosure-gaps 2026-03-17 17:50:36 -07:00
Timothy c167002754 fix: data disclosure gaps 2026-03-17 17:50:08 -07:00
Timothy @aden ea26ee7d0c Merge pull request #6568 from aden-hive/feature/node-focus-prompt
Inject execution-scope preamble into worker node system prompts
2026-03-17 17:38:49 -07:00
Richard Tang 5280e908b2 feat: change the agent last active time 2026-03-17 17:35:01 -07:00
RichardTang-Aden 1c5dd8c664 Merge pull request #5178 from Schlaflied/feat/sdr-agent-template
feat(templates): add SDR Agent sample template
2026-03-17 16:05:45 -07:00
Richard Tang 3aca153be5 fix: add missing flowchart and terminal nodes 2026-03-17 16:03:29 -07:00
Timothy 65c8e1653c chore: lint 2026-03-17 15:31:36 -07:00
Timothy 58e4fa918c feat: make worker node aware of boundaries 2026-03-17 15:28:41 -07:00
Timothy 3af13d3f90 feat: session digest for run scoped diary 2026-03-17 14:25:32 -07:00
levxn b799789dbe fixing lint 2026-03-18 02:15:58 +05:30
levxn 2cd73dfccc implements AS-9 and AS-10 2026-03-18 02:06:51 +05:30
levxn 57d77d5479 fixing lint 2026-03-18 01:32:24 +05:30
levxn 5814021773 skills trust gate merged properly into resource loading branch 2026-03-18 01:18:20 +05:30
levxn 4f4cc9c8ce halfway done commit 2026-03-18 00:59:35 +05:30
Timothy d9c840eee5 chore: resolve merge conflicts with feature/agent-skills
Integrate SkillsManager refactor from base branch. Trust gating (AS-13)
is now wired into SkillsManager._do_load() instead of inline in runner.py,
with the interactive flag passed through SkillsManagerConfig.
2026-03-17 11:55:11 -07:00
Timothy @aden d2eb86e534 Merge pull request #6540 from sundaram2021/fix/make-windows-compatibility
fix make test compatibility on windows
2026-03-17 11:41:32 -07:00
Timothy 03842353e4 Merge branch 'main' into feature/openrouter-api-key-support 2026-03-17 11:21:53 -07:00
Schlaflied 48747e20af fix: remove personal oauth credential entries from .gitignore
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-17 13:53:16 -04:00
Schlaflied 58af593af6 revert: remove unrelated changes from previous commit
Restore .claude/settings.json and revert .gitignore change
that were accidentally included in the sdr-agent refactor commit.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-17 13:52:44 -04:00
Schlaflied 450575a927 refactor(sdr-agent): reuse agent.start() in tui command and fix mock mode
- Replace duplicated setup code in tui command with agent.start(mock_mode=mock)
- Fix mock mode to use MockLLMProvider instead of llm=None
- Add demo_contacts.json sample data for template testing
- Untrack .claude/settings.json and add to .gitignore

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-17 13:52:10 -04:00
Schlaflied eac2bb19b2 fix(sdr-agent): fix agent runtime lifecycle and mcp config
- Replace self._executor with self._agent_runtime (AgentRuntime | None)
- Import AgentRuntime for proper type annotation
- Add missing await self._agent_runtime.start() in start() — runtime
  was created but never started, causing silent failures at runtime
- Add self._agent_runtime = None reset in stop() for clean restart
- Remove redundant self._graph is None guard in trigger_and_wait()
- Update mcp_servers.json with hive-tools server config
- Add credential file patterns to .gitignore

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-17 13:50:29 -04:00
Schlaflied 756a815bf0 feat(templates): add SDR Agent sample template 2026-03-17 13:50:05 -04:00
mma2027 23a7b080eb test: add comprehensive test suite for safe_eval (#4015)
* test: add comprehensive test suite for safe_eval sandboxed evaluator

Adds 113 tests across 14 test classes covering the full surface area of
the safe_eval expression evaluator used by edge conditions:

- Literals, data structures, arithmetic, unary/binary/boolean operators
- Short-circuit semantics for `and`/`or` (including guard patterns)
- Ternary expressions, variable lookup, subscript/attribute access
- Whitelisted function and method calls
- Security boundaries (private attrs, disallowed AST nodes, blocked builtins)
- Real-world EdgeSpec.condition_expr patterns from graph executor usage

* style: fix import sort order

---------

Co-authored-by: mma2027 <mma2027@users.noreply.github.com>
Co-authored-by: hundao <alchemy_wimp@hotmail.com>
2026-03-18 01:01:31 +08:00
mma2027 bf39bcdec9 fixed race condition deadlock, missing short-circuit eval, unhandled format exceptions (#4012) 2026-03-18 00:36:54 +08:00
Richard Tang 0276632491 Merge branch 'feat/graph-improvements' 2026-03-17 07:34:10 -07:00
RichardTang-Aden ae2993d0d1 Merge pull request #6528 from Antiarin/feat/trigger-nodes-in-draft-graph
Restore trigger nodes in the new flowchart
2026-03-16 20:54:36 -07:00
RichardTang-Aden d14d71f760 Merge pull request #6549 from aden-hive/staging
Release / Create Release (push) Waiting to run
release 0.7.2
2026-03-16 20:44:47 -07:00
Antiarin 738641d35f fix: correct trigger target, label, and SSE event data
- Add name and entry_node to all trigger SSE events (TRIGGER_AVAILABLE,
  TRIGGER_ACTIVATED, TRIGGER_DEACTIVATED) so frontend gets correct data
  immediately instead of guessing
- Use ep.entry_node from backend in polling instead of guessing first
  non-trigger node
- Compute cronToLabel from trigger config during polling so pill labels
  show human-readable schedule
- Fix AsyncMock for event_bus.publish in tests
2026-03-17 09:07:10 +05:30
Antiarin 22f5534f08 fix: ensure Queen calls remove_trigger when user asks to remove scheduler
Added explicit prompt guidance requiring the Queen to call the
remove_trigger tool instead of just saying "it's removed."
2026-03-17 09:07:10 +05:30
Antiarin b79e7eca73 feat: live update trigger pill and detail panel on save
- Handle trigger_updated SSE event to update graph node label and
  config in real time when cron or task is saved
- Use cronToLabel for human-readable schedule display in detail panel
- Add "Saved" button feedback for Save Cron and Save Task (2s toast)
- Update trigger pill label to reflect new schedule on cron save
2026-03-17 09:07:10 +05:30
Antiarin 28250dc45e feat: support cron editing via trigger update API
- Extend PATCH /triggers/{id} to accept trigger_config with cron
  validation via croniter and active timer restart
- Add TRIGGER_UPDATED SSE event so frontend updates in real time
- Update frontend API client to use updateTrigger with config support
- Add tests for task update, cron restart, and invalid cron rejection
2026-03-17 09:07:10 +05:30
Antiarin fe5df6a87a feat: restore trigger node rendering in DraftGraph
Trigger nodes (scheduler, webhook, etc.) stopped appearing after the
v0.7.0 refactor because DraftGraph had no trigger awareness.

- Extract shared utilities (cssVar, truncateLabel, trigger colors/icons,
  useTriggerColors, cronToLabel) into lib/graphUtils.ts
- Render trigger pills above the draft flowchart with pill shape, icons,
  countdown timers, active/inactive status, and click handling
- Draw dashed edges from trigger pills to the correct draft node using
  flowchartMap lookup
- Name all trigger layout constants, fix countdown text color bug
- Include trigger pill extent in SVG viewBox width

Closes #6344
2026-03-17 09:07:10 +05:30
levxn 88253883a3 tier 3 resource loading 2026-03-17 03:30:58 +05:30
Sundaram Kumar Jha ff7b5c7e27 fix: prepend ~/.local/bin to PATH so uv is found in Git Bash on Windows 2026-03-17 01:28:25 +05:30
levxn 6ed6e5b286 lint fixes 2026-03-17 00:32:14 +05:30
Vasu Bansal 30bb0ad5d8 style: format MCP connection manager 2026-03-16 23:46:44 +05:30
Vasu Bansal cb0845f5ba fix: wrap MCP manager cleanup condition 2026-03-16 23:41:36 +05:30
Levin ce2525b59c Merge branch 'aden-hive:main' into skills/trust-gating 2026-03-16 23:39:27 +05:30
levxn 1f77ec3831 fixed bug introduced with change in executor.py, AS-13 along with upstream's AS-1,2,3,4,5 2026-03-16 23:38:45 +05:30
Vasu Bansal 6ab5aa8004 style: format mcp client
Apply ruff formatting to satisfy CI on the MCP transport changes.
2026-03-16 23:19:49 +05:30
Vasu Bansal 4449cd8ee8 feat: add shared MCP connection manager 2026-03-16 23:10:26 +05:30
Vasu Bansal 8b60c03a0a feat: add unix and sse MCP transports
Implements unix socket and SSE MCP transports, adds reconnect-once retry for unix/SSE, and adds focused unit coverage.
2026-03-16 23:03:44 +05:30
Levin 0e98023e40 Merge branch 'aden-hive:main' into skills/trust-gating 2026-03-16 22:23:57 +05:30
Sundaram Kumar Jha 22bb07f00e chore: resolve merge conflict 2026-03-16 19:59:57 +05:30
Sundaram Kumar Jha 660f883197 style(core): apply ruff formatting to satisfy CI lint 2026-03-16 19:57:21 +05:30
Sundaram Kumar Jha 988de80b66 Merge branch 'main' into feature/openrouter-api-key-support 2026-03-16 19:51:04 +05:30
Sundaram Kumar Jha dc6aa226ee feat(openrouter): validate model readiness and harden tool-call handling
- add OpenRouter chat completion validation to key checks for quickstart flows

- improve OpenRouter compat parsing to convert plain textual tool calls into real tool events

- prevent tool-call text from leaking into assistant responses

- add regression tests for OpenRouter key checks and LiteLLM tool compat parsing
2026-03-16 19:39:11 +05:30
levxn 48a54b4ee2 implements AS-13, trusted gating for project level skills 2026-03-16 17:45:33 +05:30
Sundaram Kumar Jha a7b6b080ab chore(lockfiles): refresh generated lockfiles
- update frontend package-lock metadata after frontend validation\n- refresh uv.lock editable package version for the current workspace state
2026-03-14 20:50:51 +05:30
Sundaram Kumar Jha 9202cbd4d4 fix(openrouter): stabilize quickstart and tool execution
- add cross-platform OpenRouter quickstart setup, config fallbacks, and key validation\n- harden LiteLLM/OpenRouter tool execution, duplicate question handling, and worker loading UX\n- add backend and frontend regression coverage for OpenRouter flows
2026-03-14 20:48:58 +05:30
135 changed files with 16913 additions and 2048 deletions
+14 -4
View File
@@ -2,14 +2,22 @@ name: Bounty completed
description: Awards points and notifies Discord when a bounty PR is merged
on:
pull_request:
pull_request_target:
types: [closed]
workflow_dispatch:
inputs:
pr_number:
description: "PR number to process (for missed bounties)"
required: true
type: number
jobs:
bounty-notify:
if: >
github.event.pull_request.merged == true &&
contains(join(github.event.pull_request.labels.*.name, ','), 'bounty:')
github.event_name == 'workflow_dispatch' ||
(github.event.pull_request.merged == true &&
contains(join(github.event.pull_request.labels.*.name, ','), 'bounty:'))
runs-on: ubuntu-latest
timeout-minutes: 5
permissions:
@@ -32,6 +40,8 @@ jobs:
GITHUB_REPOSITORY_OWNER: ${{ github.repository_owner }}
GITHUB_REPOSITORY_NAME: ${{ github.event.repository.name }}
DISCORD_WEBHOOK_URL: ${{ secrets.DISCORD_BOUNTY_WEBHOOK_URL }}
BOT_API_URL: ${{ secrets.BOT_API_URL }}
BOT_API_KEY: ${{ secrets.BOT_API_KEY }}
LURKR_API_KEY: ${{ secrets.LURKR_API_KEY }}
LURKR_GUILD_ID: ${{ secrets.LURKR_GUILD_ID }}
PR_NUMBER: ${{ github.event.pull_request.number }}
PR_NUMBER: ${{ inputs.pr_number || github.event.pull_request.number }}
-126
View File
@@ -1,126 +0,0 @@
name: Link Discord account
description: Auto-creates a PR to add contributor to contributors.yml when a link-discord issue is opened
on:
issues:
types: [opened]
jobs:
link-discord:
if: contains(github.event.issue.labels.*.name, 'link-discord')
runs-on: ubuntu-latest
timeout-minutes: 2
permissions:
contents: write
issues: write
pull-requests: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Parse issue and update contributors.yml
uses: actions/github-script@v7
with:
script: |
const fs = require('fs');
const issue = context.payload.issue;
const githubUsername = issue.user.login;
// Parse the issue body for form fields
const body = issue.body || '';
// Extract Discord ID — look for the numeric value after the "Discord User ID" heading
const discordMatch = body.match(/### Discord User ID\s*\n\s*(\d{17,20})/);
if (!discordMatch) {
await github.rest.issues.createComment({
...context.repo,
issue_number: issue.number,
body: `Could not find a valid Discord ID in the issue body. Please make sure you entered a numeric ID (17-20 digits), not a username.\n\nExample: \`123456789012345678\``
});
await github.rest.issues.update({
...context.repo,
issue_number: issue.number,
state: 'closed',
state_reason: 'not_planned'
});
return;
}
const discordId = discordMatch[1];
// Extract display name (optional)
const nameMatch = body.match(/### Display Name \(optional\)\s*\n\s*(.+)/);
const displayName = nameMatch ? nameMatch[1].trim() : '';
// Check if user already exists
const yml = fs.readFileSync('contributors.yml', 'utf-8');
if (yml.includes(`github: ${githubUsername}`)) {
await github.rest.issues.createComment({
...context.repo,
issue_number: issue.number,
body: `@${githubUsername} is already in \`contributors.yml\`. If you need to update your Discord ID, please edit the file directly via PR.`
});
await github.rest.issues.update({
...context.repo,
issue_number: issue.number,
state: 'closed',
state_reason: 'completed'
});
return;
}
// Append entry to contributors.yml
let entry = ` - github: ${githubUsername}\n discord: "${discordId}"`;
if (displayName && displayName !== '_No response_') {
entry += `\n name: ${displayName}`;
}
entry += '\n';
const updated = yml.trimEnd() + '\n' + entry;
fs.writeFileSync('contributors.yml', updated);
// Set outputs for commit step
core.exportVariable('GITHUB_USERNAME', githubUsername);
core.exportVariable('DISCORD_ID', discordId);
core.exportVariable('ISSUE_NUMBER', issue.number.toString());
- name: Create PR
run: |
# Check if there are changes
if git diff --quiet contributors.yml; then
echo "No changes to contributors.yml"
exit 0
fi
BRANCH="docs/link-discord-${GITHUB_USERNAME}"
git config user.name "github-actions[bot]"
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
git checkout -b "$BRANCH"
git add contributors.yml
git commit -m "docs: link @${GITHUB_USERNAME} to Discord"
git push origin "$BRANCH"
gh pr create \
--title "docs: link @${GITHUB_USERNAME} to Discord" \
--body "Adds @${GITHUB_USERNAME} (Discord \`${DISCORD_ID}\`) to \`contributors.yml\` for bounty XP tracking.
Closes #${ISSUE_NUMBER}" \
--base main \
--head "$BRANCH" \
--label "link-discord"
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Notify on issue
uses: actions/github-script@v7
with:
script: |
const username = process.env.GITHUB_USERNAME;
const issueNumber = parseInt(process.env.ISSUE_NUMBER);
await github.rest.issues.createComment({
...context.repo,
issue_number: issueNumber,
body: `A PR has been created to link your account. A maintainer will merge it shortly — once merged, you'll receive XP and Discord pings when your bounty PRs are merged.`
});
+2
View File
@@ -35,6 +35,8 @@ jobs:
GITHUB_REPOSITORY_OWNER: ${{ github.repository_owner }}
GITHUB_REPOSITORY_NAME: ${{ github.event.repository.name }}
DISCORD_WEBHOOK_URL: ${{ secrets.DISCORD_BOUNTY_WEBHOOK_URL }}
BOT_API_URL: ${{ secrets.BOT_API_URL }}
BOT_API_KEY: ${{ secrets.BOT_API_KEY }}
LURKR_API_KEY: ${{ secrets.LURKR_API_KEY }}
LURKR_GUILD_ID: ${{ secrets.LURKR_GUILD_ID }}
SINCE_DATE: ${{ github.event.inputs.since_date || '' }}
-1
View File
@@ -68,7 +68,6 @@ temp/
exports/*
.claude/settings.local.json
.claude/skills/ship-it/
.venv
+9 -2
View File
@@ -1,4 +1,11 @@
.PHONY: lint format check test install-hooks help frontend-install frontend-dev frontend-build
.PHONY: lint format check test test-tools test-live test-all install-hooks help frontend-install frontend-dev frontend-build
# ── Ensure uv is findable in Git Bash on Windows ──────────────────────────────
# uv installs to ~/.local/bin on Windows/Linux/macOS. Git Bash may not include
# this in PATH by default, so we prepend it here.
export PATH := $(HOME)/.local/bin:$(PATH)
# ── Targets ───────────────────────────────────────────────────────────────────
help: ## Show this help
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | \
@@ -46,4 +53,4 @@ frontend-dev: ## Start frontend dev server
cd core/frontend && npm run dev
frontend-build: ## Build frontend for production
cd core/frontend && npm run build
cd core/frontend && npm run build
+3 -1
View File
@@ -41,7 +41,9 @@ Generate a swarm of worker agents with a coding agent(queen) that control them.
Visit [adenhq.com](https://adenhq.com) for complete documentation, examples, and guides.
[![Hive Demo](https://img.youtube.com/vi/XDOG9fOaLjU/maxresdefault.jpg)](https://www.youtube.com/watch?v=XDOG9fOaLjU)
https://github.com/user-attachments/assets/bf10edc3-06ba-48b6-98ba-d069b15fb69d
## Who Is Hive For?
-31
View File
@@ -1,31 +0,0 @@
perf: reduce subprocess spawning in quickstart scripts (#4427)
## Problem
Windows process creation (CreateProcess) is 10-100x slower than Linux fork/exec.
The quickstart scripts were spawning 4+ separate `uv run python -c "import X"`
processes to verify imports, adding ~600ms overhead on Windows.
## Solution
Consolidated all import checks into a single batch script that checks multiple
modules in one subprocess call, reducing spawn overhead by ~75%.
## Changes
- **New**: `scripts/check_requirements.py` - Batched import checker
- **New**: `scripts/test_check_requirements.py` - Test suite
- **New**: `scripts/benchmark_quickstart.ps1` - Performance benchmark tool
- **Modified**: `quickstart.ps1` - Updated import verification (2 sections)
- **Modified**: `quickstart.sh` - Updated import verification
## Performance Impact
**Benchmark results on Windows:**
- Before: ~19.8 seconds for import checks
- After: ~4.9 seconds for import checks
- **Improvement: 14.9 seconds saved (75.2% faster)**
## Testing
- ✅ All functional tests pass (`scripts/test_check_requirements.py`)
- ✅ Quickstart scripts work correctly on Windows
- ✅ Error handling verified (invalid imports reported correctly)
- ✅ Performance benchmark confirms 75%+ improvement
Fixes #4427
-27
View File
@@ -1,27 +0,0 @@
# Identity mapping: GitHub username -> Discord ID
#
# This file links GitHub accounts to Discord accounts for the
# Integration Bounty Program. When a bounty PR is merged, the
# GitHub Action uses this file to ping the contributor on Discord.
#
# HOW TO ADD YOURSELF:
# Open a "Link Discord Account" issue:
# https://github.com/aden-hive/hive/issues/new?template=link-discord.yml
# A GitHub Action will automatically add your entry here.
#
# To find your Discord ID:
# 1. Open Discord Settings > Advanced > Enable Developer Mode
# 2. Right-click your name > Copy User ID
#
# Format:
# - github: your-github-username
# discord: "your-discord-id" # quotes required (it's a number)
# name: Your Display Name # optional
contributors:
# - github: example-user
# discord: "123456789012345678"
# name: Example User
- github: TimothyZhang7
discord: "408460790061072384"
name: Timothy@Aden
+583
View File
@@ -0,0 +1,583 @@
#!/usr/bin/env python3
"""Antigravity authentication CLI.
Implements OAuth2 flow for Google's Antigravity Code Assist gateway.
Credentials are stored in ~/.hive/antigravity-accounts.json.
Usage:
python -m antigravity_auth auth account add
python -m antigravity_auth auth account list
python -m antigravity_auth auth account remove <email>
"""
from __future__ import annotations
import argparse
import json
import logging
import os
import secrets
import socket
import sys
import time
import urllib.parse
import urllib.request
import webbrowser
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import Any
logging.basicConfig(level=logging.INFO, format="%(message)s")
logger = logging.getLogger(__name__)
# OAuth endpoints
_OAUTH_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
_OAUTH_TOKEN_URL = "https://oauth2.googleapis.com/token"
# Scopes for Antigravity/Cloud Code Assist
_OAUTH_SCOPES = [
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
]
# Credentials file path in ~/.hive/
_ACCOUNTS_FILE = Path.home() / ".hive" / "antigravity-accounts.json"
# Default project ID
_DEFAULT_PROJECT_ID = "rising-fact-p41fc"
_DEFAULT_REDIRECT_PORT = 51121
# OAuth credentials fetched from the opencode-antigravity-auth project.
# This project reverse-engineered and published the public OAuth credentials
# for Google's Antigravity/Cloud Code Assist API.
# Source: https://github.com/NoeFabris/opencode-antigravity-auth
_CREDENTIALS_URL = (
"https://raw.githubusercontent.com/NoeFabris/opencode-antigravity-auth/dev/src/constants.ts"
)
# Cached credentials fetched from public source
_cached_client_id: str | None = None
_cached_client_secret: str | None = None
def _fetch_credentials_from_public_source() -> tuple[str | None, str | None]:
"""Fetch OAuth client ID and secret from the public npm package source on GitHub."""
global _cached_client_id, _cached_client_secret
if _cached_client_id and _cached_client_secret:
return _cached_client_id, _cached_client_secret
try:
req = urllib.request.Request(
_CREDENTIALS_URL, headers={"User-Agent": "Hive-Antigravity-Auth/1.0"}
)
with urllib.request.urlopen(req, timeout=10) as resp:
content = resp.read().decode("utf-8")
import re
id_match = re.search(r'ANTIGRAVITY_CLIENT_ID\s*=\s*"([^"]+)"', content)
secret_match = re.search(r'ANTIGRAVITY_CLIENT_SECRET\s*=\s*"([^"]+)"', content)
if id_match:
_cached_client_id = id_match.group(1)
if secret_match:
_cached_client_secret = secret_match.group(1)
return _cached_client_id, _cached_client_secret
except Exception as e:
logger.debug(f"Failed to fetch credentials from public source: {e}")
return None, None
def get_client_id() -> str:
"""Get OAuth client ID from env, config, or public source."""
env_id = os.environ.get("ANTIGRAVITY_CLIENT_ID")
if env_id:
return env_id
# Try hive config
hive_cfg = Path.home() / ".hive" / "configuration.json"
if hive_cfg.exists():
try:
with open(hive_cfg) as f:
cfg = json.load(f)
cfg_id = cfg.get("llm", {}).get("antigravity_client_id")
if cfg_id:
return cfg_id
except Exception:
pass
# Fetch from public source
client_id, _ = _fetch_credentials_from_public_source()
if client_id:
return client_id
raise RuntimeError("Could not obtain Antigravity OAuth client ID")
def get_client_secret() -> str | None:
"""Get OAuth client secret from env, config, or public source."""
secret = os.environ.get("ANTIGRAVITY_CLIENT_SECRET")
if secret:
return secret
# Try to read from hive config
hive_cfg = Path.home() / ".hive" / "configuration.json"
if hive_cfg.exists():
try:
with open(hive_cfg) as f:
cfg = json.load(f)
secret = cfg.get("llm", {}).get("antigravity_client_secret")
if secret:
return secret
except Exception:
pass
# Fetch from public source (npm package on GitHub)
_, secret = _fetch_credentials_from_public_source()
return secret
def find_free_port() -> int:
"""Find an available local port."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
s.listen(1)
return s.getsockname()[1]
class OAuthCallbackHandler(BaseHTTPRequestHandler):
"""Handle OAuth callback from browser."""
auth_code: str | None = None
state: str | None = None
error: str | None = None
def log_message(self, format: str, *args: Any) -> None:
pass # Suppress default logging
def do_GET(self) -> None:
parsed = urllib.parse.urlparse(self.path)
if parsed.path == "/oauth-callback":
query = urllib.parse.parse_qs(parsed.query)
if "error" in query:
self.error = query["error"][0]
self._send_response("Authentication failed. You can close this window.")
return
if "code" in query and "state" in query:
OAuthCallbackHandler.auth_code = query["code"][0]
OAuthCallbackHandler.state = query["state"][0]
self._send_response(
"Authentication successful! You can close this window "
"and return to the terminal."
)
return
self._send_response("Waiting for authentication...")
def _send_response(self, message: str) -> None:
self.send_response(200)
self.send_header("Content-Type", "text/html")
self.end_headers()
html = f"""<!DOCTYPE html>
<html>
<head><title>Antigravity Auth</title></head>
<body style="font-family: system-ui; display: flex; align-items: center;
justify-content: center; height: 100vh; margin: 0; background: #1a1a2e;
color: #eee;">
<div style="text-align: center;">
<h2>{message}</h2>
</div>
</body>
</html>"""
self.wfile.write(html.encode())
def wait_for_callback(port: int, timeout: int = 300) -> tuple[str | None, str | None, str | None]:
"""Start local server and wait for OAuth callback."""
server = HTTPServer(("localhost", port), OAuthCallbackHandler)
server.timeout = 1
start = time.time()
while time.time() - start < timeout:
if OAuthCallbackHandler.auth_code:
return (
OAuthCallbackHandler.auth_code,
OAuthCallbackHandler.state,
OAuthCallbackHandler.error,
)
server.handle_request()
return None, None, "timeout"
def exchange_code_for_tokens(
code: str, redirect_uri: str, client_id: str, client_secret: str | None
) -> dict[str, Any] | None:
"""Exchange authorization code for tokens."""
data = {
"code": code,
"client_id": client_id,
"redirect_uri": redirect_uri,
"grant_type": "authorization_code",
}
if client_secret:
data["client_secret"] = client_secret
body = urllib.parse.urlencode(data).encode()
req = urllib.request.Request(
_OAUTH_TOKEN_URL,
data=body,
headers={"Content-Type": "application/x-www-form-urlencoded"},
method="POST",
)
try:
with urllib.request.urlopen(req, timeout=30) as resp:
return json.loads(resp.read())
except Exception as e:
logger.error(f"Token exchange failed: {e}")
return None
def get_user_email(access_token: str) -> str | None:
"""Get user email from Google API."""
req = urllib.request.Request(
"https://www.googleapis.com/oauth2/v2/userinfo",
headers={"Authorization": f"Bearer {access_token}"},
)
try:
with urllib.request.urlopen(req, timeout=10) as resp:
data = json.loads(resp.read())
return data.get("email")
except Exception:
return None
def load_accounts() -> dict[str, Any]:
"""Load existing accounts from file."""
if not _ACCOUNTS_FILE.exists():
return {"schemaVersion": 4, "accounts": []}
try:
with open(_ACCOUNTS_FILE) as f:
return json.load(f)
except Exception:
return {"schemaVersion": 4, "accounts": []}
def save_accounts(data: dict[str, Any]) -> None:
"""Save accounts to file."""
_ACCOUNTS_FILE.parent.mkdir(parents=True, exist_ok=True)
with open(_ACCOUNTS_FILE, "w") as f:
json.dump(data, f, indent=2)
logger.info(f"Saved credentials to {_ACCOUNTS_FILE}")
def validate_credentials(access_token: str, project_id: str = _DEFAULT_PROJECT_ID) -> bool:
"""Test if credentials work by making a simple API call to Antigravity.
Returns True if credentials are valid, False otherwise.
"""
endpoint = "https://daily-cloudcode-pa.sandbox.googleapis.com"
body = {
"project": project_id,
"model": "gemini-3-flash",
"request": {
"contents": [{"role": "user", "parts": [{"text": "hi"}]}],
"generationConfig": {"maxOutputTokens": 10},
},
"requestType": "agent",
"userAgent": "antigravity",
"requestId": "validation-test",
}
headers = {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
"User-Agent": (
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
"AppleWebKit/537.36 (KHTML, like Gecko) Antigravity/1.18.3"
),
"X-Goog-Api-Client": "google-cloud-sdk vscode_cloudshelleditor/0.1",
}
try:
req = urllib.request.Request(
f"{endpoint}/v1internal:generateContent",
data=json.dumps(body).encode("utf-8"),
headers=headers,
method="POST",
)
with urllib.request.urlopen(req, timeout=30) as resp:
json.loads(resp.read())
return True
except Exception:
return False
def refresh_access_token(
refresh_token: str, client_id: str, client_secret: str | None
) -> dict | None:
"""Refresh the access token using the refresh token."""
data = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": client_id,
}
if client_secret:
data["client_secret"] = client_secret
body = urllib.parse.urlencode(data).encode()
req = urllib.request.Request(
_OAUTH_TOKEN_URL,
data=body,
headers={"Content-Type": "application/x-www-form-urlencoded"},
method="POST",
)
try:
with urllib.request.urlopen(req, timeout=30) as resp:
return json.loads(resp.read())
except Exception as e:
logger.debug(f"Token refresh failed: {e}")
return None
def cmd_account_add(args: argparse.Namespace) -> int:
"""Add a new Antigravity account via OAuth2.
First checks if valid credentials already exist. If so, validates them
and skips OAuth if they work. Otherwise, proceeds with OAuth flow.
"""
client_id = get_client_id()
client_secret = get_client_secret()
# Check if credentials already exist
accounts_data = load_accounts()
accounts = accounts_data.get("accounts", [])
if accounts:
account = next((a for a in accounts if a.get("enabled", True) is not False), accounts[0])
access_token = account.get("access")
refresh_token_str = account.get("refresh", "")
refresh_token = refresh_token_str.split("|")[0] if refresh_token_str else None
project_id = (
refresh_token_str.split("|")[1] if "|" in refresh_token_str else _DEFAULT_PROJECT_ID
)
email = account.get("email", "unknown")
expires_ms = account.get("expires", 0)
expires_at = expires_ms / 1000.0 if expires_ms else 0.0
# Check if token is expired or near expiry
if access_token and expires_at and time.time() < expires_at - 60:
# Token still valid, test it
logger.info(f"Found existing credentials for: {email}")
logger.info("Validating existing credentials...")
if validate_credentials(access_token, project_id):
logger.info("✓ Credentials valid! Skipping OAuth.")
return 0
else:
logger.info("Credentials failed validation, refreshing...")
elif refresh_token:
logger.info(f"Found expired credentials for: {email}")
logger.info("Attempting token refresh...")
tokens = refresh_access_token(refresh_token, client_id, client_secret)
if tokens:
new_access = tokens.get("access_token")
expires_in = tokens.get("expires_in", 3600)
if new_access:
# Update the account
account["access"] = new_access
account["expires"] = int((time.time() + expires_in) * 1000)
accounts_data["last_refresh"] = time.strftime(
"%Y-%m-%dT%H:%M:%SZ", time.gmtime()
)
save_accounts(accounts_data)
# Validate the refreshed token
logger.info("Validating refreshed credentials...")
if validate_credentials(new_access, project_id):
logger.info("✓ Credentials refreshed and validated!")
return 0
else:
logger.info("Refreshed token failed validation, proceeding with OAuth...")
else:
logger.info("Token refresh failed, proceeding with OAuth...")
# No valid credentials, proceed with OAuth
if not client_secret:
logger.warning(
"No client secret configured. Token refresh may fail.\n"
"Set ANTIGRAVITY_CLIENT_SECRET env var or add "
"'antigravity_client_secret' to ~/.hive/configuration.json"
)
# Use fixed port and path matching Google's expected OAuth redirect URI
port = _DEFAULT_REDIRECT_PORT
redirect_uri = f"http://localhost:{port}/oauth-callback"
# Generate state for CSRF protection
state = secrets.token_urlsafe(16)
# Build authorization URL
params = {
"client_id": client_id,
"redirect_uri": redirect_uri,
"response_type": "code",
"scope": " ".join(_OAUTH_SCOPES),
"state": state,
"access_type": "offline",
"prompt": "consent",
}
auth_url = f"{_OAUTH_AUTH_URL}?{urllib.parse.urlencode(params)}"
logger.info("Opening browser for authentication...")
logger.info(f"If the browser doesn't open, visit: {auth_url}\n")
# Open browser
webbrowser.open(auth_url)
# Wait for callback
logger.info(f"Listening for callback on port {port}...")
code, received_state, error = wait_for_callback(port)
if error:
logger.error(f"Authentication failed: {error}")
return 1
if not code:
logger.error("No authorization code received")
return 1
if received_state != state:
logger.error("State mismatch - possible CSRF attack")
return 1
# Exchange code for tokens
logger.info("Exchanging authorization code for tokens...")
tokens = exchange_code_for_tokens(code, redirect_uri, client_id, client_secret)
if not tokens:
return 1
access_token = tokens.get("access_token")
refresh_token = tokens.get("refresh_token")
expires_in = tokens.get("expires_in", 3600)
if not access_token:
logger.error("No access token in response")
return 1
# Get user email
email = get_user_email(access_token)
if email:
logger.info(f"Authenticated as: {email}")
# Load existing accounts and add/update
accounts_data = load_accounts()
accounts = accounts_data.get("accounts", [])
# Build new account entry (V4 schema)
expires_ms = int((time.time() + expires_in) * 1000)
refresh_entry = f"{refresh_token}|{_DEFAULT_PROJECT_ID}"
new_account = {
"access": access_token,
"refresh": refresh_entry,
"expires": expires_ms,
"email": email,
"enabled": True,
}
# Update existing account or add new one
existing_idx = next((i for i, a in enumerate(accounts) if a.get("email") == email), None)
if existing_idx is not None:
accounts[existing_idx] = new_account
logger.info(f"Updated existing account: {email}")
else:
accounts.append(new_account)
logger.info(f"Added new account: {email}")
accounts_data["accounts"] = accounts
accounts_data["schemaVersion"] = 4
accounts_data["last_refresh"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
save_accounts(accounts_data)
logger.info("\n✓ Authentication complete!")
return 0
def cmd_account_list(args: argparse.Namespace) -> int:
"""List all stored accounts."""
data = load_accounts()
accounts = data.get("accounts", [])
if not accounts:
logger.info("No accounts configured.")
logger.info("Run 'antigravity auth account add' to add one.")
return 0
logger.info("Configured accounts:\n")
for i, account in enumerate(accounts, 1):
email = account.get("email", "unknown")
enabled = "enabled" if account.get("enabled", True) else "disabled"
logger.info(f" {i}. {email} ({enabled})")
return 0
def cmd_account_remove(args: argparse.Namespace) -> int:
"""Remove an account by email."""
email = args.email
data = load_accounts()
accounts = data.get("accounts", [])
original_len = len(accounts)
accounts = [a for a in accounts if a.get("email") != email]
if len(accounts) == original_len:
logger.error(f"No account found with email: {email}")
return 1
data["accounts"] = accounts
save_accounts(data)
logger.info(f"Removed account: {email}")
return 0
def main() -> int:
parser = argparse.ArgumentParser(
description="Antigravity authentication CLI",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
subparsers = parser.add_subparsers(dest="command", help="Commands")
# auth account add
auth_parser = subparsers.add_parser("auth", help="Authentication commands")
auth_subparsers = auth_parser.add_subparsers(dest="auth_command")
account_parser = auth_subparsers.add_parser("account", help="Account management")
account_subparsers = account_parser.add_subparsers(dest="account_command")
add_parser = account_subparsers.add_parser("add", help="Add a new account via OAuth2")
add_parser.set_defaults(func=cmd_account_add)
list_parser = account_subparsers.add_parser("list", help="List configured accounts")
list_parser.set_defaults(func=cmd_account_list)
remove_parser = account_subparsers.add_parser("remove", help="Remove an account")
remove_parser.add_argument("email", help="Email of account to remove")
remove_parser.set_defaults(func=cmd_account_remove)
args = parser.parse_args()
if hasattr(args, "func"):
return args.func(args)
parser.print_help()
return 0
if __name__ == "__main__":
sys.exit(main())
+50 -19
View File
@@ -23,25 +23,56 @@ class AgentEntry:
last_active: str | None = None
def _get_last_active(agent_name: str) -> str | None:
"""Return the most recent updated_at timestamp across all sessions."""
sessions_dir = Path.home() / ".hive" / "agents" / agent_name / "sessions"
if not sessions_dir.exists():
return None
def _get_last_active(agent_path: Path) -> str | None:
"""Return the most recent updated_at timestamp across all sessions.
Checks both worker sessions (``~/.hive/agents/{name}/sessions/``) and
queen sessions (``~/.hive/queen/session/``) whose ``meta.json`` references
the same *agent_path*.
"""
from datetime import datetime
agent_name = agent_path.name
latest: str | None = None
for session_dir in sessions_dir.iterdir():
if not session_dir.is_dir() or not session_dir.name.startswith("session_"):
continue
state_file = session_dir / "state.json"
if not state_file.exists():
continue
try:
data = json.loads(state_file.read_text(encoding="utf-8"))
ts = data.get("timestamps", {}).get("updated_at")
if ts and (latest is None or ts > latest):
latest = ts
except Exception:
continue
# 1. Worker sessions
sessions_dir = Path.home() / ".hive" / "agents" / agent_name / "sessions"
if sessions_dir.exists():
for session_dir in sessions_dir.iterdir():
if not session_dir.is_dir() or not session_dir.name.startswith("session_"):
continue
state_file = session_dir / "state.json"
if not state_file.exists():
continue
try:
data = json.loads(state_file.read_text(encoding="utf-8"))
ts = data.get("timestamps", {}).get("updated_at")
if ts and (latest is None or ts > latest):
latest = ts
except Exception:
continue
# 2. Queen sessions
queen_sessions_dir = Path.home() / ".hive" / "queen" / "session"
if queen_sessions_dir.exists():
resolved = agent_path.resolve()
for d in queen_sessions_dir.iterdir():
if not d.is_dir():
continue
meta_file = d / "meta.json"
if not meta_file.exists():
continue
try:
meta = json.loads(meta_file.read_text(encoding="utf-8"))
stored = meta.get("agent_path")
if not stored or Path(stored).resolve() != resolved:
continue
ts = datetime.fromtimestamp(d.stat().st_mtime).isoformat()
if latest is None or ts > latest:
latest = ts
except Exception:
continue
return latest
@@ -169,7 +200,7 @@ def discover_agents() -> dict[str, list[AgentEntry]]:
node_count=node_count,
tool_count=tool_count,
tags=tags,
last_active=_get_last_active(path.name),
last_active=_get_last_active(path),
)
)
if entries:
@@ -702,6 +702,15 @@ stop_worker() to return to STAGING phase.
_queen_behavior_always = """
# Behavior
## Images attached by the user
Users can attach images directly to their chat messages. When you see an \
image in the conversation, analyze it using your native vision capability \
do NOT say you cannot see images or that you lack access to files. The image \
is embedded in the message; no tool call is needed to view it. Describe what \
you see, answer questions about it, and use the visual content to inform your \
response just as you would text.
## CRITICAL RULE — ask_user / ask_user_multiple
Every response that ends with a question, a prompt, or expects user \
@@ -1144,6 +1153,8 @@ Batch your response — do not call run_agent_with_input() once per trigger.
config since last run), skip it and inform the user.
- Never disable a trigger without telling the user. Use remove_trigger() only \
when explicitly asked or when the trigger is clearly obsolete.
- When the user asks to remove or disable a trigger, you MUST call remove_trigger(trigger_id). \
Never just say "it's removed" without actually calling the tool.
"""
# -- Backward-compatible composed versions (used by queen_node.system_prompt default) --
@@ -150,7 +150,7 @@ Call all three subagents in a single response to run them in parallel:
## GCU Anti-Patterns
- Using `browser_screenshot` to read text (use `browser_snapshot`)
- Using `browser_screenshot` to read text (use `browser_snapshot` instead; screenshots are for visual context only)
- Re-navigating after scrolling (resets scroll position)
- Attempting login on auth walls
- Forgetting `target_id` in multi-tab scenarios
+286
View File
@@ -0,0 +1,286 @@
"""Worker per-run digest (run diary).
Storage layout:
~/.hive/agents/{agent_name}/runs/{run_id}/digest.md
Each completed or failed worker run gets one digest file. The queen reads
these via get_worker_status(focus='diary') before digging into live runtime
logs the diary is a cheap, persistent record that survives across sessions.
"""
from __future__ import annotations
import logging
import traceback
from collections import Counter
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from framework.runtime.event_bus import AgentEvent, EventBus
logger = logging.getLogger(__name__)
_DIGEST_SYSTEM = """\
You maintain run digests for a worker agent.
A run digest is a concise, factual record of a single task execution.
Write 3-6 sentences covering:
- What the worker was asked to do (the task/goal)
- What approach it took and what tools it used
- What the outcome was (success, partial, or failure and why if relevant)
- Any notable issues, retries, or escalations to the queen
Write in third person past tense. Be direct and specific.
Omit routine tool invocations unless the result matters.
Output only the digest prose no headings, no code fences.
"""
def _worker_runs_dir(agent_name: str) -> Path:
return Path.home() / ".hive" / "agents" / agent_name / "runs"
def digest_path(agent_name: str, run_id: str) -> Path:
return _worker_runs_dir(agent_name) / run_id / "digest.md"
def _collect_run_events(bus: EventBus, run_id: str, limit: int = 2000) -> list[AgentEvent]:
"""Collect all events belonging to *run_id* from the bus history.
Strategy: find the EXECUTION_STARTED event that carries ``run_id``,
extract its ``execution_id``, then query the bus by that execution_id.
This works because TOOL_CALL_*, EDGE_TRAVERSED, NODE_STALLED etc. carry
execution_id but not run_id.
Falls back to a full-scan run_id filter when EXECUTION_STARTED is not
found (e.g. bus was rotated).
"""
from framework.runtime.event_bus import EventType
# Pass 1: find execution_id via EXECUTION_STARTED with matching run_id
started = bus.get_history(event_type=EventType.EXECUTION_STARTED, limit=limit)
exec_id: str | None = None
for e in started:
if getattr(e, "run_id", None) == run_id and e.execution_id:
exec_id = e.execution_id
break
if exec_id:
return bus.get_history(execution_id=exec_id, limit=limit)
# Fallback: scan all events and match by run_id attribute
return [e for e in bus.get_history(limit=limit) if getattr(e, "run_id", None) == run_id]
def _build_run_context(
events: list[AgentEvent],
outcome_event: AgentEvent | None,
) -> str:
"""Assemble a plain-text run context string for the digest LLM call."""
from framework.runtime.event_bus import EventType
# Reverse so events are in chronological order
events_chron = list(reversed(events))
lines: list[str] = []
# Task input from EXECUTION_STARTED
started = [e for e in events_chron if e.type == EventType.EXECUTION_STARTED]
if started:
inp = started[0].data.get("input", {})
if inp:
lines.append(f"Task input: {str(inp)[:400]}")
# Duration (elapsed so far if no outcome yet)
ref_ts = outcome_event.timestamp if outcome_event else datetime.utcnow()
if started:
elapsed = (ref_ts - started[0].timestamp).total_seconds()
m, s = divmod(int(elapsed), 60)
lines.append(f"Duration so far: {m}m {s}s" if m else f"Duration so far: {s}s")
# Outcome
if outcome_event is None:
lines.append("Status: still running (mid-run snapshot)")
elif outcome_event.type == EventType.EXECUTION_COMPLETED:
out = outcome_event.data.get("output", {})
out_str = f"Outcome: completed. Output: {str(out)[:300]}"
lines.append(out_str if out else "Outcome: completed.")
else:
err = outcome_event.data.get("error", "")
lines.append(f"Outcome: failed. Error: {str(err)[:300]}" if err else "Outcome: failed.")
# Node path (edge traversals)
edges = [e for e in events_chron if e.type == EventType.EDGE_TRAVERSED]
if edges:
parts = [
f"{e.data.get('source_node', '?')}->{e.data.get('target_node', '?')}"
for e in edges[-20:]
]
lines.append(f"Node path: {', '.join(parts)}")
# Tools used
tool_events = [e for e in events_chron if e.type == EventType.TOOL_CALL_COMPLETED]
if tool_events:
names = [e.data.get("tool_name", "?") for e in tool_events]
counts = Counter(names)
summary = ", ".join(f"{name}×{n}" if n > 1 else name for name, n in counts.most_common())
lines.append(f"Tools used: {summary}")
# Note any tool errors
errors = [e for e in tool_events if e.data.get("is_error")]
if errors:
err_names = Counter(e.data.get("tool_name", "?") for e in errors)
lines.append(f"Tool errors: {dict(err_names)}")
# Issues
issue_map = {
EventType.NODE_STALLED: "stall",
EventType.NODE_TOOL_DOOM_LOOP: "doom loop",
EventType.CONSTRAINT_VIOLATION: "constraint violation",
EventType.NODE_RETRY: "retry",
}
issue_parts: list[str] = []
for evt_type, label in issue_map.items():
n = sum(1 for e in events_chron if e.type == evt_type)
if n:
issue_parts.append(f"{n} {label}(s)")
if issue_parts:
lines.append(f"Issues: {', '.join(issue_parts)}")
# Escalations to queen
escalations = [e for e in events_chron if e.type == EventType.ESCALATION_REQUESTED]
if escalations:
lines.append(f"Escalations to queen: {len(escalations)}")
# Final LLM output snippet (last LLM_TEXT_DELTA snapshot)
text_events = [e for e in reversed(events_chron) if e.type == EventType.LLM_TEXT_DELTA]
if text_events:
snapshot = text_events[0].data.get("snapshot", "") or ""
if snapshot:
lines.append(f"Final LLM output: {snapshot[-400:].strip()}")
return "\n".join(lines)
async def consolidate_worker_run(
agent_name: str,
run_id: str,
outcome_event: AgentEvent | None,
bus: EventBus,
llm: Any,
) -> None:
"""Write (or overwrite) the digest for a worker run.
Called fire-and-forget either:
- After EXECUTION_COMPLETED / EXECUTION_FAILED (outcome_event set, final write)
- Periodically during a run on a cooldown timer (outcome_event=None, mid-run snapshot)
The digest file is always overwritten so each call produces the freshest view.
The final completion/failure call supersedes any mid-run snapshot.
Args:
agent_name: Worker agent directory name (determines storage path).
run_id: The run ID.
outcome_event: EXECUTION_COMPLETED or EXECUTION_FAILED event, or None for
a mid-run snapshot.
bus: The session EventBus (shared queen + worker).
llm: LLMProvider with an acomplete() method.
"""
try:
events = _collect_run_events(bus, run_id)
run_context = _build_run_context(events, outcome_event)
if not run_context:
logger.debug("worker_memory: no events for run %s, skipping digest", run_id)
return
is_final = outcome_event is not None
logger.info(
"worker_memory: generating %s digest for run %s ...",
"final" if is_final else "mid-run",
run_id,
)
from framework.agents.queen.config import default_config
resp = await llm.acomplete(
messages=[{"role": "user", "content": run_context}],
system=_DIGEST_SYSTEM,
max_tokens=min(default_config.max_tokens, 512),
)
digest_text = (resp.content or "").strip()
if not digest_text:
logger.warning("worker_memory: LLM returned empty digest for run %s", run_id)
return
path = digest_path(agent_name, run_id)
path.parent.mkdir(parents=True, exist_ok=True)
from framework.runtime.event_bus import EventType
ts = (outcome_event.timestamp if outcome_event else datetime.utcnow()).strftime(
"%Y-%m-%d %H:%M"
)
if outcome_event is None:
status = "running"
elif outcome_event.type == EventType.EXECUTION_COMPLETED:
status = "completed"
else:
status = "failed"
path.write_text(
f"# {run_id}\n\n**{ts}** | {status}\n\n{digest_text}\n",
encoding="utf-8",
)
logger.info(
"worker_memory: %s digest written for run %s (%d chars)",
status,
run_id,
len(digest_text),
)
except Exception:
tb = traceback.format_exc()
logger.exception("worker_memory: digest failed for run %s", run_id)
# Persist the error so it's findable without log access
error_path = _worker_runs_dir(agent_name) / run_id / "digest_error.txt"
try:
error_path.parent.mkdir(parents=True, exist_ok=True)
error_path.write_text(
f"run_id: {run_id}\ntime: {datetime.now().isoformat()}\n\n{tb}",
encoding="utf-8",
)
except Exception:
pass
def read_recent_digests(agent_name: str, max_runs: int = 5) -> list[tuple[str, str]]:
"""Return recent run digests as [(run_id, content), ...], newest first.
Args:
agent_name: Worker agent directory name.
max_runs: Maximum number of digests to return.
Returns:
List of (run_id, digest_content) tuples, ordered newest first.
"""
runs_dir = _worker_runs_dir(agent_name)
if not runs_dir.exists():
return []
digest_files = sorted(
runs_dir.glob("*/digest.md"),
key=lambda p: p.stat().st_mtime,
reverse=True,
)[:max_runs]
result: list[tuple[str, str]] = []
for f in digest_files:
try:
content = f.read_text(encoding="utf-8").strip()
if content:
result.append((f.parent.name, content))
except OSError:
continue
return result
+10
View File
@@ -89,6 +89,16 @@ def main():
register_testing_commands(subparsers)
# Register skill commands (skill list, skill trust, ...)
from framework.skills.cli import register_skill_commands
register_skill_commands(subparsers)
# Register debugger commands (debugger)
from framework.debugger.cli import register_debugger_commands
register_debugger_commands(subparsers)
args = parser.parse_args()
if hasattr(args, "func"):
+251 -2
View File
@@ -51,16 +51,167 @@ def get_preferred_model() -> str:
"""Return the user's preferred LLM model string (e.g. 'anthropic/claude-sonnet-4-20250514')."""
llm = get_hive_config().get("llm", {})
if llm.get("provider") and llm.get("model"):
return f"{llm['provider']}/{llm['model']}"
provider = str(llm["provider"])
model = str(llm["model"]).strip()
# OpenRouter quickstart stores raw model IDs; tolerate pasted "openrouter/<id>" too.
if provider.lower() == "openrouter" and model.lower().startswith("openrouter/"):
model = model[len("openrouter/") :]
if model:
return f"{provider}/{model}"
return "anthropic/claude-sonnet-4-20250514"
def get_preferred_worker_model() -> str | None:
"""Return the user's preferred worker LLM model, or None if not configured.
Reads from the ``worker_llm`` section of ~/.hive/configuration.json.
Returns None when no worker-specific model is set, so callers can
fall back to the default (queen) model via ``get_preferred_model()``.
"""
worker_llm = get_hive_config().get("worker_llm", {})
if worker_llm.get("provider") and worker_llm.get("model"):
provider = str(worker_llm["provider"])
model = str(worker_llm["model"]).strip()
if provider.lower() == "openrouter" and model.lower().startswith("openrouter/"):
model = model[len("openrouter/") :]
if model:
return f"{provider}/{model}"
return None
def get_worker_api_key() -> str | None:
"""Return the API key for the worker LLM, falling back to the default key."""
worker_llm = get_hive_config().get("worker_llm", {})
if not worker_llm:
return get_api_key()
# Worker-specific subscription / env var
if worker_llm.get("use_claude_code_subscription"):
try:
from framework.runner.runner import get_claude_code_token
token = get_claude_code_token()
if token:
return token
except ImportError:
pass
if worker_llm.get("use_codex_subscription"):
try:
from framework.runner.runner import get_codex_token
token = get_codex_token()
if token:
return token
except ImportError:
pass
if worker_llm.get("use_kimi_code_subscription"):
try:
from framework.runner.runner import get_kimi_code_token
token = get_kimi_code_token()
if token:
return token
except ImportError:
pass
if worker_llm.get("use_antigravity_subscription"):
try:
from framework.runner.runner import get_antigravity_token
token = get_antigravity_token()
if token:
return token
except ImportError:
pass
api_key_env_var = worker_llm.get("api_key_env_var")
if api_key_env_var:
return os.environ.get(api_key_env_var)
# Fall back to default key
return get_api_key()
def get_worker_api_base() -> str | None:
"""Return the api_base for the worker LLM, falling back to the default."""
worker_llm = get_hive_config().get("worker_llm", {})
if not worker_llm:
return get_api_base()
if worker_llm.get("use_codex_subscription"):
return "https://chatgpt.com/backend-api/codex"
if worker_llm.get("use_kimi_code_subscription"):
return "https://api.kimi.com/coding"
if worker_llm.get("use_antigravity_subscription"):
# Antigravity uses AntigravityProvider directly — no api_base needed.
return None
if worker_llm.get("api_base"):
return worker_llm["api_base"]
if str(worker_llm.get("provider", "")).lower() == "openrouter":
return OPENROUTER_API_BASE
return None
def get_worker_llm_extra_kwargs() -> dict[str, Any]:
"""Return extra kwargs for the worker LLM provider."""
worker_llm = get_hive_config().get("worker_llm", {})
if not worker_llm:
return get_llm_extra_kwargs()
if worker_llm.get("use_claude_code_subscription"):
api_key = get_worker_api_key()
if api_key:
return {
"extra_headers": {"authorization": f"Bearer {api_key}"},
}
if worker_llm.get("use_codex_subscription"):
api_key = get_worker_api_key()
if api_key:
headers: dict[str, str] = {
"Authorization": f"Bearer {api_key}",
"User-Agent": "CodexBar",
}
try:
from framework.runner.runner import get_codex_account_id
account_id = get_codex_account_id()
if account_id:
headers["ChatGPT-Account-Id"] = account_id
except ImportError:
pass
return {
"extra_headers": headers,
"store": False,
"allowed_openai_params": ["store"],
}
return {}
def get_worker_max_tokens() -> int:
"""Return max_tokens for the worker LLM, falling back to default."""
worker_llm = get_hive_config().get("worker_llm", {})
if worker_llm and "max_tokens" in worker_llm:
return worker_llm["max_tokens"]
return get_max_tokens()
def get_worker_max_context_tokens() -> int:
"""Return max_context_tokens for the worker LLM, falling back to default."""
worker_llm = get_hive_config().get("worker_llm", {})
if worker_llm and "max_context_tokens" in worker_llm:
return worker_llm["max_context_tokens"]
return get_max_context_tokens()
def get_max_tokens() -> int:
"""Return the configured max_tokens, falling back to DEFAULT_MAX_TOKENS."""
return get_hive_config().get("llm", {}).get("max_tokens", DEFAULT_MAX_TOKENS)
DEFAULT_MAX_CONTEXT_TOKENS = 32_000
OPENROUTER_API_BASE = "https://openrouter.ai/api/v1"
def get_max_context_tokens() -> int:
@@ -113,6 +264,17 @@ def get_api_key() -> str | None:
except ImportError:
pass
# Antigravity subscription: read OAuth token from accounts JSON
if llm.get("use_antigravity_subscription"):
try:
from framework.runner.runner import get_antigravity_token
token = get_antigravity_token()
if token:
return token
except ImportError:
pass
# Standard env-var path (covers ZAI Code and all API-key providers)
api_key_env_var = llm.get("api_key_env_var")
if api_key_env_var:
@@ -120,6 +282,86 @@ def get_api_key() -> str | None:
return None
# OAuth credentials for Antigravity are fetched from the opencode-antigravity-auth project.
# This project reverse-engineered and published the public OAuth credentials
# for Google's Antigravity/Cloud Code Assist API.
# Source: https://github.com/NoeFabris/opencode-antigravity-auth
_ANTIGRAVITY_CREDENTIALS_URL = (
"https://raw.githubusercontent.com/NoeFabris/opencode-antigravity-auth/dev/src/constants.ts"
)
_antigravity_credentials_cache: tuple[str | None, str | None] = (None, None)
def _fetch_antigravity_credentials() -> tuple[str | None, str | None]:
"""Fetch OAuth client ID and secret from the public npm package source on GitHub."""
global _antigravity_credentials_cache
if _antigravity_credentials_cache[0] and _antigravity_credentials_cache[1]:
return _antigravity_credentials_cache
import re
import urllib.request
try:
req = urllib.request.Request(
_ANTIGRAVITY_CREDENTIALS_URL, headers={"User-Agent": "Hive/1.0"}
)
with urllib.request.urlopen(req, timeout=10) as resp:
content = resp.read().decode("utf-8")
id_match = re.search(r'ANTIGRAVITY_CLIENT_ID\s*=\s*"([^"]+)"', content)
secret_match = re.search(r'ANTIGRAVITY_CLIENT_SECRET\s*=\s*"([^"]+)"', content)
client_id = id_match.group(1) if id_match else None
client_secret = secret_match.group(1) if secret_match else None
if client_id and client_secret:
_antigravity_credentials_cache = (client_id, client_secret)
return client_id, client_secret
except Exception as e:
logger.debug("Failed to fetch Antigravity credentials from public source: %s", e)
return None, None
def get_antigravity_client_id() -> str:
"""Return the Antigravity OAuth application client ID.
Checked in order:
1. ``ANTIGRAVITY_CLIENT_ID`` environment variable
2. ``llm.antigravity_client_id`` in ~/.hive/configuration.json
3. Fetch from public source (opencode-antigravity-auth project on GitHub)
"""
env = os.environ.get("ANTIGRAVITY_CLIENT_ID")
if env:
return env
cfg_val = get_hive_config().get("llm", {}).get("antigravity_client_id")
if cfg_val:
return cfg_val
# Fetch from public source
client_id, _ = _fetch_antigravity_credentials()
if client_id:
return client_id
raise RuntimeError("Could not obtain Antigravity OAuth client ID")
def get_antigravity_client_secret() -> str | None:
"""Return the Antigravity OAuth client secret.
Checked in order:
1. ``ANTIGRAVITY_CLIENT_SECRET`` environment variable
2. ``llm.antigravity_client_secret`` in ~/.hive/configuration.json
3. Fetch from public source (opencode-antigravity-auth project on GitHub)
Returns None when not found token refresh will be skipped and
the caller must use whatever access token is already available.
"""
env = os.environ.get("ANTIGRAVITY_CLIENT_SECRET")
if env:
return env
cfg_val = get_hive_config().get("llm", {}).get("antigravity_client_secret") or None
if cfg_val:
return cfg_val
# Fetch from public source
_, secret = _fetch_antigravity_credentials()
return secret
def get_gcu_enabled() -> bool:
"""Return whether GCU (browser automation) is enabled in user config."""
return get_hive_config().get("gcu_enabled", True)
@@ -142,7 +384,14 @@ def get_api_base() -> str | None:
if llm.get("use_kimi_code_subscription"):
# Kimi Code uses an Anthropic-compatible endpoint (no /v1 suffix).
return "https://api.kimi.com/coding"
return llm.get("api_base")
if llm.get("use_antigravity_subscription"):
# Antigravity uses AntigravityProvider directly — no api_base needed.
return None
if llm.get("api_base"):
return llm["api_base"]
if str(llm.get("provider", "")).lower() == "openrouter":
return OPENROUTER_API_BASE
return None
def get_llm_extra_kwargs() -> dict[str, Any]:
+10
View File
@@ -51,6 +51,16 @@ def ensure_credential_key_env() -> None:
if found and value:
os.environ[var_name] = value
logger.debug("Loaded %s from shell config", var_name)
# Also load the currently configured LLM env var even if it's not in CREDENTIAL_SPECS.
# This keeps quickstart-written keys available to fresh processes on Unix shells.
from framework.config import get_hive_config
llm_env_var = str(get_hive_config().get("llm", {}).get("api_key_env_var", "")).strip()
if llm_env_var and not os.environ.get(llm_env_var):
found, value = check_env_var_in_shell_config(llm_env_var)
if found and value:
os.environ[llm_env_var] = value
logger.debug("Loaded configured LLM env var %s from shell config", llm_env_var)
except ImportError:
pass
View File
+76
View File
@@ -0,0 +1,76 @@
"""CLI command for the LLM debug log viewer."""
import argparse
import subprocess
import sys
from pathlib import Path
_SCRIPT = Path(__file__).resolve().parents[3] / "scripts" / "llm_debug_log_visualizer.py"
def register_debugger_commands(subparsers: argparse._SubParsersAction) -> None:
"""Register the ``hive debugger`` command."""
parser = subparsers.add_parser(
"debugger",
help="Open the LLM debug log viewer",
description=(
"Start a local server that lets you browse LLM debug sessions "
"recorded in ~/.hive/llm_logs. Sessions are loaded on demand so "
"the browser stays responsive."
),
)
parser.add_argument(
"--session",
help="Execution ID to select initially.",
)
parser.add_argument(
"--port",
type=int,
default=0,
help="Port for the local server (0 = auto-pick a free port).",
)
parser.add_argument(
"--logs-dir",
help="Directory containing JSONL log files (default: ~/.hive/llm_logs).",
)
parser.add_argument(
"--limit-files",
type=int,
default=None,
help="Maximum number of newest log files to scan (default: 200).",
)
parser.add_argument(
"--output",
help="Write a static HTML file instead of starting a server.",
)
parser.add_argument(
"--no-open",
action="store_true",
help="Start the server but do not open a browser.",
)
parser.add_argument(
"--include-tests",
action="store_true",
help="Show test/mock sessions (hidden by default).",
)
parser.set_defaults(func=cmd_debugger)
def cmd_debugger(args: argparse.Namespace) -> int:
"""Launch the LLM debug log visualizer."""
cmd: list[str] = [sys.executable, str(_SCRIPT)]
if args.session:
cmd += ["--session", args.session]
if args.port:
cmd += ["--port", str(args.port)]
if args.logs_dir:
cmd += ["--logs-dir", args.logs_dir]
if args.limit_files is not None:
cmd += ["--limit-files", str(args.limit_files)]
if args.output:
cmd += ["--output", args.output]
if args.no_open:
cmd.append("--no-open")
if args.include_tests:
cmd.append("--include-tests")
return subprocess.call(cmd)
+36 -2
View File
@@ -33,10 +33,20 @@ class Message:
is_transition_marker: bool = False
# True when this message is real human input (from /chat), not a system prompt
is_client_input: bool = False
# Optional image content blocks (e.g. from browser_screenshot)
image_content: list[dict[str, Any]] | None = None
# True when message contains an activated skill body (AS-10: never prune)
is_skill_content: bool = False
def to_llm_dict(self) -> dict[str, Any]:
"""Convert to OpenAI-format message dict."""
if self.role == "user":
if self.image_content:
blocks: list[dict[str, Any]] = []
if self.content:
blocks.append({"type": "text", "text": self.content})
blocks.extend(self.image_content)
return {"role": "user", "content": blocks}
return {"role": "user", "content": self.content}
if self.role == "assistant":
@@ -47,6 +57,15 @@ class Message:
# role == "tool"
content = f"ERROR: {self.content}" if self.is_error else self.content
if self.image_content:
# Multimodal tool result: text + image content blocks
blocks: list[dict[str, Any]] = [{"type": "text", "text": content}]
blocks.extend(self.image_content)
return {
"role": "tool",
"tool_call_id": self.tool_use_id,
"content": blocks,
}
return {
"role": "tool",
"tool_call_id": self.tool_use_id,
@@ -72,6 +91,8 @@ class Message:
d["is_transition_marker"] = self.is_transition_marker
if self.is_client_input:
d["is_client_input"] = self.is_client_input
if self.image_content is not None:
d["image_content"] = self.image_content
return d
@classmethod
@@ -87,6 +108,7 @@ class Message:
phase_id=data.get("phase_id"),
is_transition_marker=data.get("is_transition_marker", False),
is_client_input=data.get("is_client_input", False),
image_content=data.get("image_content"),
)
@@ -373,6 +395,7 @@ class NodeConversation:
*,
is_transition_marker: bool = False,
is_client_input: bool = False,
image_content: list[dict[str, Any]] | None = None,
) -> Message:
msg = Message(
seq=self._next_seq,
@@ -381,6 +404,7 @@ class NodeConversation:
phase_id=self._current_phase,
is_transition_marker=is_transition_marker,
is_client_input=is_client_input,
image_content=image_content,
)
self._messages.append(msg)
self._next_seq += 1
@@ -409,6 +433,8 @@ class NodeConversation:
tool_use_id: str,
content: str,
is_error: bool = False,
image_content: list[dict[str, Any]] | None = None,
is_skill_content: bool = False,
) -> Message:
msg = Message(
seq=self._next_seq,
@@ -417,6 +443,8 @@ class NodeConversation:
tool_use_id=tool_use_id,
is_error=is_error,
phase_id=self._current_phase,
image_content=image_content,
is_skill_content=is_skill_content,
)
self._messages.append(msg)
self._next_seq += 1
@@ -610,8 +638,15 @@ class NodeConversation:
continue
if msg.is_error:
continue # never prune errors
if msg.is_skill_content:
continue # never prune activated skill instructions (AS-10)
if msg.content.startswith("[Pruned tool result"):
continue # already pruned
# Tiny results (set_output acks, confirmations) — pruning
# saves negligible space but makes the LLM think the call
# failed, causing costly retries.
if len(msg.content) < 100:
continue
# Phase-aware: protect current phase messages
if self._current_phase and msg.phase_id == self._current_phase:
@@ -901,8 +936,7 @@ class NodeConversation:
full_path = str((spill_path / conv_filename).resolve())
ref_parts.append(
f"[Previous conversation saved to '{full_path}'. "
f"Use load_data('{conv_filename}'), read_file('{full_path}'), "
f"or run_command('cat \"{full_path}\"') to review if needed.]"
f"Use load_data('{conv_filename}') to review if needed.]"
)
elif not collapsed_msgs:
ref_parts.append("[Previous freeform messages compacted.]")
+526 -96
View File
@@ -14,6 +14,7 @@ from __future__ import annotations
import asyncio
import json
import logging
import os
import re
import time
from collections.abc import Awaitable, Callable
@@ -24,6 +25,7 @@ from typing import Any, Literal, Protocol, runtime_checkable
from framework.graph.conversation import ConversationStore, NodeConversation
from framework.graph.node import NodeContext, NodeProtocol, NodeResult
from framework.llm.capabilities import supports_image_tool_results
from framework.llm.provider import Tool, ToolResult, ToolUse
from framework.llm.stream_events import (
FinishEvent,
@@ -37,6 +39,56 @@ from framework.runtime.llm_debug_logger import log_llm_turn
logger = logging.getLogger(__name__)
async def _describe_images_as_text(image_content: list[dict[str, Any]]) -> str | None:
"""Describe images using the best available vision model.
Called when the queen's model lacks vision support. Tries vision-capable
models in priority order based on available API keys and returns a bracketed
description to inject into the message text, or None if no vision model is
reachable.
"""
import litellm
# Build content blocks: prompt + all images
blocks: list[dict[str, Any]] = [
{
"type": "text",
"text": (
"Describe the following image(s) concisely but with enough detail "
"that a text-only AI assistant can understand the content and context."
),
}
]
blocks.extend(image_content)
# Ordered candidates based on available env vars
candidates: list[str] = []
if os.environ.get("OPENAI_API_KEY"):
candidates.append("gpt-4o-mini")
if os.environ.get("ANTHROPIC_API_KEY"):
candidates.append("claude-3-haiku-20240307")
if os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY"):
candidates.append("gemini/gemini-1.5-flash")
for model in candidates:
try:
response = await litellm.acompletion(
model=model,
messages=[{"role": "user", "content": blocks}],
max_tokens=512,
)
description = (response.choices[0].message.content or "").strip()
if description:
count = len(image_content)
label = "image" if count == 1 else f"{count} images"
return f"[{label} attached — description: {description}]"
except Exception as exc:
logger.debug("Vision fallback model '%s' failed: %s", model, exc)
continue
return None
@dataclass
class TriggerEvent:
"""A framework-level trigger signal (timer tick or webhook hit).
@@ -90,7 +142,13 @@ class _EscalationReceiver:
self._response: str | None = None
self._awaiting_input = True # So inject_worker_message() can prefer us
async def inject_event(self, content: str, *, is_client_input: bool = False) -> None:
async def inject_event(
self,
content: str,
*,
is_client_input: bool = False,
image_content: list[dict] | None = None,
) -> None:
"""Called by ExecutionStream.inject_input() when the user responds."""
self._response = content
self._event.set()
@@ -243,7 +301,7 @@ class LoopConfig:
# Maximum seconds a delegate_to_sub_agent call may run before being
# killed. Subagents run a full event-loop so they naturally take
# longer than a single tool call — default is 10 minutes. 0 = no timeout.
subagent_timeout_seconds: float = 300.0
subagent_timeout_seconds: float = 600.0
# --- Lifecycle hooks ---
# Hooks are async callables keyed by event name. Supported events:
@@ -293,13 +351,26 @@ class OutputAccumulator:
Values are stored in memory and optionally written through to a
ConversationStore's cursor data for crash recovery.
When *spillover_dir* and *max_value_chars* are set, large values are
automatically saved to files and replaced with lightweight file
references. This guarantees auto-spill fires on **every** ``set()``
call regardless of code path (resume, checkpoint restore, etc.).
"""
values: dict[str, Any] = field(default_factory=dict)
store: ConversationStore | None = None
spillover_dir: str | None = None
max_value_chars: int = 0 # 0 = disabled
async def set(self, key: str, value: Any) -> None:
"""Set a key-value pair, persisting immediately if store is available."""
"""Set a key-value pair, auto-spilling large values to files.
When the serialised value exceeds *max_value_chars*, the data is
saved to ``<spillover_dir>/output_<key>.<ext>`` and *value* is
replaced with a compact file-reference string.
"""
value = self._auto_spill(key, value)
self.values[key] = value
if self.store:
cursor = await self.store.read_cursor() or {}
@@ -308,6 +379,39 @@ class OutputAccumulator:
cursor["outputs"] = outputs
await self.store.write_cursor(cursor)
def _auto_spill(self, key: str, value: Any) -> Any:
"""Save large values to a file and return a reference string."""
if self.max_value_chars <= 0 or not self.spillover_dir:
return value
val_str = json.dumps(value, ensure_ascii=False) if not isinstance(value, str) else value
if len(val_str) <= self.max_value_chars:
return value
spill_path = Path(self.spillover_dir)
spill_path.mkdir(parents=True, exist_ok=True)
ext = ".json" if isinstance(value, (dict, list)) else ".txt"
filename = f"output_{key}{ext}"
write_content = (
json.dumps(value, indent=2, ensure_ascii=False)
if isinstance(value, (dict, list))
else str(value)
)
(spill_path / filename).write_text(write_content, encoding="utf-8")
file_size = (spill_path / filename).stat().st_size
logger.info(
"set_output value auto-spilled: key=%s, %d chars → %s (%d bytes)",
key,
len(val_str),
filename,
file_size,
)
return (
f"[Saved to '{filename}' ({file_size:,} bytes). "
f"Use load_data(filename='{filename}') "
f"to access full data.]"
)
def get(self, key: str) -> Any | None:
"""Get a value by key, or None if not present."""
return self.values.get(key)
@@ -380,7 +484,9 @@ class EventLoopNode(NodeProtocol):
self._config = config or LoopConfig()
self._tool_executor = tool_executor
self._conversation_store = conversation_store
self._injection_queue: asyncio.Queue[tuple[str, bool]] = asyncio.Queue()
self._injection_queue: asyncio.Queue[tuple[str, bool, list[dict[str, Any]] | None]] = (
asyncio.Queue()
)
self._trigger_queue: asyncio.Queue[TriggerEvent] = asyncio.Queue()
# Client-facing input blocking state
self._input_ready = asyncio.Event()
@@ -421,6 +527,8 @@ class EventLoopNode(NodeProtocol):
stream_id = ctx.stream_id or ctx.node_id
node_id = ctx.node_id
execution_id = ctx.execution_id or ""
# Store skill dirs for AS-9 file-read interception in _execute_tool
self._skill_dirs: list[str] = ctx.skill_dirs
# Verdict counters for runtime logging
_accept_count = _retry_count = _escalate_count = _continue_count = 0
@@ -467,7 +575,11 @@ class EventLoopNode(NodeProtocol):
conversation._output_keys = (
ctx.cumulative_output_keys or ctx.node_spec.output_keys or None
)
accumulator = OutputAccumulator(store=self._conversation_store)
accumulator = OutputAccumulator(
store=self._conversation_store,
spillover_dir=self._config.spillover_dir,
max_value_chars=self._config.max_output_value_chars,
)
start_iteration = 0
_restored_recent_responses: list[str] = []
_restored_tool_fingerprints: list[list[tuple[str, str]]] = []
@@ -481,12 +593,28 @@ class EventLoopNode(NodeProtocol):
_restored_recent_responses = restored.recent_responses
_restored_tool_fingerprints = restored.recent_tool_fingerprints
# Refresh the system prompt with full 3-layer composition.
# The stored prompt may be stale after code changes or when
# runtime-injected context (e.g. worker identity) has changed.
# On resume, we rebuild identity + narrative + focus so the LLM
# understands the session history, not just the node directive.
from framework.graph.prompt_composer import compose_system_prompt
# Refresh the system prompt with full composition including
# execution preamble and node-type preamble. The stored
# prompt may be stale after code changes or when runtime-
# injected context (e.g. worker identity) has changed.
from framework.graph.prompt_composer import (
EXECUTION_SCOPE_PREAMBLE,
compose_system_prompt,
)
_exec_preamble = None
if (
not ctx.is_subagent_mode
and ctx.node_spec.node_type in ("event_loop", "gcu")
and ctx.node_spec.output_keys
):
_exec_preamble = EXECUTION_SCOPE_PREAMBLE
_node_type_preamble = None
if ctx.node_spec.node_type == "gcu":
from framework.graph.gcu import GCU_BROWSER_SYSTEM_PROMPT
_node_type_preamble = GCU_BROWSER_SYSTEM_PROMPT
_current_prompt = compose_system_prompt(
identity_prompt=ctx.identity_prompt or None,
@@ -495,6 +623,8 @@ class EventLoopNode(NodeProtocol):
accounts_prompt=ctx.accounts_prompt or None,
skills_catalog_prompt=ctx.skills_catalog_prompt or None,
protocols_prompt=ctx.protocols_prompt or None,
execution_preamble=_exec_preamble,
node_type_preamble=_node_type_preamble,
)
if conversation.system_prompt != _current_prompt:
conversation.update_system_prompt(_current_prompt)
@@ -504,9 +634,21 @@ class EventLoopNode(NodeProtocol):
_restored_tool_fingerprints = []
# Fresh conversation: either isolated mode or first node in continuous mode.
from framework.graph.prompt_composer import _with_datetime
from framework.graph.prompt_composer import (
EXECUTION_SCOPE_PREAMBLE,
_with_datetime,
)
system_prompt = _with_datetime(ctx.node_spec.system_prompt or "")
# Prepend execution-scope preamble for worker nodes so the
# LLM knows it is one step in a pipeline and should not try
# to perform work that belongs to other nodes.
if (
not ctx.is_subagent_mode
and ctx.node_spec.node_type in ("event_loop", "gcu")
and ctx.node_spec.output_keys
):
system_prompt = f"{EXECUTION_SCOPE_PREAMBLE}\n\n{system_prompt}"
# Prepend GCU browser best-practices prompt for gcu nodes
if ctx.node_spec.node_type == "gcu":
from framework.graph.gcu import GCU_BROWSER_SYSTEM_PROMPT
@@ -573,7 +715,11 @@ class EventLoopNode(NodeProtocol):
# Stamp phase for first node in continuous mode
if _is_continuous:
conversation.set_current_phase(ctx.node_id)
accumulator = OutputAccumulator(store=self._conversation_store)
accumulator = OutputAccumulator(
store=self._conversation_store,
spillover_dir=self._config.spillover_dir,
max_value_chars=self._config.max_output_value_chars,
)
start_iteration = 0
# Add initial user message from input data
@@ -698,7 +844,7 @@ class EventLoopNode(NodeProtocol):
)
# 6b. Drain injection queue
await self._drain_injection_queue(conversation)
await self._drain_injection_queue(conversation, ctx)
# 6b1. Drain trigger queue (framework-level signals)
await self._drain_trigger_queue(conversation)
@@ -740,6 +886,13 @@ class EventLoopNode(NodeProtocol):
execution_id,
extra_data=_iter_meta,
)
# Sync max_context_tokens from live config so mid-session model
# switches are reflected in compaction decisions and the UI bar.
from framework.config import get_max_context_tokens as _live_mct
conversation._max_context_tokens = _live_mct()
await self._publish_context_usage(ctx, conversation, "iteration_start")
# 6d. Pre-turn compaction check (tiered)
_compacted_this_iter = False
@@ -756,6 +909,7 @@ class EventLoopNode(NodeProtocol):
)
_stream_retry_count = 0
_turn_cancelled = False
_llm_turn_failed_waiting_input = False
while True:
try:
(
@@ -875,6 +1029,16 @@ class EventLoopNode(NodeProtocol):
# can retry or adjust the request.
if ctx.node_spec.client_facing:
error_msg = f"LLM call failed: {e}"
_guardrail_phrase = (
"no endpoints available matching your guardrail restrictions "
"and data policy"
)
if _guardrail_phrase in str(e).lower():
error_msg += (
" OpenRouter blocked this model under current privacy settings. "
"Update https://openrouter.ai/settings/privacy or choose another "
"OpenRouter model."
)
logger.error(
"[%s] iter=%d: %s — waiting for user input",
node_id,
@@ -896,6 +1060,7 @@ class EventLoopNode(NodeProtocol):
f"[Error: {error_msg}. Please try again.]"
)
await self._await_user_input(ctx, prompt="")
_llm_turn_failed_waiting_input = True
break # exit retry loop, continue outer iteration
# Non-client-facing: crash as before
@@ -946,6 +1111,11 @@ class EventLoopNode(NodeProtocol):
await self._await_user_input(ctx, prompt="")
continue # back to top of for-iteration loop
# Client-facing non-transient LLM failures wait for user input and then
# continue the outer loop without touching per-turn token vars.
if _llm_turn_failed_waiting_input:
continue
# 6e'. Feed actual API token count back for accurate estimation
turn_input = turn_tokens.get("input", 0)
if turn_input > 0:
@@ -1800,7 +1970,13 @@ class EventLoopNode(NodeProtocol):
conversation=conversation if _is_continuous else None,
)
async def inject_event(self, content: str, *, is_client_input: bool = False) -> None:
async def inject_event(
self,
content: str,
*,
is_client_input: bool = False,
image_content: list[dict[str, Any]] | None = None,
) -> None:
"""Inject an external event or user input into the running loop.
The content becomes a user message prepended to the next iteration.
@@ -1816,8 +1992,10 @@ class EventLoopNode(NodeProtocol):
human user (e.g. /chat endpoint), False for external events
(e.g. worker question forwarded by the frontend). Controls
message formatting in _drain_injection_queue, not wake behavior.
image_content: Optional list of image content blocks (OpenAI
image_url format) to include alongside the text.
"""
await self._injection_queue.put((content, is_client_input))
await self._injection_queue.put((content, is_client_input, image_content))
self._input_ready.set()
async def inject_trigger(self, trigger: TriggerEvent) -> None:
@@ -1991,6 +2169,24 @@ class EventLoopNode(NodeProtocol):
messages = conversation.to_llm_messages()
# Debug: log whether the last user message contains image blocks
for _m in reversed(messages):
if _m.get("role") == "user":
_content = _m.get("content")
if isinstance(_content, list):
_img_count = sum(
1
for _b in _content
if isinstance(_b, dict) and _b.get("type") == "image_url"
)
if _img_count:
logger.info(
"[%s] LLM call: last user message has %d image block(s)",
node_id,
_img_count,
)
break
# Defensive guard: ensure messages don't end with an assistant
# message. The Anthropic API rejects "assistant message prefill"
# (conversations must end with a user or tool message). This can
@@ -2197,58 +2393,24 @@ class EventLoopNode(NodeProtocol):
pass
key = tc.tool_input.get("key", "")
# Auto-spill: save large values to data files and
# replace with a lightweight file reference so shared
# memory / adapt.md / transition markers stay small.
spill_dir = self._config.spillover_dir
max_val = self._config.max_output_value_chars
if max_val > 0 and spill_dir:
val_str = (
json.dumps(value, ensure_ascii=False)
if not isinstance(value, str)
else value
)
if len(val_str) > max_val:
spill_path = Path(spill_dir)
spill_path.mkdir(parents=True, exist_ok=True)
ext = ".json" if isinstance(value, (dict, list)) else ".txt"
filename = f"output_{key}{ext}"
write_content = (
json.dumps(value, indent=2, ensure_ascii=False)
if isinstance(value, (dict, list))
else str(value)
)
(spill_path / filename).write_text(write_content, encoding="utf-8")
file_size = (spill_path / filename).stat().st_size
logger.info(
"set_output value auto-spilled: key=%s, "
"%d chars → %s (%d bytes)",
key,
len(val_str),
filename,
file_size,
)
# Replace value with reference
value = (
f"[Saved to '{filename}' ({file_size:,} bytes). "
f"Use load_data(filename='{filename}') "
f"to access full data.]"
)
# Update tool result to inform the LLM
result = ToolResult(
tool_use_id=tc.tool_use_id,
content=(
f"Output '{key}' was large "
f"({len(val_str):,} chars) — data saved "
f"to '{filename}' ({file_size:,} bytes). "
f"The next phase will see the file "
f"reference and can load full data."
),
is_error=False,
)
# Auto-spill happens inside accumulator.set()
# — it fires on every code path (fresh, resume,
# restore) and prevents overwrite regression.
await accumulator.set(key, value)
self._record_learning(key, value)
stored = accumulator.get(key)
# If the accumulator spilled, update the tool
# result so the LLM knows data was saved to a file.
if isinstance(stored, str) and stored.startswith("[Saved to '"):
result = ToolResult(
tool_use_id=tc.tool_use_id,
content=(
f"Output '{key}' auto-saved to file "
f"(value was too large for inline). "
f"{stored}"
),
is_error=False,
)
self._record_learning(key, stored)
outputs_set_this_turn.append(key)
await self._publish_output_key_set(stream_id, node_id, key, execution_id)
logged_tool_calls.append(
@@ -2266,7 +2428,6 @@ class EventLoopNode(NodeProtocol):
elif tc.tool_name == "ask_user":
# --- Framework-level ask_user handling ---
user_input_requested = True
ask_user_prompt = tc.tool_input.get("question", "")
raw_options = tc.tool_input.get("options", None)
# Defensive: ensure options is a list of strings.
@@ -2303,6 +2464,8 @@ class EventLoopNode(NodeProtocol):
user_input_requested = False
continue
user_input_requested = True
# Free-form ask_user (no options): stream the question
# text as a chat message so the user can see it. When
# options are present the QuestionWidget shows the
@@ -2328,7 +2491,6 @@ class EventLoopNode(NodeProtocol):
elif tc.tool_name == "ask_user_multiple":
# --- Framework-level ask_user_multiple ---
user_input_requested = True
raw_questions = tc.tool_input.get("questions", [])
if not isinstance(raw_questions, list) or len(raw_questions) < 2:
result = ToolResult(
@@ -2366,6 +2528,8 @@ class EventLoopNode(NodeProtocol):
}
)
user_input_requested = True
# Store as multi-question prompt/options for
# the event emission path
ask_user_prompt = ""
@@ -2426,6 +2590,27 @@ class EventLoopNode(NodeProtocol):
results_by_id[tc.tool_use_id] = result
elif tc.tool_name == "delegate_to_sub_agent":
# Guard: in continuous mode the LLM may see delegate
# calls from a previous node's conversation history and
# attempt to re-use the tool on a node that doesn't own
# it. Only accept if the tool was actually offered.
if not any(t.name == "delegate_to_sub_agent" for t in tools):
logger.warning(
"[%s] LLM called delegate_to_sub_agent but tool "
"was not offered to this node — rejecting",
node_id,
)
result = ToolResult(
tool_use_id=tc.tool_use_id,
content=(
"ERROR: delegate_to_sub_agent is not available "
"on this node. This tool belongs to a different "
"node in the workflow."
),
is_error=True,
)
results_by_id[tc.tool_use_id] = result
continue
# --- Framework-level subagent delegation ---
# Queue for parallel execution in Phase 2
logger.info(
@@ -2627,6 +2812,11 @@ class EventLoopNode(NodeProtocol):
content=raw.content,
is_error=raw.is_error,
)
# Route through _truncate_tool_result so large
# subagent results are saved to spillover files
# and survive pruning (instead of being "cleared
# from context" with no recovery path).
result = self._truncate_tool_result(result, "delegate_to_sub_agent")
results_by_id[tc.tool_use_id] = result
logged_tool_calls.append(
{
@@ -2666,12 +2856,28 @@ class EventLoopNode(NodeProtocol):
real_tool_results.append(tool_entry)
logged_tool_calls.append(tool_entry)
# Strip image content for models that can't handle it
image_content = result.image_content
if image_content and ctx.llm and not supports_image_tool_results(ctx.llm.model):
logger.info(
"Stripping image_content from tool result — model '%s' "
"does not support images in tool results",
ctx.llm.model,
)
image_content = None
await conversation.add_tool_result(
tool_use_id=tc.tool_use_id,
content=result.content,
is_error=result.is_error,
image_content=image_content,
is_skill_content=result.is_skill_content,
)
if tc.tool_name in ("ask_user", "ask_user_multiple"):
if (
tc.tool_name in ("ask_user", "ask_user_multiple")
and user_input_requested
and not result.is_error
):
# Defer tool_call_completed until after user responds
self._deferred_tool_complete = {
"stream_id": stream_id,
@@ -2774,6 +2980,8 @@ class EventLoopNode(NodeProtocol):
conversation.usage_ratio() * 100,
)
await self._publish_context_usage(ctx, conversation, "post_tool_results")
# If the turn requested external input (ask_user or queen handoff),
# return immediately so the outer loop can block before judge eval.
if user_input_requested or queen_input_requested:
@@ -3489,6 +3697,33 @@ class EventLoopNode(NodeProtocol):
content=f"No tool executor configured for '{tc.tool_name}'",
is_error=True,
)
# AS-9: Intercept file-read tools for skill directories — bypass session sandbox
_SKILL_READ_TOOLS = {"view_file", "load_data", "read_file"}
skill_dirs = getattr(self, "_skill_dirs", [])
if tc.tool_name in _SKILL_READ_TOOLS and skill_dirs:
_path = tc.tool_input.get("path", "")
if _path:
import os
from pathlib import Path as _Path
_resolved = os.path.realpath(os.path.abspath(_path))
if any(_resolved.startswith(os.path.realpath(d)) for d in skill_dirs):
try:
_content = _Path(_resolved).read_text(encoding="utf-8")
_is_skill_md = _resolved.endswith("SKILL.md")
return ToolResult(
tool_use_id=tc.tool_use_id,
content=_content,
is_skill_content=_is_skill_md, # AS-10: protect SKILL.md reads
)
except Exception as _exc:
return ToolResult(
tool_use_id=tc.tool_use_id,
content=f"Could not read skill resource '{_path}': {_exc}",
is_error=True,
)
tool_use = ToolUse(id=tc.tool_use_id, name=tc.tool_name, input=tc.tool_input)
timeout = self._config.tool_call_timeout_seconds
@@ -3776,6 +4011,7 @@ class EventLoopNode(NodeProtocol):
tool_use_id=result.tool_use_id,
content=truncated,
is_error=False,
image_content=result.image_content,
)
spill_dir = self._config.spillover_dir
@@ -3848,6 +4084,7 @@ class EventLoopNode(NodeProtocol):
tool_use_id=result.tool_use_id,
content=content,
is_error=False,
image_content=result.image_content,
)
# No spillover_dir — truncate in-place if needed
@@ -3890,6 +4127,7 @@ class EventLoopNode(NodeProtocol):
tool_use_id=result.tool_use_id,
content=truncated,
is_error=False,
image_content=result.image_content,
)
return result
@@ -3920,6 +4158,12 @@ class EventLoopNode(NodeProtocol):
ratio_before = conversation.usage_ratio()
phase_grad = getattr(ctx, "continuous_mode", False)
# Capture pre-compaction message inventory when over budget,
# since compaction mutates the conversation in place.
pre_inventory: list[dict[str, Any]] | None = None
if ratio_before >= 1.0:
pre_inventory = self._build_message_inventory(conversation)
# --- Step 1: Prune old tool results (free, no LLM) ---
protect = max(2000, self._config.max_context_tokens // 12)
pruned = await conversation.prune_old_tool_results(
@@ -3934,7 +4178,7 @@ class EventLoopNode(NodeProtocol):
conversation.usage_ratio() * 100,
)
if not conversation.needs_compaction():
await self._log_compaction(ctx, conversation, ratio_before)
await self._log_compaction(ctx, conversation, ratio_before, pre_inventory)
return
# --- Step 2: Standard structure-preserving compaction (free, no LLM) ---
@@ -3947,7 +4191,7 @@ class EventLoopNode(NodeProtocol):
phase_graduated=phase_grad,
)
if not conversation.needs_compaction():
await self._log_compaction(ctx, conversation, ratio_before)
await self._log_compaction(ctx, conversation, ratio_before, pre_inventory)
return
# --- Step 3: LLM summary compaction ---
@@ -3974,7 +4218,7 @@ class EventLoopNode(NodeProtocol):
logger.warning("LLM compaction failed: %s", e)
if not conversation.needs_compaction():
await self._log_compaction(ctx, conversation, ratio_before)
await self._log_compaction(ctx, conversation, ratio_before, pre_inventory)
return
# --- Step 4: Emergency deterministic summary (LLM failed/unavailable) ---
@@ -3988,7 +4232,7 @@ class EventLoopNode(NodeProtocol):
keep_recent=1,
phase_graduated=phase_grad,
)
await self._log_compaction(ctx, conversation, ratio_before)
await self._log_compaction(ctx, conversation, ratio_before, pre_inventory)
# --- LLM compaction with binary-search splitting ----------------------
@@ -4150,13 +4394,59 @@ class EventLoopNode(NodeProtocol):
"re-doing work.\n"
)
@staticmethod
def _build_message_inventory(
conversation: NodeConversation,
) -> list[dict[str, Any]]:
"""Build a per-message size inventory for debug logging."""
inventory: list[dict[str, Any]] = []
for m in conversation.messages:
content_chars = len(m.content)
tc_chars = 0
tool_name = None
if m.tool_calls:
for tc in m.tool_calls:
args = tc.get("function", {}).get("arguments", "")
tc_chars += len(args) if isinstance(args, str) else len(json.dumps(args))
names = [tc.get("function", {}).get("name", "?") for tc in m.tool_calls]
tool_name = ", ".join(names)
elif m.role == "tool" and m.tool_use_id:
for prev in conversation.messages:
if prev.tool_calls:
for tc in prev.tool_calls:
if tc.get("id") == m.tool_use_id:
tool_name = tc.get("function", {}).get("name", "?")
break
if tool_name:
break
entry: dict[str, Any] = {
"seq": m.seq,
"role": m.role,
"content_chars": content_chars,
}
if tc_chars:
entry["tool_call_args_chars"] = tc_chars
if tool_name:
entry["tool"] = tool_name
if m.is_error:
entry["is_error"] = True
if m.phase_id:
entry["phase"] = m.phase_id
if content_chars > 2000:
entry["preview"] = m.content[:200] + ""
inventory.append(entry)
return inventory
async def _log_compaction(
self,
ctx: NodeContext,
conversation: NodeConversation,
ratio_before: float,
pre_inventory: list[dict[str, Any]] | None = None,
) -> None:
"""Log compaction result to runtime logger and event bus."""
"""Log compaction result to runtime logger, event bus, and debug file."""
import os as _os
ratio_after = conversation.usage_ratio()
before_pct = round(ratio_before * 100)
after_pct = round(ratio_after * 100)
@@ -4189,19 +4479,103 @@ class EventLoopNode(NodeProtocol):
if self._event_bus:
from framework.runtime.event_bus import AgentEvent, EventType
event_data: dict[str, Any] = {
"level": level,
"usage_before": before_pct,
"usage_after": after_pct,
}
if pre_inventory is not None:
event_data["message_inventory"] = pre_inventory
await self._event_bus.publish(
AgentEvent(
type=EventType.CONTEXT_COMPACTED,
stream_id=ctx.stream_id or ctx.node_id,
node_id=ctx.node_id,
data={
"level": level,
"usage_before": before_pct,
"usage_after": after_pct,
},
data=event_data,
)
)
# Emit post-compaction usage update
await self._publish_context_usage(ctx, conversation, "post_compaction")
# Write detailed debug log to ~/.hive/compaction_log/ when enabled
if _os.environ.get("HIVE_COMPACTION_DEBUG"):
self._write_compaction_debug_log(ctx, before_pct, after_pct, level, pre_inventory)
@staticmethod
def _write_compaction_debug_log(
ctx: NodeContext,
before_pct: int,
after_pct: int,
level: str,
inventory: list[dict[str, Any]] | None,
) -> None:
"""Write detailed compaction analysis to ~/.hive/compaction_log/."""
log_dir = Path.home() / ".hive" / "compaction_log"
log_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now(UTC).strftime("%Y%m%dT%H%M%S_%f")
node_label = ctx.node_id.replace("/", "_")
log_path = log_dir / f"{ts}_{node_label}.md"
lines: list[str] = [
f"# Compaction Debug — {ctx.node_id}",
f"**Time:** {datetime.now(UTC).isoformat()}",
f"**Node:** {ctx.node_spec.name} (`{ctx.node_id}`)",
]
if ctx.stream_id:
lines.append(f"**Stream:** {ctx.stream_id}")
lines.append(f"**Level:** {level}")
lines.append(f"**Usage:** {before_pct}% → {after_pct}%")
lines.append("")
if inventory:
total_chars = sum(
e.get("content_chars", 0) + e.get("tool_call_args_chars", 0) for e in inventory
)
lines.append(
f"## Pre-Compaction Message Inventory "
f"({len(inventory)} messages, {total_chars:,} total chars)"
)
lines.append("")
ranked = sorted(
inventory,
key=lambda e: e.get("content_chars", 0) + e.get("tool_call_args_chars", 0),
reverse=True,
)
lines.append("| # | seq | role | tool | chars | % of total | flags |")
lines.append("|---|-----|------|------|------:|------------|-------|")
for i, entry in enumerate(ranked, 1):
chars = entry.get("content_chars", 0) + entry.get("tool_call_args_chars", 0)
pct = (chars / total_chars * 100) if total_chars else 0
tool = entry.get("tool", "")
flags = []
if entry.get("is_error"):
flags.append("error")
if entry.get("phase"):
flags.append(f"phase={entry['phase']}")
lines.append(
f"| {i} | {entry['seq']} | {entry['role']} | {tool} "
f"| {chars:,} | {pct:.1f}% | {', '.join(flags)} |"
)
large = [e for e in ranked if e.get("preview")]
if large:
lines.append("")
lines.append("### Large message previews")
for entry in large:
lines.append(
f"\n**seq={entry['seq']}** ({entry['role']}, {entry.get('tool', '')}):"
)
lines.append(f"```\n{entry['preview']}\n```")
lines.append("")
try:
log_path.write_text("\n".join(lines), encoding="utf-8")
logger.debug("Compaction debug log written to %s", log_path)
except OSError:
logger.debug("Failed to write compaction debug log to %s", log_path)
def _build_emergency_summary(
self,
ctx: NodeContext,
@@ -4287,17 +4661,14 @@ class EventLoopNode(NodeProtocol):
)
parts.append(
"CONVERSATION HISTORY (freeform messages saved during compaction — "
"use load_data('<filename>'), read_file('<full_path>'), "
"or run_command('cat \"<full_path>\"') to review earlier dialogue):\n"
+ conv_list
"use load_data('<filename>') to review earlier dialogue):\n" + conv_list
)
if data_files:
file_list = "\n".join(
f" - {f} (full path: {data_dir / f})" for f in data_files[:30]
)
parts.append(
"DATA FILES (use load_data('<filename>'), read_file('<full_path>'), "
"or run_command('cat \"<full_path>\"') to read):\n" + file_list
"DATA FILES (use load_data('<filename>') to read):\n" + file_list
)
if not all_files:
parts.append(
@@ -4363,6 +4734,8 @@ class EventLoopNode(NodeProtocol):
return None
accumulator = await OutputAccumulator.restore(self._conversation_store)
accumulator.spillover_dir = self._config.spillover_dir
accumulator.max_value_chars = self._config.max_output_value_chars
cursor = await self._conversation_store.read_cursor()
start_iteration = cursor.get("iteration", 0) + 1 if cursor else 0
@@ -4425,20 +4798,37 @@ class EventLoopNode(NodeProtocol):
]
await self._conversation_store.write_cursor(cursor)
async def _drain_injection_queue(self, conversation: NodeConversation) -> int:
async def _drain_injection_queue(self, conversation: NodeConversation, ctx: NodeContext) -> int:
"""Drain all pending injected events as user messages. Returns count."""
count = 0
while not self._injection_queue.empty():
try:
content, is_client_input = self._injection_queue.get_nowait()
content, is_client_input, image_content = self._injection_queue.get_nowait()
logger.info(
"[drain] injected message (client_input=%s): %s",
"[drain] injected message (client_input=%s, images=%d): %s",
is_client_input,
len(image_content) if image_content else 0,
content[:200] if content else "(empty)",
)
# For models that don't support images, fall back to a text description
if image_content and ctx.llm:
if not supports_image_tool_results(ctx.llm.model):
logger.info(
"Model '%s' does not support images — attempting vision fallback",
ctx.llm.model,
)
description = await _describe_images_as_text(image_content)
if description:
content = f"{content}\n\n{description}" if content else description
logger.info("[drain] image described as text via vision fallback")
else:
logger.info("[drain] no vision fallback available — images dropped")
image_content = None
# Real user input is stored as-is; external events get a prefix
if is_client_input:
await conversation.add_user_message(content, is_client_input=True)
await conversation.add_user_message(
content, is_client_input=True, image_content=image_content
)
else:
await conversation.add_user_message(f"[External event]: {content}")
count += 1
@@ -4607,6 +4997,36 @@ class EventLoopNode(NodeProtocol):
if result.inject:
await conversation.add_user_message(result.inject)
async def _publish_context_usage(
self,
ctx: NodeContext,
conversation: NodeConversation,
trigger: str,
) -> None:
"""Emit a CONTEXT_USAGE_UPDATED event with current context window state."""
if not self._event_bus:
return
from framework.runtime.event_bus import AgentEvent, EventType
estimated = conversation.estimate_tokens()
max_tokens = conversation._max_context_tokens
ratio = estimated / max_tokens if max_tokens > 0 else 0.0
await self._event_bus.publish(
AgentEvent(
type=EventType.CONTEXT_USAGE_UPDATED,
stream_id=ctx.stream_id or ctx.node_id,
node_id=ctx.node_id,
data={
"usage_ratio": round(ratio, 4),
"usage_pct": round(ratio * 100),
"message_count": conversation.message_count,
"estimated_tokens": estimated,
"max_context_tokens": max_tokens,
"trigger": trigger,
},
)
)
async def _publish_iteration(
self,
stream_id: str,
@@ -4891,7 +5311,20 @@ class EventLoopNode(NodeProtocol):
write_keys=[], # Read-only!
)
# 2b. Set up report callback (one-way channel to parent / event bus)
# 2b. Compute instance counter early so node_id is available for the
# report callback and the NodeContext. Each delegation to the same
# agent_id gets a unique suffix (instance 1 has no suffix for backward
# compat; instance 2+ appends ":N").
self._subagent_instance_counter.setdefault(agent_id, 0)
self._subagent_instance_counter[agent_id] += 1
_sa_instance = self._subagent_instance_counter[agent_id]
if _sa_instance > 1:
sa_node_id = f"{ctx.node_id}:subagent:{agent_id}:{_sa_instance}"
else:
sa_node_id = f"{ctx.node_id}:subagent:{agent_id}"
subagent_instance = str(_sa_instance)
# 2c. Set up report callback (one-way channel to parent / event bus)
subagent_reports: list[dict] = []
async def _report_callback(
@@ -4904,7 +5337,7 @@ class EventLoopNode(NodeProtocol):
if self._event_bus:
await self._event_bus.emit_subagent_report(
stream_id=ctx.node_id,
node_id=f"{ctx.node_id}:subagent:{agent_id}",
node_id=sa_node_id,
subagent_id=agent_id,
message=message,
data=data,
@@ -4994,7 +5427,7 @@ class EventLoopNode(NodeProtocol):
max_iter = min(self._config.max_iterations, 10)
subagent_ctx = NodeContext(
runtime=ctx.runtime,
node_id=f"{ctx.node_id}:subagent:{agent_id}",
node_id=sa_node_id,
node_spec=subagent_spec,
memory=scoped_memory,
input_data={"task": task, **parent_data},
@@ -5022,10 +5455,7 @@ class EventLoopNode(NodeProtocol):
# Derive a conversation store for the subagent from the parent's store.
# Each invocation gets a unique path so that repeated delegate calls
# (e.g. one per profile) don't restore a stale completed conversation.
self._subagent_instance_counter.setdefault(agent_id, 0)
self._subagent_instance_counter[agent_id] += 1
subagent_instance = str(self._subagent_instance_counter[agent_id])
# (Instance counter was computed earlier in step 2b.)
subagent_conv_store = None
if self._conversation_store is not None:
from framework.storage.conversation_store import FileConversationStore
+14 -1
View File
@@ -154,6 +154,7 @@ class GraphExecutor:
iteration_metadata_provider: Callable | None = None,
skills_catalog_prompt: str = "",
protocols_prompt: str = "",
skill_dirs: list[str] | None = None,
):
"""
Initialize the executor.
@@ -181,6 +182,7 @@ class GraphExecutor:
system prompt (for phase switching)
skills_catalog_prompt: Available skills catalog for system prompt
protocols_prompt: Default skill operational protocols for system prompt
skill_dirs: Skill base directories for Tier 3 resource access
"""
self.runtime = runtime
self.llm = llm
@@ -204,6 +206,7 @@ class GraphExecutor:
self.iteration_metadata_provider = iteration_metadata_provider
self.skills_catalog_prompt = skills_catalog_prompt
self.protocols_prompt = protocols_prompt
self.skill_dirs: list[str] = skill_dirs or []
if protocols_prompt:
self.logger.info(
@@ -1420,6 +1423,7 @@ class GraphExecutor:
next_spec = graph.get_node(current_node_id)
if next_spec and next_spec.node_type == "event_loop":
from framework.graph.prompt_composer import (
EXECUTION_SCOPE_PREAMBLE,
build_accounts_prompt,
build_narrative,
build_transition_marker,
@@ -1459,9 +1463,14 @@ class GraphExecutor:
)
# Compose new system prompt (Layer 1 + 2 + 3 + accounts)
# Prepend scope preamble to focus so the LLM stays
# within this node's responsibility.
_focus = next_spec.system_prompt
if next_spec.output_keys and _focus:
_focus = f"{EXECUTION_SCOPE_PREAMBLE}\n\n{_focus}"
new_system = compose_system_prompt(
identity_prompt=getattr(graph, "identity_prompt", None),
focus_prompt=next_spec.system_prompt,
focus_prompt=_focus,
narrative=narrative,
accounts_prompt=_node_accounts,
)
@@ -1839,6 +1848,9 @@ class GraphExecutor:
existing_underscore = [k for k in memory._data if k.startswith("_")]
extra_keys = set(_skill_keys) | set(existing_underscore)
# Only inject into read_keys when it was already non-empty — an empty
# read_keys means "allow all reads" and injecting skill keys would
# inadvertently restrict reads to skill keys only.
for k in extra_keys:
if read_keys and k not in read_keys:
read_keys.append(k)
@@ -1893,6 +1905,7 @@ class GraphExecutor:
iteration_metadata_provider=self.iteration_metadata_provider,
skills_catalog_prompt=self.skills_catalog_prompt,
protocols_prompt=self.protocols_prompt,
skill_dirs=self.skill_dirs,
)
VALID_NODE_TYPES = {
+5 -2
View File
@@ -43,8 +43,11 @@ Follow these rules for reliable, efficient browser interaction.
`browser_snapshot` separately after every action.
Only call `browser_snapshot` when you need a fresh view without
performing an action, or after setting `auto_snapshot=false`.
- Do NOT use `browser_screenshot` for reading text content
it produces huge base64 images with no searchable text.
- Do NOT use `browser_screenshot` to read text use
`browser_snapshot` for that (compact, searchable, fast).
- DO use `browser_screenshot` when you need visual context:
charts, images, canvas elements, layout verification, or when
the snapshot doesn't capture what you need.
- Only fall back to `browser_get_text` for extracting specific
small elements by CSS selector.
-8
View File
@@ -167,14 +167,6 @@ class Goal(BaseModel):
return met_weight >= total_weight * 0.9 # 90% threshold
def check_constraint(self, constraint_id: str, value: Any) -> bool:
"""Check if a specific constraint is satisfied."""
for c in self.constraints:
if c.id == constraint_id:
# This would be expanded with actual evaluation logic
return True
return True
def to_prompt_context(self) -> str:
"""Generate context string for LLM prompts.
+1
View File
@@ -568,6 +568,7 @@ class NodeContext:
# Skill system prompts — injected by the skill discovery pipeline
skills_catalog_prompt: str = "" # Available skills XML catalog
protocols_prompt: str = "" # Default skill operational protocols
skill_dirs: list[str] = field(default_factory=list) # Skill base dirs for resource access
# Per-iteration metadata provider — when set, EventLoopNode merges
# the returned dict into node_loop_iteration event data. Used by
+58 -3
View File
@@ -26,6 +26,16 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Injected into every worker node's system prompt so the LLM understands
# it is one step in a multi-node pipeline and should not overreach.
EXECUTION_SCOPE_PREAMBLE = (
"EXECUTION SCOPE: You are one node in a multi-step workflow graph. "
"Focus ONLY on the task described in your instructions below. "
"Call set_output() for each of your declared output keys, then stop. "
"Do NOT attempt work that belongs to other nodes — the framework "
"routes data between nodes automatically."
)
def _with_datetime(prompt: str) -> str:
"""Append current datetime with local timezone to a system prompt."""
@@ -142,6 +152,8 @@ def compose_system_prompt(
accounts_prompt: str | None = None,
skills_catalog_prompt: str | None = None,
protocols_prompt: str | None = None,
execution_preamble: str | None = None,
node_type_preamble: str | None = None,
) -> str:
"""Compose the multi-layer system prompt.
@@ -152,6 +164,10 @@ def compose_system_prompt(
accounts_prompt: Connected accounts block (sits between identity and narrative).
skills_catalog_prompt: Available skills catalog XML (Agent Skills standard).
protocols_prompt: Default skill operational protocols section.
execution_preamble: EXECUTION_SCOPE_PREAMBLE for worker nodes
(prepended before focus so the LLM knows its pipeline scope).
node_type_preamble: Node-type-specific preamble, e.g. GCU browser
best-practices prompt (prepended before focus).
Returns:
Composed system prompt with all layers present, plus current datetime.
@@ -178,6 +194,15 @@ def compose_system_prompt(
if narrative:
parts.append(f"\n--- Context (what has happened so far) ---\n{narrative}")
# Execution scope preamble (worker nodes — tells the LLM it is one
# step in a multi-node pipeline and should not overreach)
if execution_preamble:
parts.append(f"\n{execution_preamble}")
# Node-type preamble (e.g. GCU browser best-practices)
if node_type_preamble:
parts.append(f"\n{node_type_preamble}")
# Layer 3: Focus (current phase directive)
if focus_prompt:
parts.append(f"\n--- Current Focus ---\n{focus_prompt}")
@@ -267,7 +292,9 @@ def build_transition_marker(
sections.append(f"\nCompleted: {previous_node.name}")
sections.append(f" {previous_node.description}")
# Outputs in memory
# Outputs in memory — use file references for large values so the
# next node loads full data from disk instead of seeing truncated
# inline previews that look deceptively complete.
all_memory = memory.read_all()
if all_memory:
memory_lines: list[str] = []
@@ -275,7 +302,29 @@ def build_transition_marker(
if value is None:
continue
val_str = str(value)
if len(val_str) > 300:
if len(val_str) > 300 and data_dir:
# Auto-spill large transition values to data files
import json as _json
data_path = Path(data_dir)
data_path.mkdir(parents=True, exist_ok=True)
ext = ".json" if isinstance(value, (dict, list)) else ".txt"
filename = f"output_{key}{ext}"
try:
write_content = (
_json.dumps(value, indent=2, ensure_ascii=False)
if isinstance(value, (dict, list))
else str(value)
)
(data_path / filename).write_text(write_content, encoding="utf-8")
file_size = (data_path / filename).stat().st_size
val_str = (
f"[Saved to '{filename}' ({file_size:,} bytes). "
f"Use load_data(filename='{filename}') to access.]"
)
except Exception:
val_str = val_str[:300] + "..."
elif len(val_str) > 300:
val_str = val_str[:300] + "..."
memory_lines.append(f" {key}: {val_str}")
if memory_lines:
@@ -292,7 +341,7 @@ def build_transition_marker(
]
if file_lines:
sections.append(
"\nData files (use read_file to access):\n" + "\n".join(file_lines)
"\nData files (use load_data to access):\n" + "\n".join(file_lines)
)
# Agent working memory
@@ -306,6 +355,12 @@ def build_transition_marker(
# Next phase
sections.append(f"\nNow entering: {next_node.name}")
sections.append(f" {next_node.description}")
if next_node.output_keys:
sections.append(
f"\nYour ONLY job in this phase: complete the task above and call "
f"set_output() for {next_node.output_keys}. Do NOT do work that "
f"belongs to later phases."
)
# Reflection prompt (engineered metacognition)
sections.append(
+15 -3
View File
@@ -115,11 +115,23 @@ class SafeEvalVisitor(ast.NodeVisitor):
return True
def visit_BoolOp(self, node: ast.BoolOp) -> Any:
values = [self.visit(v) for v in node.values]
# Short-circuit evaluation to match Python semantics.
# Previously all operands were eagerly evaluated, which broke
# guard patterns like: ``x is not None and x.get("key")``
if isinstance(node.op, ast.And):
return all(values)
result = True
for v in node.values:
result = self.visit(v)
if not result:
return result
return result
elif isinstance(node.op, ast.Or):
return any(values)
result = False
for v in node.values:
result = self.visit(v)
if result:
return result
return result
raise ValueError(f"Boolean operator {type(node.op).__name__} is not allowed")
def visit_IfExp(self, node: ast.IfExp) -> Any:
+706
View File
@@ -0,0 +1,706 @@
"""Antigravity (Google internal Cloud Code Assist) LLM provider.
Antigravity is Google's unified gateway API that routes requests to Gemini,
Claude, and GPT-OSS models through a single Gemini-style interface. It is
NOT the public ``generativelanguage.googleapis.com`` API.
Authentication uses Google OAuth2. Token refresh is done directly with the
OAuth client secret no local proxy required.
Credential sources (checked in order):
1. ``~/.hive/antigravity-accounts.json`` (native OAuth implementation)
2. Antigravity IDE SQLite state DB (macOS / Linux)
"""
from __future__ import annotations
import json
import logging
import re
import time
import uuid
from collections.abc import AsyncIterator, Callable, Iterator
from pathlib import Path
from typing import Any
from framework.llm.provider import LLMProvider, LLMResponse, Tool
from framework.llm.stream_events import (
FinishEvent,
StreamErrorEvent,
StreamEvent,
TextDeltaEvent,
TextEndEvent,
ToolCallEvent,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
_TOKEN_URL = "https://oauth2.googleapis.com/token"
# Fallback order: daily sandbox → autopush sandbox → production
_ENDPOINTS = [
"https://daily-cloudcode-pa.sandbox.googleapis.com",
"https://autopush-cloudcode-pa.sandbox.googleapis.com",
"https://cloudcode-pa.googleapis.com",
]
_DEFAULT_PROJECT_ID = "rising-fact-p41fc"
_TOKEN_REFRESH_BUFFER_SECS = 60
# Credentials file in ~/.hive/ (native implementation)
_ACCOUNTS_FILE = Path.home() / ".hive" / "antigravity-accounts.json"
_IDE_STATE_DB_MAC = (
Path.home()
/ "Library"
/ "Application Support"
/ "Antigravity"
/ "User"
/ "globalStorage"
/ "state.vscdb"
)
_IDE_STATE_DB_LINUX = (
Path.home() / ".config" / "Antigravity" / "User" / "globalStorage" / "state.vscdb"
)
_IDE_STATE_DB_KEY = "antigravityUnifiedStateSync.oauthToken"
_BASE_HEADERS: dict[str, str] = {
# Mimic the Antigravity Electron app so the API accepts the request.
"User-Agent": (
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 "
"(KHTML, like Gecko) Antigravity/1.18.3 Chrome/138.0.7204.235 "
"Electron/37.3.1 Safari/537.36"
),
"X-Goog-Api-Client": "google-cloud-sdk vscode_cloudshelleditor/0.1",
"Client-Metadata": '{"ideType":"ANTIGRAVITY","platform":"MACOS","pluginType":"GEMINI"}',
}
# ---------------------------------------------------------------------------
# Credential loading helpers
# ---------------------------------------------------------------------------
def _load_from_json_file() -> tuple[str | None, str | None, str, float]:
"""Read credentials from JSON accounts file.
Reads from ~/.hive/antigravity-accounts.json.
Returns ``(access_token | None, refresh_token | None, project_id, expires_at)``.
``expires_at`` is a Unix timestamp (seconds); 0.0 means unknown.
"""
if not _ACCOUNTS_FILE.exists():
return None, None, _DEFAULT_PROJECT_ID, 0.0
try:
with open(_ACCOUNTS_FILE, encoding="utf-8") as fh:
data = json.load(fh)
except (OSError, json.JSONDecodeError) as exc:
logger.debug("Failed to read Antigravity accounts file: %s", exc)
return None, None, _DEFAULT_PROJECT_ID, 0.0
accounts = data.get("accounts", [])
if not accounts:
return None, None, _DEFAULT_PROJECT_ID, 0.0
account = next((a for a in accounts if a.get("enabled", True) is not False), accounts[0])
schema_version = data.get("schemaVersion", 1)
if schema_version >= 4:
# V4 schema: refresh = "refreshToken|projectId[|managedProjectId]"
refresh_str = account.get("refresh", "")
parts = refresh_str.split("|") if refresh_str else []
refresh_token: str | None = parts[0] if parts else None
project_id = parts[1] if len(parts) >= 2 and parts[1] else _DEFAULT_PROJECT_ID
access_token: str | None = account.get("access")
expires_ms: int = account.get("expires", 0)
expires_at = float(expires_ms) / 1000.0 if expires_ms else 0.0
# Treat near-expiry tokens as absent so _ensure_token() triggers a refresh.
if access_token and expires_at and time.time() >= expires_at - _TOKEN_REFRESH_BUFFER_SECS:
access_token = None
expires_at = 0.0
return access_token, refresh_token, project_id, expires_at
else:
# V1V3 schema: plain accessToken / refreshToken fields
access_token = account.get("accessToken")
refresh_token = account.get("refreshToken")
# Estimate expiry from last_refresh + 1 h
last_refresh_str: str | None = data.get("last_refresh")
expires_at = 0.0
if last_refresh_str:
try:
from datetime import datetime # noqa: PLC0415
ts = datetime.fromisoformat(last_refresh_str.replace("Z", "+00:00")).timestamp()
expires_at = ts + 3600.0
if time.time() >= expires_at - _TOKEN_REFRESH_BUFFER_SECS:
access_token = None
except (ValueError, TypeError):
pass
return access_token, refresh_token, _DEFAULT_PROJECT_ID, expires_at
def _load_from_ide_db() -> tuple[str | None, str | None, float]:
"""Extract ``(access_token, refresh_token, expires_at)`` from the IDE SQLite DB."""
import base64 # noqa: PLC0415
import sqlite3 # noqa: PLC0415
for db_path in (_IDE_STATE_DB_MAC, _IDE_STATE_DB_LINUX):
if not db_path.exists():
continue
try:
con = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
try:
row = con.execute(
"SELECT value FROM ItemTable WHERE key = ?",
(_IDE_STATE_DB_KEY,),
).fetchone()
finally:
con.close()
if not row:
continue
blob = base64.b64decode(row[0])
candidates = re.findall(rb"[A-Za-z0-9+/=_\-]{40,}", blob)
access_token: str | None = None
refresh_token: str | None = None
for candidate in candidates:
try:
padded = candidate + b"=" * (-len(candidate) % 4)
inner = base64.urlsafe_b64decode(padded)
except Exception:
continue
if not access_token:
m = re.search(rb"ya29\.[A-Za-z0-9_\-\.]+", inner)
if m:
access_token = m.group(0).decode("ascii")
if not refresh_token:
m = re.search(rb"1//[A-Za-z0-9_\-\.]+", inner)
if m:
refresh_token = m.group(0).decode("ascii")
if access_token and refresh_token:
break
if access_token:
# Estimate expiry from DB mtime (IDE refreshes while running)
mtime = db_path.stat().st_mtime
expires_at = mtime + 3600.0
return access_token, refresh_token, expires_at
except Exception as exc:
logger.debug("Failed to read Antigravity IDE state DB: %s", exc)
continue
return None, None, 0.0
def _do_token_refresh(refresh_token: str) -> tuple[str, float] | None:
"""POST to Google OAuth endpoint and return ``(new_access_token, expires_at)``.
The client secret is sourced via ``get_antigravity_client_secret()`` (env var,
config file, or npm package fallback). When unavailable the refresh is attempted
without it Google will reject it for web-app clients, but the npm fallback in
``get_antigravity_client_secret()`` should ensure the secret is found at runtime.
Returns None when the HTTP request fails.
"""
from framework.config import get_antigravity_client_secret # noqa: PLC0415
client_secret = get_antigravity_client_secret()
if not client_secret:
logger.debug(
"Antigravity client secret not configured — attempting refresh without it. "
"Set ANTIGRAVITY_CLIENT_SECRET or run quickstart to configure."
)
import urllib.error # noqa: PLC0415
import urllib.parse # noqa: PLC0415
import urllib.request # noqa: PLC0415
from framework.config import get_antigravity_client_id # noqa: PLC0415
params: dict[str, str] = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": get_antigravity_client_id(),
}
if client_secret:
params["client_secret"] = client_secret
body = urllib.parse.urlencode(params).encode("utf-8")
req = urllib.request.Request(
_TOKEN_URL,
data=body,
headers={"Content-Type": "application/x-www-form-urlencoded"},
method="POST",
)
try:
with urllib.request.urlopen(req, timeout=15) as resp: # noqa: S310
payload = json.loads(resp.read())
access_token: str = payload["access_token"]
expires_in: int = payload.get("expires_in", 3600)
logger.debug("Antigravity token refreshed successfully")
return access_token, time.time() + expires_in
except Exception as exc:
logger.debug("Antigravity token refresh failed: %s", exc)
return None
# ---------------------------------------------------------------------------
# Message conversion helpers
# ---------------------------------------------------------------------------
def _clean_tool_name(name: str) -> str:
"""Sanitize a tool name for the Antigravity function-calling schema."""
name = re.sub(r"[/\s]", "_", name)
if name and not (name[0].isalpha() or name[0] == "_"):
name = "_" + name
return name[:64]
def _to_gemini_contents(
messages: list[dict[str, Any]],
thought_sigs: dict[str, str] | None = None,
) -> list[dict[str, Any]]:
"""Convert OpenAI-format messages to Gemini-style ``contents`` array."""
# Pre-build a map tool_call_id → function_name from assistant messages.
# Tool result messages (role="tool") only carry tool_call_id, not the name,
# but Gemini requires functionResponse.name to match the functionCall.name.
tc_id_to_name: dict[str, str] = {}
for msg in messages:
if msg.get("role") == "assistant":
for tc in msg.get("tool_calls") or []:
tc_id = tc.get("id")
fn_name = tc.get("function", {}).get("name", "")
if tc_id and fn_name:
tc_id_to_name[tc_id] = fn_name
contents: list[dict[str, Any]] = []
# Consecutive tool-result messages must be batched into one user turn.
pending_tool_parts: list[dict[str, Any]] = []
def _flush_tool_results() -> None:
if pending_tool_parts:
contents.append({"role": "user", "parts": list(pending_tool_parts)})
pending_tool_parts.clear()
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content")
if role == "system":
continue # Handled via systemInstruction, not in contents.
if role == "tool":
# OpenAI tool result → Gemini functionResponse part.
result_str = content if isinstance(content, str) else str(content or "")
tc_id = msg.get("tool_call_id", "")
# Look up function name from the pre-built map; fall back to msg.name.
fn_name = tc_id_to_name.get(tc_id) or msg.get("name", "")
pending_tool_parts.append(
{
"functionResponse": {
"name": fn_name,
"id": tc_id,
"response": {"content": result_str},
}
}
)
continue
_flush_tool_results()
gemini_role = "model" if role == "assistant" else "user"
parts: list[dict[str, Any]] = []
if isinstance(content, str) and content:
parts.append({"text": content})
elif isinstance(content, list):
for block in content:
if not isinstance(block, dict):
continue
if block.get("type") == "text":
text = block.get("text", "")
if text:
parts.append({"text": text})
# Other block types (image_url etc.) skipped.
# Assistant messages may carry OpenAI-style tool_calls.
for tc in msg.get("tool_calls") or []:
fn = tc.get("function", {})
try:
args = json.loads(fn.get("arguments", "{}") or "{}")
except (json.JSONDecodeError, TypeError):
args = {}
tc_id = tc.get("id", str(uuid.uuid4()))
fc_part: dict[str, Any] = {
"functionCall": {
"name": fn.get("name", ""),
"args": args,
"id": tc_id,
}
}
if thought_sigs:
sig = thought_sigs.get(tc_id, "")
if sig:
fc_part["thoughtSignature"] = sig # part-level, not inside functionCall
parts.append(fc_part)
if parts:
contents.append({"role": gemini_role, "parts": parts})
_flush_tool_results()
# Gemini requires the first turn to be a user turn. Drop any leading
# model messages so the API doesn't reject with a 400.
while contents and contents[0].get("role") == "model":
contents.pop(0)
return contents
# ---------------------------------------------------------------------------
# Response parsing helpers
# ---------------------------------------------------------------------------
def _map_finish_reason(reason: str) -> str:
return {"STOP": "stop", "MAX_TOKENS": "max_tokens", "OTHER": "tool_use"}.get(
(reason or "").upper(), "stop"
)
def _parse_complete_response(raw: dict[str, Any], model: str) -> LLMResponse:
"""Parse a non-streaming Antigravity response dict → LLMResponse."""
payload: dict[str, Any] = raw.get("response", raw)
candidates: list[dict[str, Any]] = payload.get("candidates", [])
usage: dict[str, Any] = payload.get("usageMetadata", {})
text_parts: list[str] = []
if candidates:
for part in candidates[0].get("content", {}).get("parts", []):
if "text" in part and not part.get("thought"):
text_parts.append(part["text"])
return LLMResponse(
content="".join(text_parts),
model=payload.get("modelVersion", model),
input_tokens=usage.get("promptTokenCount", 0),
output_tokens=usage.get("candidatesTokenCount", 0),
stop_reason=_map_finish_reason(candidates[0].get("finishReason", "") if candidates else ""),
raw_response=raw,
)
def _parse_sse_stream(
response: Any,
model: str,
on_thought_signature: Callable[[str, str], None] | None = None,
) -> Iterator[StreamEvent]:
"""Parse Antigravity SSE response line-by-line → StreamEvents.
Each SSE line looks like::
data: {"response": {"candidates": [...], "usageMetadata": {...}}, "traceId": "..."}
"""
accumulated = ""
input_tokens = 0
output_tokens = 0
finish_reason = ""
for raw_line in response:
line: str = raw_line.decode("utf-8", errors="replace").rstrip("\r\n")
if not line.startswith("data:"):
continue
data_str = line[5:].strip()
if not data_str or data_str == "[DONE]":
continue
try:
data: dict[str, Any] = json.loads(data_str)
except json.JSONDecodeError:
continue
# The outer envelope is {"response": {...}, "traceId": "..."}.
payload: dict[str, Any] = data.get("response", data)
usage = payload.get("usageMetadata", {})
if usage:
input_tokens = usage.get("promptTokenCount", input_tokens)
output_tokens = usage.get("candidatesTokenCount", output_tokens)
for candidate in payload.get("candidates", []):
fr = candidate.get("finishReason", "")
if fr:
finish_reason = fr
for part in candidate.get("content", {}).get("parts", []):
if "text" in part and not part.get("thought"):
delta: str = part["text"]
accumulated += delta
yield TextDeltaEvent(content=delta, snapshot=accumulated)
elif "functionCall" in part:
fc: dict[str, Any] = part["functionCall"]
tool_use_id = fc.get("id") or str(uuid.uuid4())
thought_sig = part.get("thoughtSignature", "") # sibling of functionCall
if thought_sig and on_thought_signature:
on_thought_signature(tool_use_id, thought_sig)
args = fc.get("args", {})
if isinstance(args, str):
try:
args = json.loads(args)
except json.JSONDecodeError:
args = {}
yield ToolCallEvent(
tool_use_id=tool_use_id,
tool_name=fc.get("name", ""),
tool_input=args,
)
if accumulated:
yield TextEndEvent(full_text=accumulated)
yield FinishEvent(
stop_reason=_map_finish_reason(finish_reason),
input_tokens=input_tokens,
output_tokens=output_tokens,
model=model,
)
# ---------------------------------------------------------------------------
# Provider
# ---------------------------------------------------------------------------
class AntigravityProvider(LLMProvider):
"""LLM provider for Google's internal Antigravity Code Assist gateway.
No local proxy required. Handles OAuth token refresh, Gemini-format
request/response conversion, and SSE streaming directly.
"""
def __init__(self, model: str = "gemini-3-flash") -> None:
# Strip any provider prefix ("openai/gemini-3-flash" → "gemini-3-flash").
if "/" in model:
model = model.split("/", 1)[1]
self.model = model
self._access_token: str | None = None
self._refresh_token: str | None = None
self._project_id: str = _DEFAULT_PROJECT_ID
self._token_expires_at: float = 0.0
self._thought_sigs: dict[str, str] = {} # tool_use_id → thoughtSignature
self._init_credentials()
# --- Credential management -------------------------------------------- #
def _init_credentials(self) -> None:
"""Load credentials from the best available source."""
access, refresh, project_id, expires_at = _load_from_json_file()
if refresh:
self._refresh_token = refresh
self._project_id = project_id
self._access_token = access
self._token_expires_at = expires_at
return
# Fall back to IDE state DB.
access, refresh, expires_at = _load_from_ide_db()
if access:
self._access_token = access
self._refresh_token = refresh
self._token_expires_at = expires_at
def has_credentials(self) -> bool:
"""Return True if any credential is available."""
return bool(self._access_token or self._refresh_token)
def _ensure_token(self) -> str:
"""Return a valid access token, refreshing via OAuth if needed."""
if (
self._access_token
and self._token_expires_at
and time.time() < self._token_expires_at - _TOKEN_REFRESH_BUFFER_SECS
):
return self._access_token
if self._refresh_token:
result = _do_token_refresh(self._refresh_token)
if result:
self._access_token, self._token_expires_at = result
return self._access_token
if self._access_token:
logger.warning("Using potentially stale Antigravity access token")
return self._access_token
raise RuntimeError(
"No valid Antigravity credentials. "
"Run: uv run python core/antigravity_auth.py auth account add"
)
# --- Request building -------------------------------------------------- #
def _build_body(
self,
messages: list[dict[str, Any]],
system: str,
tools: list[Tool] | None,
max_tokens: int,
) -> dict[str, Any]:
contents = _to_gemini_contents(messages, self._thought_sigs)
inner: dict[str, Any] = {
"contents": contents,
"generationConfig": {"maxOutputTokens": max_tokens},
}
if system:
inner["systemInstruction"] = {"parts": [{"text": system}]}
if tools:
inner["tools"] = [
{
"functionDeclarations": [
{
"name": _clean_tool_name(t.name),
"description": t.description,
"parameters": t.parameters
or {
"type": "object",
"properties": {},
},
}
for t in tools
]
}
]
return {
"project": self._project_id,
"model": self.model,
"request": inner,
"requestType": "agent",
"userAgent": "antigravity",
"requestId": f"agent-{uuid.uuid4()}",
}
# --- HTTP transport ---------------------------------------------------- #
def _post(self, body: dict[str, Any], *, streaming: bool) -> Any:
"""POST to the Antigravity endpoint, falling back through the endpoint list."""
import urllib.error # noqa: PLC0415
import urllib.request # noqa: PLC0415
token = self._ensure_token()
body_bytes = json.dumps(body).encode("utf-8")
path = (
"/v1internal:streamGenerateContent?alt=sse"
if streaming
else "/v1internal:generateContent"
)
headers = {
**_BASE_HEADERS,
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
if streaming:
headers["Accept"] = "text/event-stream"
last_exc: Exception | None = None
for base_url in _ENDPOINTS:
url = f"{base_url}{path}"
req = urllib.request.Request(url, data=body_bytes, headers=headers, method="POST")
try:
return urllib.request.urlopen(req, timeout=120) # noqa: S310
except urllib.error.HTTPError as exc:
if exc.code in (401, 403) and self._refresh_token:
# Token rejected — refresh once and retry this endpoint.
result = _do_token_refresh(self._refresh_token)
if result:
self._access_token, self._token_expires_at = result
headers["Authorization"] = f"Bearer {self._access_token}"
req2 = urllib.request.Request(
url, data=body_bytes, headers=headers, method="POST"
)
try:
return urllib.request.urlopen(req2, timeout=120) # noqa: S310
except urllib.error.HTTPError as exc2:
last_exc = exc2
continue
last_exc = exc
continue
elif exc.code >= 500:
last_exc = exc
continue
# Include the API response body in the exception for easier debugging.
try:
err_body = exc.read().decode("utf-8", errors="replace")
except Exception:
err_body = "(unreadable)"
raise RuntimeError(f"Antigravity HTTP {exc.code} from {url}: {err_body}") from exc
except (urllib.error.URLError, OSError) as exc:
last_exc = exc
continue
raise RuntimeError(
f"All Antigravity endpoints failed. Last error: {last_exc}"
) from last_exc
# --- LLMProvider interface --------------------------------------------- #
def complete(
self,
messages: list[dict[str, Any]],
system: str = "",
tools: list[Tool] | None = None,
max_tokens: int = 1024,
response_format: dict[str, Any] | None = None,
json_mode: bool = False,
max_retries: int | None = None,
) -> LLMResponse:
if json_mode:
suffix = "\n\nPlease respond with a valid JSON object."
system = (system + suffix) if system else suffix.strip()
body = self._build_body(messages, system, tools, max_tokens)
resp = self._post(body, streaming=False)
return _parse_complete_response(json.loads(resp.read()), self.model)
async def stream(
self,
messages: list[dict[str, Any]],
system: str = "",
tools: list[Tool] | None = None,
max_tokens: int = 4096,
) -> AsyncIterator[StreamEvent]:
import asyncio # noqa: PLC0415
import concurrent.futures # noqa: PLC0415
loop = asyncio.get_running_loop()
queue: asyncio.Queue[StreamEvent | None] = asyncio.Queue()
def _blocking_work() -> None:
try:
body = self._build_body(messages, system, tools, max_tokens)
http_resp = self._post(body, streaming=True)
for event in _parse_sse_stream(
http_resp, self.model, self._thought_sigs.__setitem__
):
loop.call_soon_threadsafe(queue.put_nowait, event)
except Exception as exc:
logger.error("Antigravity stream error: %s", exc)
loop.call_soon_threadsafe(queue.put_nowait, StreamErrorEvent(error=str(exc)))
finally:
loop.call_soon_threadsafe(queue.put_nowait, None) # sentinel
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
fut = loop.run_in_executor(executor, _blocking_work)
try:
while True:
event = await queue.get()
if event is None:
break
yield event
finally:
await fut
executor.shutdown(wait=False)
+106
View File
@@ -0,0 +1,106 @@
"""Model capability checks for LLM providers.
Vision support rules are derived from official vendor documentation:
- ZAI (z.ai): docs.z.ai/guides/vlm GLM-4.6V variants are vision; GLM-5/4.6/4.7 are text-only
- MiniMax: platform.minimax.io/docs minimax-vl-01 is vision; M2.x are text-only
- DeepSeek: api-docs.deepseek.com deepseek-vl2 is vision; chat/reasoner are text-only
- Cerebras: inference-docs.cerebras.ai no vision models at all
- Groq: console.groq.com/docs/vision vision capable; treat as supported by default
- Ollama/LM Studio/vLLM/llama.cpp: local runners denied by default; model names
don't reliably indicate vision support, so users must configure explicitly
"""
from __future__ import annotations
def _model_name(model: str) -> str:
"""Return the bare model name after stripping any 'provider/' prefix."""
if "/" in model:
return model.split("/", 1)[1]
return model
# Step 1: explicit vision allow-list — these always support images regardless
# of what the provider-level rules say. Checked first so that e.g. glm-4.6v
# is allowed even though glm-4.6 is denied.
_VISION_ALLOW_BARE_PREFIXES: tuple[str, ...] = (
# ZAI/GLM vision models (docs.z.ai/guides/vlm)
"glm-4v", # GLM-4V series (legacy)
"glm-4.6v", # GLM-4.6V, GLM-4.6V-flash, GLM-4.6V-flashx
# DeepSeek vision models
"deepseek-vl", # deepseek-vl2, deepseek-vl2-small, deepseek-vl2-tiny
# MiniMax vision model
"minimax-vl", # minimax-vl-01
)
# Step 2: provider-level deny — every model from this provider is text-only.
_TEXT_ONLY_PROVIDER_PREFIXES: tuple[str, ...] = (
# Cerebras: inference-docs.cerebras.ai lists only text models
"cerebras/",
# Local runners: model names don't reliably indicate vision support
"ollama/",
"ollama_chat/",
"lm_studio/",
"vllm/",
"llamacpp/",
)
# Step 3: per-model deny — text-only models within otherwise mixed providers.
# Matched against the bare model name (provider prefix stripped, lower-cased).
# The vision allow-list above is checked first, so vision variants of the same
# family are already handled before these deny patterns are reached.
_TEXT_ONLY_MODEL_BARE_PREFIXES: tuple[str, ...] = (
# --- ZAI / GLM family ---
# text-only: glm-5, glm-4.6, glm-4.7, glm-4.5, zai-glm-*
# vision: glm-4v, glm-4.6v (caught by allow-list above)
"glm-5",
"glm-4.6", # bare glm-4.6 is text-only; glm-4.6v is caught by allow-list
"glm-4.7",
"glm-4.5",
"zai-glm",
# --- DeepSeek ---
# text-only: deepseek-chat, deepseek-coder, deepseek-reasoner
# vision: deepseek-vl2 (caught by allow-list above)
# Note: LiteLLM's deepseek handler may flatten content lists for some models;
# VL models are allowed through and rely on LiteLLM's native VL support.
"deepseek-chat",
"deepseek-coder",
"deepseek-reasoner",
# --- MiniMax ---
# text-only: minimax-m2.*, minimax-text-*, abab* (legacy)
# vision: minimax-vl-01 (caught by allow-list above)
"minimax-m2",
"minimax-text",
"abab",
)
def supports_image_tool_results(model: str) -> bool:
"""Return whether *model* can receive image content in messages.
Used to gate both user-message images and tool-result image blocks.
Logic (checked in order):
1. Vision allow-list True (known vision model, skip all denies)
2. Provider deny False (entire provider is text-only)
3. Model deny False (specific text-only model within a mixed provider)
4. Default True (assume capable; unknown providers and models)
"""
model_lower = model.lower()
bare = _model_name(model_lower)
# 1. Explicit vision allow — takes priority over all denies
if any(bare.startswith(p) for p in _VISION_ALLOW_BARE_PREFIXES):
return True
# 2. Provider-level deny (all models from this provider are text-only)
if any(model_lower.startswith(p) for p in _TEXT_ONLY_PROVIDER_PREFIXES):
return False
# 3. Per-model deny (text-only variants within mixed-capability families)
if any(bare.startswith(p) for p in _TEXT_ONLY_MODEL_BARE_PREFIXES):
return False
# 5. Default: assume vision capable
# Covers: OpenAI, Anthropic, Google, Mistral, Kimi, and other hosted providers
return True
+649 -12
View File
@@ -7,9 +7,13 @@ Groq, and local models.
See: https://docs.litellm.ai/docs/providers
"""
import ast
import asyncio
import hashlib
import json
import logging
import os
import re
import time
from collections.abc import AsyncIterator
from datetime import datetime
@@ -44,7 +48,10 @@ def _patch_litellm_anthropic_oauth() -> None:
"""
try:
from litellm.llms.anthropic.common_utils import AnthropicModelInfo
from litellm.types.llms.anthropic import ANTHROPIC_OAUTH_TOKEN_PREFIX
from litellm.types.llms.anthropic import (
ANTHROPIC_OAUTH_BETA_HEADER,
ANTHROPIC_OAUTH_TOKEN_PREFIX,
)
except ImportError:
logger.warning(
"Could not apply litellm Anthropic OAuth patch — litellm internals may have "
@@ -69,9 +76,27 @@ def _patch_litellm_anthropic_oauth() -> None:
api_key=api_key,
api_base=api_base,
)
# Check both authorization header and x-api-key for OAuth tokens.
# litellm's optionally_handle_anthropic_oauth only checks headers["authorization"],
# but hive passes OAuth tokens via api_key — so litellm puts them into x-api-key.
# Anthropic rejects OAuth tokens in x-api-key; they must go in Authorization: Bearer.
auth = result.get("authorization", "")
if auth.startswith(f"Bearer {ANTHROPIC_OAUTH_TOKEN_PREFIX}"):
x_api_key = result.get("x-api-key", "")
oauth_prefix = f"Bearer {ANTHROPIC_OAUTH_TOKEN_PREFIX}"
auth_is_oauth = auth.startswith(oauth_prefix)
key_is_oauth = x_api_key.startswith(ANTHROPIC_OAUTH_TOKEN_PREFIX)
if auth_is_oauth or key_is_oauth:
token = x_api_key if key_is_oauth else auth.removeprefix("Bearer ").strip()
result.pop("x-api-key", None)
result["authorization"] = f"Bearer {token}"
# Merge the OAuth beta header with any existing beta headers.
existing_beta = result.get("anthropic-beta", "")
beta_parts = (
[b.strip() for b in existing_beta.split(",") if b.strip()] if existing_beta else []
)
if ANTHROPIC_OAUTH_BETA_HEADER not in beta_parts:
beta_parts.append(ANTHROPIC_OAUTH_BETA_HEADER)
result["anthropic-beta"] = ",".join(beta_parts)
return result
AnthropicModelInfo.validate_environment = _patched_validate_environment
@@ -130,11 +155,15 @@ def _patch_litellm_metadata_nonetype() -> None:
if litellm is not None:
_patch_litellm_anthropic_oauth()
_patch_litellm_metadata_nonetype()
# Let litellm silently drop params unsupported by the target provider
# (e.g. stream_options for Anthropic) instead of forwarding them verbatim.
litellm.drop_params = True
RATE_LIMIT_MAX_RETRIES = 10
RATE_LIMIT_BACKOFF_BASE = 2 # seconds
RATE_LIMIT_MAX_DELAY = 120 # seconds - cap to prevent absurd waits
MINIMAX_API_BASE = "https://api.minimax.io/v1"
OPENROUTER_API_BASE = "https://openrouter.ai/api/v1"
# Providers that accept cache_control on message content blocks.
# Anthropic: native ephemeral caching. MiniMax & Z-AI/GLM: pass-through to their APIs.
@@ -159,10 +188,69 @@ def _model_supports_cache_control(model: str) -> bool:
# enforces a coding-agent whitelist that blocks unknown User-Agents.
KIMI_API_BASE = "https://api.kimi.com/coding"
# Claude Code OAuth subscription: the Anthropic API requires a specific
# User-Agent and a billing integrity header for OAuth-authenticated requests.
CLAUDE_CODE_VERSION = "2.1.76"
CLAUDE_CODE_USER_AGENT = f"claude-code/{CLAUDE_CODE_VERSION}"
_CLAUDE_CODE_BILLING_SALT = "59cf53e54c78"
def _sample_js_code_unit(text: str, idx: int) -> str:
"""Return the character at UTF-16 code unit index *idx*, matching JS semantics."""
encoded = text.encode("utf-16-le")
unit_offset = idx * 2
if unit_offset + 2 > len(encoded):
return "0"
code_unit = int.from_bytes(encoded[unit_offset : unit_offset + 2], "little")
return chr(code_unit)
def _claude_code_billing_header(messages: list[dict[str, Any]]) -> str:
"""Build the billing integrity system block required by Anthropic's OAuth path."""
# Find the first user message text
first_text = ""
for msg in messages:
if msg.get("role") != "user":
continue
content = msg.get("content")
if isinstance(content, str):
first_text = content
break
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text" and block.get("text"):
first_text = block["text"]
break
if first_text:
break
sampled = "".join(_sample_js_code_unit(first_text, i) for i in (4, 7, 20))
version_hash = hashlib.sha256(
f"{_CLAUDE_CODE_BILLING_SALT}{sampled}{CLAUDE_CODE_VERSION}".encode()
).hexdigest()
entrypoint = os.environ.get("CLAUDE_CODE_ENTRYPOINT", "").strip() or "cli"
return (
f"x-anthropic-billing-header: cc_version={CLAUDE_CODE_VERSION}.{version_hash[:3]}; "
f"cc_entrypoint={entrypoint}; cch=00000;"
)
# Empty-stream retries use a short fixed delay, not the rate-limit backoff.
# Conversation-structure issues are deterministic — long waits don't help.
EMPTY_STREAM_MAX_RETRIES = 3
EMPTY_STREAM_RETRY_DELAY = 1.0 # seconds
OPENROUTER_TOOL_COMPAT_ERROR_SNIPPETS = (
"no endpoints found that support tool use",
"no endpoints available that support tool use",
"provider routing",
)
OPENROUTER_TOOL_CALL_RE = re.compile(
r"<\|tool_call_start\|>\s*(.*?)\s*<\|tool_call_end\|>",
re.DOTALL,
)
OPENROUTER_TOOL_COMPAT_CACHE_TTL_SECONDS = 3600
# OpenRouter routing can change over time, so tool-compat caching must expire.
OPENROUTER_TOOL_COMPAT_MODEL_CACHE: dict[str, float] = {}
# Directory for dumping failed requests
FAILED_REQUESTS_DIR = Path.home() / ".hive" / "failed_requests"
@@ -205,6 +293,24 @@ def _prune_failed_request_dumps(max_files: int = MAX_FAILED_REQUEST_DUMPS) -> No
pass # Best-effort — never block the caller
def _remember_openrouter_tool_compat_model(model: str) -> None:
"""Cache OpenRouter tool-compat fallback for a bounded time window."""
OPENROUTER_TOOL_COMPAT_MODEL_CACHE[model] = (
time.monotonic() + OPENROUTER_TOOL_COMPAT_CACHE_TTL_SECONDS
)
def _is_openrouter_tool_compat_cached(model: str) -> bool:
"""Return True when the cached OpenRouter compat entry is still fresh."""
expires_at = OPENROUTER_TOOL_COMPAT_MODEL_CACHE.get(model)
if expires_at is None:
return False
if expires_at <= time.monotonic():
OPENROUTER_TOOL_COMPAT_MODEL_CACHE.pop(model, None)
return False
return True
def _dump_failed_request(
model: str,
kwargs: dict[str, Any],
@@ -408,11 +514,19 @@ class LiteLLMProvider(LLMProvider):
self.api_key = api_key
self.api_base = api_base or self._default_api_base_for_model(_original_model)
self.extra_kwargs = kwargs
# Detect Claude Code OAuth subscription by checking the api_key prefix.
self._claude_code_oauth = bool(api_key and api_key.startswith("sk-ant-oat"))
if self._claude_code_oauth:
# Anthropic requires a specific User-Agent for OAuth requests.
eh = self.extra_kwargs.setdefault("extra_headers", {})
eh.setdefault("user-agent", CLAUDE_CODE_USER_AGENT)
# The Codex ChatGPT backend (chatgpt.com/backend-api/codex) rejects
# several standard OpenAI params: max_output_tokens, stream_options.
self._codex_backend = bool(
self.api_base and "chatgpt.com/backend-api/codex" in self.api_base
)
# Antigravity routes through a local OpenAI-compatible proxy — no patches needed.
self._antigravity = bool(self.api_base and "localhost:8069" in self.api_base)
if litellm is None:
raise ImportError(
@@ -431,6 +545,8 @@ class LiteLLMProvider(LLMProvider):
model_lower = model.lower()
if model_lower.startswith("minimax/") or model_lower.startswith("minimax-"):
return MINIMAX_API_BASE
if model_lower.startswith("openrouter/"):
return OPENROUTER_API_BASE
if model_lower.startswith("kimi/"):
return KIMI_API_BASE
if model_lower.startswith("hive/"):
@@ -773,6 +889,9 @@ class LiteLLMProvider(LLMProvider):
return await self._collect_stream_to_response(stream_iter)
full_messages: list[dict[str, Any]] = []
if self._claude_code_oauth:
billing = _claude_code_billing_header(messages)
full_messages.append({"role": "system", "content": billing})
if system:
sys_msg: dict[str, Any] = {"role": "system", "content": system}
if _model_supports_cache_control(self.model):
@@ -834,11 +953,504 @@ class LiteLLMProvider(LLMProvider):
},
}
def _is_anthropic_model(self) -> bool:
"""Return True when the configured model targets Anthropic."""
model = (self.model or "").lower()
return model.startswith("anthropic/") or model.startswith("claude-")
def _is_minimax_model(self) -> bool:
"""Return True when the configured model targets MiniMax."""
model = (self.model or "").lower()
return model.startswith("minimax/") or model.startswith("minimax-")
def _is_openrouter_model(self) -> bool:
"""Return True when the configured model targets OpenRouter."""
model = (self.model or "").lower()
if model.startswith("openrouter/"):
return True
api_base = (self.api_base or "").lower()
return "openrouter.ai/api/v1" in api_base
def _should_use_openrouter_tool_compat(
self,
error: BaseException,
tools: list[Tool] | None,
) -> bool:
"""Return True when OpenRouter rejects native tool use for the model."""
if not tools or not self._is_openrouter_model():
return False
error_text = str(error).lower()
return "openrouter" in error_text and any(
snippet in error_text for snippet in OPENROUTER_TOOL_COMPAT_ERROR_SNIPPETS
)
@staticmethod
def _extract_json_object(text: str) -> dict[str, Any] | None:
"""Extract the first JSON object from a model response."""
candidates = [text.strip()]
stripped = text.strip()
if stripped.startswith("```"):
fence_lines = stripped.splitlines()
if len(fence_lines) >= 3:
candidates.append("\n".join(fence_lines[1:-1]).strip())
decoder = json.JSONDecoder()
for candidate in candidates:
if not candidate:
continue
try:
parsed = json.loads(candidate)
except json.JSONDecodeError:
parsed = None
if isinstance(parsed, dict):
return parsed
for start_idx, char in enumerate(candidate):
if char != "{":
continue
try:
parsed, _ = decoder.raw_decode(candidate[start_idx:])
except json.JSONDecodeError:
continue
if isinstance(parsed, dict):
return parsed
return None
def _parse_openrouter_tool_compat_response(
self,
content: str,
tools: list[Tool],
) -> tuple[str, list[dict[str, Any]]]:
"""Parse JSON tool-compat output into assistant text and tool calls."""
payload = self._extract_json_object(content)
if payload is None:
text_tool_content, text_tool_calls = self._parse_openrouter_text_tool_calls(
content,
tools,
)
if text_tool_calls:
logger.info(
"[openrouter-tool-compat] Parsed textual tool-call markers for %s",
self.model,
)
return text_tool_content, text_tool_calls
logger.info(
"[openrouter-tool-compat] %s returned non-JSON fallback content; "
"treating it as plain text.",
self.model,
)
return content.strip(), []
assistant_text = payload.get("assistant_response")
if not isinstance(assistant_text, str):
assistant_text = payload.get("content")
if not isinstance(assistant_text, str):
assistant_text = payload.get("response")
if not isinstance(assistant_text, str):
assistant_text = ""
tool_calls_raw = payload.get("tool_calls")
if not tool_calls_raw and {"name", "arguments"} <= payload.keys():
tool_calls_raw = [payload]
elif isinstance(payload.get("tool_call"), dict):
tool_calls_raw = [payload["tool_call"]]
if not isinstance(tool_calls_raw, list):
tool_calls_raw = []
allowed_tool_names = {tool.name for tool in tools}
tool_calls: list[dict[str, Any]] = []
compat_prefix = f"openrouter_compat_{time.time_ns()}"
for idx, raw_call in enumerate(tool_calls_raw):
if not isinstance(raw_call, dict):
continue
function_block = raw_call.get("function")
function_name = (
raw_call.get("name")
or raw_call.get("tool_name")
or (function_block.get("name") if isinstance(function_block, dict) else None)
)
if not isinstance(function_name, str) or function_name not in allowed_tool_names:
if function_name:
logger.warning(
"[openrouter-tool-compat] Ignoring unknown tool '%s' for model %s",
function_name,
self.model,
)
continue
arguments = raw_call.get("arguments")
if arguments is None:
arguments = raw_call.get("tool_input")
if arguments is None:
arguments = raw_call.get("input")
if arguments is None and isinstance(function_block, dict):
arguments = function_block.get("arguments")
if arguments is None:
arguments = {}
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {"_raw": arguments}
elif not isinstance(arguments, dict):
arguments = {"value": arguments}
tool_calls.append(
{
"id": f"{compat_prefix}_{idx}",
"name": function_name,
"input": arguments,
}
)
return assistant_text.strip(), tool_calls
@staticmethod
def _close_truncated_json_fragment(fragment: str) -> str:
"""Close a truncated JSON fragment by balancing quotes/brackets."""
stack: list[str] = []
in_string = False
escaped = False
normalized = fragment.rstrip()
while normalized and normalized[-1] in ",:{[":
normalized = normalized[:-1].rstrip()
for char in normalized:
if in_string:
if escaped:
escaped = False
elif char == "\\":
escaped = True
elif char == '"':
in_string = False
continue
if char == '"':
in_string = True
elif char in "{[":
stack.append(char)
elif char == "}" and stack and stack[-1] == "{":
stack.pop()
elif char == "]" and stack and stack[-1] == "[":
stack.pop()
if in_string:
if escaped:
normalized = normalized[:-1]
normalized += '"'
for opener in reversed(stack):
normalized += "}" if opener == "{" else "]"
return normalized
def _repair_truncated_tool_arguments(self, raw_arguments: str) -> dict[str, Any] | None:
"""Try to recover a truncated JSON object from tool-call arguments."""
stripped = raw_arguments.strip()
if not stripped or stripped[0] != "{":
return None
max_trim = min(len(stripped), 256)
for trim in range(max_trim + 1):
candidate = stripped[: len(stripped) - trim].rstrip()
if not candidate:
break
candidate = self._close_truncated_json_fragment(candidate)
try:
parsed = json.loads(candidate)
except json.JSONDecodeError:
continue
if isinstance(parsed, dict):
return parsed
return None
def _parse_tool_call_arguments(self, raw_arguments: str, tool_name: str) -> dict[str, Any]:
"""Parse streamed tool arguments, repairing truncation when possible."""
try:
parsed = json.loads(raw_arguments) if raw_arguments else {}
except json.JSONDecodeError:
parsed = None
if isinstance(parsed, dict):
return parsed
repaired = self._repair_truncated_tool_arguments(raw_arguments)
if repaired is not None:
logger.warning(
"[tool-args] Recovered truncated arguments for %s on %s",
tool_name,
self.model,
)
return repaired
raise ValueError(
f"Failed to parse tool call arguments for '{tool_name}' (likely truncated JSON)."
)
def _parse_openrouter_text_tool_calls(
self,
content: str,
tools: list[Tool],
) -> tuple[str, list[dict[str, Any]]]:
"""Parse textual OpenRouter tool calls into synthetic tool calls.
Supports both:
- Marker wrapped payloads: <|tool_call_start|>...<|tool_call_end|>
- Plain one-line tool calls: ask_user("...", ["..."])
"""
tools_by_name = {tool.name: tool for tool in tools}
compat_prefix = f"openrouter_compat_{time.time_ns()}"
tool_calls: list[dict[str, Any]] = []
segment_index = 0
for match in OPENROUTER_TOOL_CALL_RE.finditer(content):
parsed_calls = self._parse_openrouter_text_tool_call_block(
block=match.group(1),
tools_by_name=tools_by_name,
compat_prefix=f"{compat_prefix}_{segment_index}",
)
if parsed_calls:
segment_index += 1
tool_calls.extend(parsed_calls)
stripped_content = OPENROUTER_TOOL_CALL_RE.sub("", content)
retained_lines: list[str] = []
for line in stripped_content.splitlines():
stripped_line = line.strip()
if not stripped_line:
retained_lines.append(line)
continue
candidate = stripped_line
if candidate.startswith("`") and candidate.endswith("`") and len(candidate) > 1:
candidate = candidate[1:-1].strip()
parsed_calls = self._parse_openrouter_text_tool_call_block(
block=candidate,
tools_by_name=tools_by_name,
compat_prefix=f"{compat_prefix}_{segment_index}",
)
if parsed_calls:
segment_index += 1
tool_calls.extend(parsed_calls)
continue
retained_lines.append(line)
stripped_text = "\n".join(retained_lines).strip()
return stripped_text, tool_calls
def _parse_openrouter_text_tool_call_block(
self,
block: str,
tools_by_name: dict[str, Tool],
compat_prefix: str,
) -> list[dict[str, Any]]:
"""Parse a single textual tool-call block like [tool(arg='x')]."""
try:
parsed = ast.parse(block.strip(), mode="eval").body
except SyntaxError:
return []
call_nodes = parsed.elts if isinstance(parsed, ast.List) else [parsed]
tool_calls: list[dict[str, Any]] = []
for call_index, call_node in enumerate(call_nodes):
if not isinstance(call_node, ast.Call) or not isinstance(call_node.func, ast.Name):
continue
tool_name = call_node.func.id
tool = tools_by_name.get(tool_name)
if tool is None:
continue
try:
tool_input = self._parse_openrouter_text_tool_call_arguments(
call_node=call_node,
tool=tool,
)
except (ValueError, SyntaxError):
continue
tool_calls.append(
{
"id": f"{compat_prefix}_{call_index}",
"name": tool_name,
"input": tool_input,
}
)
return tool_calls
@staticmethod
def _parse_openrouter_text_tool_call_arguments(
call_node: ast.Call,
tool: Tool,
) -> dict[str, Any]:
"""Parse positional/keyword args from a textual tool call."""
properties = tool.parameters.get("properties", {})
positional_keys = list(properties.keys())
tool_input: dict[str, Any] = {}
if len(call_node.args) > len(positional_keys):
raise ValueError("Too many positional args for textual tool call")
for idx, arg_node in enumerate(call_node.args):
tool_input[positional_keys[idx]] = ast.literal_eval(arg_node)
for kwarg in call_node.keywords:
if kwarg.arg is None:
raise ValueError("Star args are not supported in textual tool calls")
tool_input[kwarg.arg] = ast.literal_eval(kwarg.value)
return tool_input
def _build_openrouter_tool_compat_messages(
self,
messages: list[dict[str, Any]],
system: str,
tools: list[Tool],
) -> list[dict[str, Any]]:
"""Build a JSON-only prompt for models without native tool support."""
tool_specs = [
{
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
}
for tool in tools
]
compat_instruction = (
"Tool compatibility mode is active because this OpenRouter model does not support "
"native function calling on the routed provider.\n"
"Return exactly one JSON object and nothing else.\n"
'Schema: {"assistant_response": string, '
'"tool_calls": [{"name": string, "arguments": object}]}\n'
"Rules:\n"
"- If a tool is required, put one or more entries in tool_calls "
"and do not invent tool results.\n"
"- If no tool is required, set tool_calls to [] and put the full "
"answer in assistant_response.\n"
"- Only use tool names from the allowed tool list.\n"
"- arguments must always be valid JSON objects.\n"
f"Allowed tools:\n{json.dumps(tool_specs, ensure_ascii=True)}"
)
compat_system = compat_instruction if not system else f"{system}\n\n{compat_instruction}"
full_messages: list[dict[str, Any]] = [{"role": "system", "content": compat_system}]
full_messages.extend(messages)
return [
message
for message in full_messages
if not (
message.get("role") == "assistant"
and not message.get("content")
and not message.get("tool_calls")
)
]
async def _acomplete_via_openrouter_tool_compat(
self,
messages: list[dict[str, Any]],
system: str,
tools: list[Tool],
max_tokens: int,
) -> LLMResponse:
"""Emulate tool calling via JSON when OpenRouter rejects native tools."""
full_messages = self._build_openrouter_tool_compat_messages(messages, system, tools)
kwargs: dict[str, Any] = {
"model": self.model,
"messages": full_messages,
"max_tokens": max_tokens,
**self.extra_kwargs,
}
if self.api_key:
kwargs["api_key"] = self.api_key
if self.api_base:
kwargs["api_base"] = self.api_base
response = await self._acompletion_with_rate_limit_retry(**kwargs)
raw_content = response.choices[0].message.content or ""
assistant_text, tool_calls = self._parse_openrouter_tool_compat_response(
raw_content,
tools,
)
usage = response.usage
input_tokens = usage.prompt_tokens if usage else 0
output_tokens = usage.completion_tokens if usage else 0
stop_reason = "tool_calls" if tool_calls else (response.choices[0].finish_reason or "stop")
return LLMResponse(
content=assistant_text,
model=response.model or self.model,
input_tokens=input_tokens,
output_tokens=output_tokens,
stop_reason=stop_reason,
raw_response={
"compat_mode": "openrouter_tool_emulation",
"tool_calls": tool_calls,
"response": response,
},
)
async def _stream_via_openrouter_tool_compat(
self,
messages: list[dict[str, Any]],
system: str,
tools: list[Tool],
max_tokens: int,
) -> AsyncIterator[StreamEvent]:
"""Fallback stream for OpenRouter models without native tool support."""
from framework.llm.stream_events import (
FinishEvent,
StreamErrorEvent,
TextDeltaEvent,
TextEndEvent,
ToolCallEvent,
)
logger.info(
"[openrouter-tool-compat] Using compatibility mode for %s",
self.model,
)
try:
response = await self._acomplete_via_openrouter_tool_compat(
messages=messages,
system=system,
tools=tools,
max_tokens=max_tokens,
)
except Exception as e:
yield StreamErrorEvent(error=str(e), recoverable=False)
return
raw_response = response.raw_response if isinstance(response.raw_response, dict) else {}
tool_calls = raw_response.get("tool_calls", [])
if response.content:
yield TextDeltaEvent(content=response.content, snapshot=response.content)
yield TextEndEvent(full_text=response.content)
for tool_call in tool_calls:
yield ToolCallEvent(
tool_use_id=tool_call["id"],
tool_name=tool_call["name"],
tool_input=tool_call["input"],
)
yield FinishEvent(
stop_reason=response.stop_reason,
input_tokens=response.input_tokens,
output_tokens=response.output_tokens,
model=response.model,
)
async def _stream_via_nonstream_completion(
self,
messages: list[dict[str, Any]],
@@ -882,12 +1494,11 @@ class LiteLLMProvider(LLMProvider):
tool_calls = msg.tool_calls or []
for tc in tool_calls:
parsed_args: Any
args = tc.function.arguments if tc.function else ""
try:
parsed_args = json.loads(args) if args else {}
except json.JSONDecodeError:
parsed_args = {"_raw": args}
parsed_args = self._parse_tool_call_arguments(
args,
tc.function.name if tc.function else "",
)
yield ToolCallEvent(
tool_use_id=getattr(tc, "id", ""),
tool_name=tc.function.name if tc.function else "",
@@ -946,7 +1557,20 @@ class LiteLLMProvider(LLMProvider):
yield event
return
if tools and self._is_openrouter_model() and _is_openrouter_tool_compat_cached(self.model):
async for event in self._stream_via_openrouter_tool_compat(
messages=messages,
system=system,
tools=tools,
max_tokens=max_tokens,
):
yield event
return
full_messages: list[dict[str, Any]] = []
if self._claude_code_oauth:
billing = _claude_code_billing_header(messages)
full_messages.append({"role": "system", "content": billing})
if system:
sys_msg: dict[str, Any] = {"role": "system", "content": system}
if _model_supports_cache_control(self.model):
@@ -984,9 +1608,12 @@ class LiteLLMProvider(LLMProvider):
"messages": full_messages,
"max_tokens": max_tokens,
"stream": True,
"stream_options": {"include_usage": True},
**self.extra_kwargs,
}
# stream_options is OpenAI-specific; Anthropic rejects it with 400.
# Only include it for providers that support it.
if not self._is_anthropic_model():
kwargs["stream_options"] = {"include_usage": True}
if self.api_key:
kwargs["api_key"] = self.api_key
if self.api_base:
@@ -1092,10 +1719,10 @@ class LiteLLMProvider(LLMProvider):
if choice.finish_reason:
stream_finish_reason = choice.finish_reason
for _idx, tc_data in sorted(tool_calls_acc.items()):
try:
parsed_args = json.loads(tc_data["arguments"])
except (json.JSONDecodeError, KeyError):
parsed_args = {"_raw": tc_data.get("arguments", "")}
parsed_args = self._parse_tool_call_arguments(
tc_data.get("arguments", ""),
tc_data.get("name", ""),
)
tail_events.append(
ToolCallEvent(
tool_use_id=tc_data["id"],
@@ -1276,6 +1903,16 @@ class LiteLLMProvider(LLMProvider):
return
except Exception as e:
if self._should_use_openrouter_tool_compat(e, tools):
_remember_openrouter_tool_compat_model(self.model)
async for event in self._stream_via_openrouter_tool_compat(
messages=messages,
system=system,
tools=tools or [],
max_tokens=max_tokens,
):
yield event
return
if _is_stream_transient_error(e) and attempt < RATE_LIMIT_MAX_RETRIES:
wait = _compute_retry_delay(attempt, exception=e)
logger.warning(
+2
View File
@@ -45,6 +45,8 @@ class ToolResult:
tool_use_id: str
content: str
is_error: bool = False
image_content: list[dict[str, Any]] | None = None
is_skill_content: bool = False # AS-10: marks activated skill body, protected from pruning
class LLMProvider(ABC):
+6 -1
View File
@@ -208,7 +208,12 @@ def configure_logging(
# Suppress noisy LiteLLM INFO logs (model/provider line + Provider List URL
# printed on every single completion call). Warnings and errors still show.
logging.getLogger("LiteLLM").setLevel(logging.WARNING)
# Honour LITELLM_LOG env var so users can opt-in to debug output.
_litellm_level = os.getenv("LITELLM_LOG", "").upper()
if _litellm_level and hasattr(logging, _litellm_level):
logging.getLogger("LiteLLM").setLevel(getattr(logging, _litellm_level))
else:
logging.getLogger("LiteLLM").setLevel(logging.WARNING)
# When in JSON mode, configure known third-party loggers to use JSON formatter
# This ensures libraries like LiteLLM, httpcore also output clean JSON
+168 -20
View File
@@ -1,7 +1,7 @@
"""MCP Client for connecting to Model Context Protocol servers.
This module provides a client for connecting to MCP servers and invoking their tools.
Supports both STDIO and HTTP transports using the official MCP Python SDK.
Supports STDIO, HTTP, UNIX socket, and SSE transports using the official MCP Python SDK.
"""
import asyncio
@@ -22,7 +22,7 @@ class MCPServerConfig:
"""Configuration for an MCP server connection."""
name: str
transport: Literal["stdio", "http"]
transport: Literal["stdio", "http", "unix", "sse"]
# For STDIO transport
command: str | None = None
@@ -33,6 +33,7 @@ class MCPServerConfig:
# For HTTP transport
url: str | None = None
headers: dict[str, str] = field(default_factory=dict)
socket_path: str | None = None
# Optional metadata
description: str = ""
@@ -52,7 +53,7 @@ class MCPClient:
"""
Client for communicating with MCP servers.
Supports both STDIO and HTTP transports using the official MCP SDK.
Supports STDIO, HTTP, UNIX socket, and SSE transports using the official MCP SDK.
Manages the connection lifecycle and provides methods to list and invoke tools.
"""
@@ -68,6 +69,7 @@ class MCPClient:
self._read_stream = None
self._write_stream = None
self._stdio_context = None # Context manager for stdio_client
self._sse_context = None # Context manager for sse_client
self._errlog_handle = None # Track errlog file handle for cleanup
self._http_client: httpx.Client | None = None
self._tools: dict[str, MCPTool] = {}
@@ -141,6 +143,10 @@ class MCPClient:
self._connect_stdio()
elif self.config.transport == "http":
self._connect_http()
elif self.config.transport == "unix":
self._connect_unix()
elif self.config.transport == "sse":
self._connect_sse()
else:
raise ValueError(f"Unsupported transport: {self.config.transport}")
@@ -266,10 +272,94 @@ class MCPClient:
logger.warning(f"Health check failed for MCP server '{self.config.name}': {e}")
# Continue anyway, server might not have health endpoint
def _connect_unix(self) -> None:
"""Connect to MCP server via UNIX domain socket transport."""
if not self.config.url:
raise ValueError("url is required for UNIX transport")
if not self.config.socket_path:
raise ValueError("socket_path is required for UNIX transport")
self._http_client = httpx.Client(
base_url=self.config.url,
headers=self.config.headers,
timeout=30.0,
transport=httpx.HTTPTransport(uds=self.config.socket_path),
)
try:
response = self._http_client.get("/health")
response.raise_for_status()
logger.info(
"Connected to MCP server '%s' via UNIX socket at %s",
self.config.name,
self.config.socket_path,
)
except Exception as e:
logger.warning(f"Health check failed for MCP server '{self.config.name}': {e}")
# Continue anyway, server might not have health endpoint
def _connect_sse(self) -> None:
"""Connect to MCP server via SSE transport using MCP SDK with persistent session."""
if not self.config.url:
raise ValueError("url is required for SSE transport")
try:
loop_started = threading.Event()
connection_ready = threading.Event()
connection_error = []
def run_event_loop():
"""Run event loop in background thread."""
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
loop_started.set()
async def init_connection():
try:
from mcp import ClientSession
from mcp.client.sse import sse_client
self._sse_context = sse_client(
self.config.url,
headers=self.config.headers,
timeout=30.0,
)
(
self._read_stream,
self._write_stream,
) = await self._sse_context.__aenter__()
self._session = ClientSession(self._read_stream, self._write_stream)
await self._session.__aenter__()
await self._session.initialize()
connection_ready.set()
except Exception as e:
connection_error.append(e)
connection_ready.set()
self._loop.create_task(init_connection())
self._loop.run_forever()
self._loop_thread = threading.Thread(target=run_event_loop, daemon=True)
self._loop_thread.start()
loop_started.wait(timeout=5)
if not loop_started.is_set():
raise RuntimeError("Event loop failed to start")
connection_ready.wait(timeout=10)
if connection_error:
raise connection_error[0]
logger.info(f"Connected to MCP server '{self.config.name}' via SSE")
except Exception as e:
raise RuntimeError(f"Failed to connect to MCP server: {e}") from e
def _discover_tools(self) -> None:
"""Discover available tools from the MCP server."""
try:
if self.config.transport == "stdio":
if self.config.transport in {"stdio", "sse"}:
tools_list = self._run_async(self._list_tools_stdio_async())
else:
tools_list = self._list_tools_http()
@@ -371,9 +461,37 @@ class MCPClient:
if self.config.transport == "stdio":
with self._stdio_call_lock:
return self._run_async(self._call_tool_stdio_async(tool_name, arguments))
elif self.config.transport == "sse":
return self._call_tool_with_retry(
lambda: self._run_async(self._call_tool_stdio_async(tool_name, arguments))
)
elif self.config.transport == "unix":
return self._call_tool_with_retry(lambda: self._call_tool_http(tool_name, arguments))
else:
return self._call_tool_http(tool_name, arguments)
def _call_tool_with_retry(self, call: Any) -> Any:
"""Retry transient MCP transport failures once after reconnecting."""
if self.config.transport == "stdio":
return call()
if self.config.transport not in {"unix", "sse"}:
return call()
try:
return call()
except (httpx.ConnectError, httpx.ReadTimeout) as original_error:
logger.warning(
"Retrying MCP tool call after transport error from '%s': %s",
self.config.name,
original_error,
)
self._reconnect()
try:
return call()
except (httpx.ConnectError, httpx.ReadTimeout) as retry_error:
raise original_error from retry_error
async def _call_tool_stdio_async(self, tool_name: str, arguments: dict[str, Any]) -> Any:
"""Call tool via STDIO protocol using persistent session."""
if not self._session:
@@ -391,17 +509,30 @@ class MCPClient:
error_text = content_item.text
raise RuntimeError(f"MCP tool '{tool_name}' failed: {error_text}")
# Extract content
# Extract content — preserve image blocks alongside text
if result.content:
# MCP returns content as a list of content items
if len(result.content) > 0:
content_item = result.content[0]
# Check if it's a text content item
if hasattr(content_item, "text"):
return content_item.text
elif hasattr(content_item, "data"):
return content_item.data
return result.content
text_parts: list[str] = []
image_parts: list[dict[str, Any]] = []
for item in result.content:
if hasattr(item, "text"):
text_parts.append(item.text)
elif hasattr(item, "data") and hasattr(item, "mimeType"):
# MCP ImageContent — preserve as structured image block
image_parts.append(
{
"type": "image_url",
"image_url": {
"url": f"data:{item.mimeType};base64,{item.data}",
},
}
)
elif hasattr(item, "data"):
text_parts.append(str(item.data))
text = "\n".join(text_parts) if text_parts else ""
if image_parts:
return {"_text": text, "_images": image_parts}
return text if text else None
return None
@@ -433,18 +564,24 @@ class MCPClient:
except Exception as e:
raise RuntimeError(f"Failed to call tool via HTTP: {e}") from e
def _reconnect(self) -> None:
"""Reconnect to the configured MCP server."""
logger.info(f"Reconnecting to MCP server '{self.config.name}'...")
self.disconnect()
self.connect()
_CLEANUP_TIMEOUT = 10
_THREAD_JOIN_TIMEOUT = 12
async def _cleanup_stdio_async(self) -> None:
"""Async cleanup for STDIO session and context managers.
"""Async cleanup for persistent MCP session and context managers.
Cleanup order is critical:
- The session must be closed BEFORE the stdio_context because the session
depends on the streams provided by stdio_context.
- This mirrors the initialization order in _connect_stdio(), where
stdio_context is entered first (providing streams), then the session is
created with those streams and entered.
- The session must be closed BEFORE the transport context manager because the
session depends on the streams provided by that context.
- This mirrors the initialization order in _connect_stdio() / _connect_sse(),
where the transport context is entered first (providing streams), then the
session is created with those streams and entered.
- Do not change this ordering without carefully considering these dependencies.
"""
# First: close session (depends on stdio_context streams)
@@ -477,6 +614,16 @@ class MCPClient:
finally:
self._stdio_context = None
try:
if self._sse_context:
await self._sse_context.__aexit__(None, None, None)
except asyncio.CancelledError:
logger.debug("SSE context cleanup was cancelled; proceeding with best-effort shutdown")
except Exception as e:
logger.warning(f"Error closing SSE context: {e}")
finally:
self._sse_context = None
# Third: close errlog file handle if we opened one
if self._errlog_handle is not None:
try:
@@ -552,6 +699,7 @@ class MCPClient:
# Setting None to None is safe and ensures clean state.
self._session = None
self._stdio_context = None
self._sse_context = None
self._read_stream = None
self._write_stream = None
self._loop = None
@@ -0,0 +1,255 @@
"""Shared MCP client connection management."""
import logging
import threading
from typing import Any
import httpx
from framework.runner.mcp_client import MCPClient, MCPServerConfig
logger = logging.getLogger(__name__)
class MCPConnectionManager:
"""Process-wide MCP client pool keyed by server name."""
_instance = None
_lock = threading.Lock()
def __init__(self) -> None:
self._pool: dict[str, MCPClient] = {}
self._refcounts: dict[str, int] = {}
self._configs: dict[str, MCPServerConfig] = {}
self._pool_lock = threading.Lock()
# Transition events keep callers from racing a connect/reconnect/disconnect.
self._transitions: dict[str, threading.Event] = {}
@classmethod
def get_instance(cls) -> "MCPConnectionManager":
"""Return the process-level singleton instance."""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
@staticmethod
def _is_connected(client: MCPClient | None) -> bool:
return bool(client and getattr(client, "_connected", False))
def acquire(self, config: MCPServerConfig) -> MCPClient:
"""Get or create a shared connection and increment its refcount."""
server_name = config.name
while True:
should_connect = False
transition_event: threading.Event | None = None
with self._pool_lock:
client = self._pool.get(server_name)
if self._is_connected(client) and server_name not in self._transitions:
new_refcount = self._refcounts.get(server_name, 0) + 1
self._refcounts[server_name] = new_refcount
self._configs[server_name] = config
logger.debug(
"Reusing pooled connection for MCP server '%s' (refcount=%d)",
server_name,
new_refcount,
)
return client
transition_event = self._transitions.get(server_name)
if transition_event is None:
transition_event = threading.Event()
self._transitions[server_name] = transition_event
self._configs[server_name] = config
should_connect = True
if not should_connect:
transition_event.wait()
continue
client = MCPClient(config)
try:
client.connect()
except Exception:
with self._pool_lock:
current = self._transitions.get(server_name)
if current is transition_event:
self._transitions.pop(server_name, None)
if (
server_name not in self._pool
and self._refcounts.get(server_name, 0) <= 0
):
self._configs.pop(server_name, None)
transition_event.set()
raise
with self._pool_lock:
current = self._transitions.get(server_name)
if current is transition_event:
self._pool[server_name] = client
self._refcounts[server_name] = self._refcounts.get(server_name, 0) + 1
self._configs[server_name] = config
self._transitions.pop(server_name, None)
transition_event.set()
return client
client.disconnect()
def release(self, server_name: str) -> None:
"""Decrement refcount and disconnect when the last user releases."""
while True:
disconnect_client: MCPClient | None = None
transition_event: threading.Event | None = None
should_disconnect = False
with self._pool_lock:
transition_event = self._transitions.get(server_name)
if transition_event is None:
refcount = self._refcounts.get(server_name, 0)
if refcount <= 0:
return
if refcount > 1:
self._refcounts[server_name] = refcount - 1
return
disconnect_client = self._pool.pop(server_name, None)
self._refcounts.pop(server_name, None)
transition_event = threading.Event()
self._transitions[server_name] = transition_event
should_disconnect = True
if not should_disconnect:
transition_event.wait()
continue
try:
if disconnect_client is not None:
disconnect_client.disconnect()
finally:
with self._pool_lock:
current = self._transitions.get(server_name)
if current is transition_event:
self._transitions.pop(server_name, None)
transition_event.set()
return
def health_check(self, server_name: str) -> bool:
"""Return True when the pooled connection appears healthy."""
while True:
with self._pool_lock:
transition_event = self._transitions.get(server_name)
if transition_event is None:
client = self._pool.get(server_name)
config = self._configs.get(server_name)
break
transition_event.wait()
if client is None or config is None:
return False
try:
if config.transport == "stdio":
client.list_tools()
return True
if not config.url:
return False
client_kwargs: dict[str, Any] = {
"base_url": config.url,
"headers": config.headers,
"timeout": 5.0,
}
if config.transport == "unix":
if not config.socket_path:
return False
client_kwargs["transport"] = httpx.HTTPTransport(uds=config.socket_path)
with httpx.Client(**client_kwargs) as http_client:
response = http_client.get("/health")
response.raise_for_status()
return True
except Exception:
return False
def reconnect(self, server_name: str) -> MCPClient:
"""Force a disconnect and replace the pooled client with a fresh one."""
while True:
transition_event: threading.Event | None = None
old_client: MCPClient | None = None
with self._pool_lock:
transition_event = self._transitions.get(server_name)
if transition_event is None:
config = self._configs.get(server_name)
if config is None:
raise KeyError(f"Unknown MCP server: {server_name}")
old_client = self._pool.get(server_name)
refcount = self._refcounts.get(server_name, 0)
transition_event = threading.Event()
self._transitions[server_name] = transition_event
break
transition_event.wait()
if old_client is not None:
old_client.disconnect()
new_client = MCPClient(config)
try:
new_client.connect()
except Exception:
with self._pool_lock:
current = self._transitions.get(server_name)
if current is transition_event:
self._pool.pop(server_name, None)
self._transitions.pop(server_name, None)
transition_event.set()
raise
with self._pool_lock:
current = self._transitions.get(server_name)
if current is transition_event:
self._pool[server_name] = new_client
self._refcounts[server_name] = max(refcount, 1)
self._transitions.pop(server_name, None)
transition_event.set()
return new_client
new_client.disconnect()
return self.acquire(config)
def cleanup_all(self) -> None:
"""Disconnect all pooled clients and clear manager state."""
while True:
with self._pool_lock:
if self._transitions:
pending = list(self._transitions.values())
else:
cleanup_events = {name: threading.Event() for name in self._pool}
clients = list(self._pool.items())
self._transitions.update(cleanup_events)
self._pool.clear()
self._refcounts.clear()
self._configs.clear()
break
for event in pending:
event.wait()
for _server_name, client in clients:
try:
client.disconnect()
except Exception:
pass
with self._pool_lock:
for server_name, event in cleanup_events.items():
current = self._transitions.get(server_name)
if current is event:
self._transitions.pop(server_name, None)
event.set()
+4 -1
View File
@@ -1,6 +1,6 @@
"""Pre-load validation for agent graphs.
Runs structural and credential checks before MCP servers are spawned.
Runs structural, credential, and skill-trust checks before MCP servers are spawned.
Fails fast with actionable error messages.
"""
@@ -169,6 +169,9 @@ def run_preload_validation(
1. Graph structure (includes GCU subagent-only checks) non-recoverable
2. Credentials potentially recoverable via interactive setup
Skill discovery and trust gating (AS-13) happen later in runner._setup()
so they have access to agent-level skill configuration.
Raises PreloadValidationError for structural issues.
Raises CredentialError for credential issues.
"""
+340 -2
View File
@@ -552,6 +552,319 @@ def get_kimi_code_token() -> str | None:
return None
# ---------------------------------------------------------------------------
# Antigravity subscription token helpers
# ---------------------------------------------------------------------------
# Antigravity IDE (native macOS/Linux app) stores OAuth tokens in its
# VSCode-style SQLite state database under the key
# "antigravityUnifiedStateSync.oauthToken" as a base64-encoded protobuf blob.
ANTIGRAVITY_IDE_STATE_DB = (
Path.home()
/ "Library"
/ "Application Support"
/ "Antigravity"
/ "User"
/ "globalStorage"
/ "state.vscdb"
)
# Linux fallback for the IDE state DB
ANTIGRAVITY_IDE_STATE_DB_LINUX = (
Path.home() / ".config" / "Antigravity" / "User" / "globalStorage" / "state.vscdb"
)
# Antigravity credentials stored by native OAuth implementation
ANTIGRAVITY_AUTH_FILE = Path.home() / ".hive" / "antigravity-accounts.json"
ANTIGRAVITY_OAUTH_TOKEN_URL = "https://oauth2.googleapis.com/token"
_ANTIGRAVITY_TOKEN_LIFETIME_SECS = 3600 # Google access tokens expire in 1 hour
_ANTIGRAVITY_IDE_STATE_DB_KEY = "antigravityUnifiedStateSync.oauthToken"
def _read_antigravity_ide_credentials() -> dict | None:
"""Read credentials from the Antigravity IDE's SQLite state database.
The Antigravity desktop IDE (VSCode-based) stores its OAuth token as a
base64-encoded protobuf blob in a SQLite database. The access token is
a standard Google OAuth ``ya29.*`` bearer token.
Returns:
Dict with ``accessToken`` and optionally ``refreshToken`` keys,
plus ``_source: "ide"`` to skip file-based save on refresh.
Returns None if the database is absent or the key is not found.
"""
import re
import sqlite3
for db_path in (ANTIGRAVITY_IDE_STATE_DB, ANTIGRAVITY_IDE_STATE_DB_LINUX):
if not db_path.exists():
continue
try:
con = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
try:
row = con.execute(
"SELECT value FROM ItemTable WHERE key = ?",
(_ANTIGRAVITY_IDE_STATE_DB_KEY,),
).fetchone()
finally:
con.close()
if not row:
continue
import base64
blob = base64.b64decode(row[0])
# The protobuf blob contains the access token (ya29.*) and
# refresh token (1//*) as length-prefixed UTF-8 strings.
# Decode the inner base64 layer and extract with regex.
inner_b64_candidates = re.findall(rb"[A-Za-z0-9+/=_\-]{40,}", blob)
access_token: str | None = None
refresh_token: str | None = None
for candidate in inner_b64_candidates:
try:
padded = candidate + b"=" * (-len(candidate) % 4)
inner = base64.urlsafe_b64decode(padded)
except Exception:
continue
if not access_token:
m = re.search(rb"ya29\.[A-Za-z0-9_\-\.]+", inner)
if m:
access_token = m.group(0).decode("ascii")
if not refresh_token:
m = re.search(rb"1//[A-Za-z0-9_\-\.]+", inner)
if m:
refresh_token = m.group(0).decode("ascii")
if access_token and refresh_token:
break
if access_token:
return {
"accounts": [
{
"accessToken": access_token,
"refreshToken": refresh_token or "",
}
],
"_source": "ide",
"_db_path": str(db_path),
}
except Exception as exc:
logger.debug("Failed to read Antigravity IDE state DB: %s", exc)
continue
return None
def _read_antigravity_credentials() -> dict | None:
"""Read Antigravity auth data from all supported credential sources.
Checks in order:
1. Antigravity IDE SQLite state database (native macOS/Linux app)
2. Native OAuth credentials file (~/.hive/antigravity-accounts.json)
Returns:
Auth data dict with an ``accounts`` list on success, None otherwise.
"""
# 1. Native Antigravity IDE (primary on macOS)
ide_creds = _read_antigravity_ide_credentials()
if ide_creds:
return ide_creds
# 2. Native OAuth credentials file
if ANTIGRAVITY_AUTH_FILE.exists():
try:
with open(ANTIGRAVITY_AUTH_FILE, encoding="utf-8") as f:
data = json.load(f)
accounts = data.get("accounts", [])
if accounts and isinstance(accounts[0], dict):
return data
except (json.JSONDecodeError, OSError):
pass
return None
def _is_antigravity_token_expired(auth_data: dict) -> bool:
"""Check whether the Antigravity access token is expired or near expiry.
For IDE-sourced credentials: uses the state DB's mtime as last_refresh
since the IDE keeps the DB fresh while it's running.
For JSON-sourced credentials: uses the ``last_refresh`` field or file mtime.
"""
import time
from datetime import datetime
now = time.time()
if auth_data.get("_source") == "ide":
# The IDE refreshes tokens automatically while running.
# Use the DB file's mtime as a proxy for when the token was last updated.
try:
db_path = Path(auth_data.get("_db_path", str(ANTIGRAVITY_IDE_STATE_DB)))
last_refresh: float = db_path.stat().st_mtime
except OSError:
return True
expires_at = last_refresh + _ANTIGRAVITY_TOKEN_LIFETIME_SECS
return now >= (expires_at - _TOKEN_REFRESH_BUFFER_SECS)
last_refresh_val: float | str | None = auth_data.get("last_refresh")
if last_refresh_val is None:
try:
last_refresh_val = ANTIGRAVITY_AUTH_FILE.stat().st_mtime
except OSError:
return True
elif isinstance(last_refresh_val, str):
try:
last_refresh_val = datetime.fromisoformat(
last_refresh_val.replace("Z", "+00:00")
).timestamp()
except (ValueError, TypeError):
return True
expires_at = float(last_refresh_val) + _ANTIGRAVITY_TOKEN_LIFETIME_SECS
return now >= (expires_at - _TOKEN_REFRESH_BUFFER_SECS)
def _refresh_antigravity_token(refresh_token: str) -> dict | None:
"""Refresh the Antigravity access token via Google OAuth.
POSTs form-encoded ``grant_type=refresh_token`` to the Google token
endpoint using Antigravity's public OAuth client ID.
Returns:
Parsed response dict (containing ``access_token``) on success,
None on any error.
"""
import urllib.error
import urllib.parse
import urllib.request
from framework.config import get_antigravity_client_id, get_antigravity_client_secret
client_id = get_antigravity_client_id()
client_secret = get_antigravity_client_secret()
params: dict = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": client_id,
}
if client_secret:
params["client_secret"] = client_secret
data = urllib.parse.urlencode(params).encode("utf-8")
req = urllib.request.Request(
ANTIGRAVITY_OAUTH_TOKEN_URL,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
method="POST",
)
try:
with urllib.request.urlopen(req, timeout=15) as resp: # noqa: S310
return json.loads(resp.read())
except (urllib.error.URLError, json.JSONDecodeError, TimeoutError, OSError) as exc:
logger.debug("Antigravity token refresh failed: %s", exc)
return None
def _save_refreshed_antigravity_credentials(auth_data: dict, token_data: dict) -> None:
"""Write refreshed tokens back to the Antigravity JSON credentials file.
Skipped for IDE-sourced credentials (the IDE manages its own DB).
Updates ``accounts[0].accessToken`` (and ``refreshToken`` if present),
then persists ``last_refresh`` as an ISO-8601 UTC string.
"""
from datetime import datetime
# IDE manages its own state — we do not write back to its SQLite DB
if auth_data.get("_source") == "ide":
return
try:
accounts = auth_data.get("accounts", [])
if not accounts:
return
account = accounts[0]
account["accessToken"] = token_data["access_token"]
if "refresh_token" in token_data:
account["refreshToken"] = token_data["refresh_token"]
auth_data["accounts"] = accounts
auth_data["last_refresh"] = datetime.now(UTC).isoformat()
ANTIGRAVITY_AUTH_FILE.parent.mkdir(parents=True, exist_ok=True)
fd = os.open(ANTIGRAVITY_AUTH_FILE, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
with os.fdopen(fd, "w", encoding="utf-8") as f:
json.dump(auth_data, f, indent=2)
logger.debug("Antigravity credentials refreshed and saved")
except (OSError, KeyError) as exc:
logger.debug("Failed to save refreshed Antigravity credentials: %s", exc)
def get_antigravity_token() -> str | None:
"""Get the OAuth access token from an Antigravity subscription.
Credential sources checked in order:
1. Antigravity IDE SQLite state DB (native app, macOS/Linux)
2. antigravity-auth CLI JSON file
For IDE credentials the token is read directly (the IDE refreshes it
automatically while running). For JSON credentials an automatic OAuth
refresh is attempted when the token is near expiry.
Returns:
The ``ya29.*`` Google OAuth access token, or None if unavailable.
"""
auth_data = _read_antigravity_credentials()
if not auth_data:
return None
accounts = auth_data.get("accounts", [])
if not accounts:
return None
account = accounts[0]
access_token = account.get("accessToken")
if not access_token:
return None
if not _is_antigravity_token_expired(auth_data):
return access_token
# Token is expired or near expiry — attempt a refresh
refresh_token = account.get("refreshToken")
if not refresh_token:
logger.warning(
"Antigravity token expired and no refresh token available. "
"Re-open the Antigravity IDE to refresh, or run 'antigravity-auth accounts add'."
)
return access_token # return stale token; proxy may still accept it briefly
logger.info("Antigravity token expired or near expiry, refreshing...")
token_data = _refresh_antigravity_token(refresh_token)
if token_data and "access_token" in token_data:
_save_refreshed_antigravity_credentials(auth_data, token_data)
return token_data["access_token"]
logger.warning(
"Antigravity token refresh failed. "
"Re-open the Antigravity IDE or run 'antigravity-auth accounts add'."
)
return access_token
def _is_antigravity_proxy_available() -> bool:
"""Return True if antigravity-auth serve is running on localhost:8069."""
import socket
try:
with socket.create_connection(("localhost", 8069), timeout=0.5):
return True
except (OSError, TimeoutError):
return False
@dataclass
class AgentInfo:
"""Information about an exported agent."""
@@ -1141,7 +1454,10 @@ class AgentRunner:
# Create LLM provider
# Uses LiteLLM which auto-detects the provider from model name
if self.mock_mode:
# Skip if already injected (e.g. worker agents with a pre-built LLM)
if self._llm is not None:
pass # LLM already configured externally
elif self.mock_mode:
# Use mock LLM for testing without real API calls
from framework.llm.mock import MockLLMProvider
@@ -1155,6 +1471,7 @@ class AgentRunner:
use_claude_code = llm_config.get("use_claude_code_subscription", False)
use_codex = llm_config.get("use_codex_subscription", False)
use_kimi_code = llm_config.get("use_kimi_code_subscription", False)
use_antigravity = llm_config.get("use_antigravity_subscription", False)
api_base = llm_config.get("api_base")
api_key = None
@@ -1176,6 +1493,8 @@ class AgentRunner:
if not api_key:
print("Warning: Kimi Code subscription configured but no key found.")
print("Run 'kimi /login' to authenticate, then try again.")
elif use_antigravity:
pass # AntigravityProvider handles credentials internally
if api_key and use_claude_code:
# Use litellm's built-in Anthropic OAuth support.
@@ -1214,6 +1533,19 @@ class AgentRunner:
api_key=api_key,
api_base=api_base,
)
elif use_antigravity:
# Direct OAuth to Google's internal Cloud Code Assist gateway.
# No local proxy required — AntigravityProvider handles token
# refresh and Gemini-format request/response conversion natively.
from framework.llm.antigravity import AntigravityProvider # noqa: PLC0415
provider = AntigravityProvider(model=self.model)
if not provider.has_credentials():
print(
"Warning: Antigravity credentials not found. "
"Run: uv run python core/antigravity_auth.py auth account add"
)
self._llm = provider
else:
# Local models (e.g. Ollama) don't need an API key
if self._is_local_model(self.model):
@@ -1340,7 +1672,7 @@ class AgentRunner:
except Exception:
pass # Best-effort — agent works without account info
# Skill configuration — the runtime handles discovery, loading, and
# Skill configuration — the runtime handles discovery, loading, trust-gating and
# prompt rasterization. The runner just builds the config.
from framework.skills.config import SkillsConfig
from framework.skills.manager import SkillsManagerConfig
@@ -1351,6 +1683,7 @@ class AgentRunner:
skills=getattr(self, "_agent_skills", None),
),
project_root=self.agent_path,
interactive=self._interactive,
)
self._setup_agent_runtime(
@@ -1381,6 +1714,8 @@ class AgentRunner:
return "MISTRAL_API_KEY"
elif model_lower.startswith("groq/"):
return "GROQ_API_KEY"
elif model_lower.startswith("openrouter/"):
return "OPENROUTER_API_KEY"
elif self._is_local_model(model_lower):
return None # Local models don't need an API key
elif model_lower.startswith("azure/"):
@@ -1460,6 +1795,9 @@ class AgentRunner:
accounts_data: list[dict] | None = None,
tool_provider_map: dict[str, str] | None = None,
event_bus=None,
skills_catalog_prompt: str = "",
protocols_prompt: str = "",
skill_dirs: list[str] | None = None,
skills_manager_config=None,
) -> None:
"""Set up multi-entry-point execution using AgentRuntime."""
+43 -11
View File
@@ -54,6 +54,8 @@ class ToolRegistry:
def __init__(self):
self._tools: dict[str, RegisteredTool] = {}
self._mcp_clients: list[Any] = [] # List of MCPClient instances
self._mcp_client_servers: dict[int, str] = {} # client id -> server name
self._mcp_managed_clients: set[int] = set() # client ids acquired from the manager
self._session_context: dict[str, Any] = {} # Auto-injected context for tools
self._provider_index: dict[str, set[str]] = {} # provider -> tool names
# MCP resync tracking
@@ -243,6 +245,13 @@ class ToolRegistry:
def _wrap_result(tool_use_id: str, result: Any) -> ToolResult:
if isinstance(result, ToolResult):
return result
# MCP client returns dict with _images when image content is present
if isinstance(result, dict) and "_images" in result:
return ToolResult(
tool_use_id=tool_use_id,
content=result.get("_text", ""),
image_content=result["_images"],
)
return ToolResult(
tool_use_id=tool_use_id,
content=json.dumps(result) if not isinstance(result, str) else result,
@@ -480,6 +489,7 @@ class ToolRegistry:
def register_mcp_server(
self,
server_config: dict[str, Any],
use_connection_manager: bool = True,
) -> int:
"""
Register an MCP server and discover its tools.
@@ -495,12 +505,14 @@ class ToolRegistry:
- url: Server URL (for http)
- headers: HTTP headers (for http)
- description: Server description (optional)
use_connection_manager: When True, reuse a shared client keyed by server name
Returns:
Number of tools registered from this server
"""
try:
from framework.runner.mcp_client import MCPClient, MCPServerConfig
from framework.runner.mcp_connection_manager import MCPConnectionManager
# Build config object
config = MCPServerConfig(
@@ -516,11 +528,18 @@ class ToolRegistry:
)
# Create and connect client
client = MCPClient(config)
client.connect()
if use_connection_manager:
client = MCPConnectionManager.get_instance().acquire(config)
else:
client = MCPClient(config)
client.connect()
# Store client for cleanup
self._mcp_clients.append(client)
client_id = id(client)
self._mcp_client_servers[client_id] = config.name
if use_connection_manager:
self._mcp_managed_clients.add(client_id)
# Register each tool
server_name = server_config["name"]
@@ -560,7 +579,9 @@ class ToolRegistry:
}
merged_inputs = {**clean_inputs, **filtered_context}
result = client_ref.call_tool(tool_name, merged_inputs)
# MCP tools return content array, extract the result
# MCP client already extracts content (returns str
# or {"_text": ..., "_images": ...} for image results).
# Handle legacy list format from HTTP transport.
if isinstance(result, list) and len(result) > 0:
if isinstance(result[0], dict) and "text" in result[0]:
return result[0]["text"]
@@ -720,12 +741,7 @@ class ToolRegistry:
logger.info("%s — resyncing MCP servers", reason)
# 1. Disconnect existing MCP clients
for client in self._mcp_clients:
try:
client.disconnect()
except Exception as e:
logger.warning(f"Error disconnecting MCP client during resync: {e}")
self._mcp_clients.clear()
self._cleanup_mcp_clients("during resync")
# 2. Remove MCP-registered tools
for name in self._mcp_tool_names:
@@ -740,12 +756,28 @@ class ToolRegistry:
def cleanup(self) -> None:
"""Clean up all MCP client connections."""
self._cleanup_mcp_clients()
def _cleanup_mcp_clients(self, context: str = "") -> None:
"""Disconnect or release all tracked MCP clients for this registry."""
if context:
context = f" {context}"
for client in self._mcp_clients:
client_id = id(client)
server_name = self._mcp_client_servers.get(client_id, client.config.name)
try:
client.disconnect()
if client_id in self._mcp_managed_clients:
from framework.runner.mcp_connection_manager import MCPConnectionManager
MCPConnectionManager.get_instance().release(server_name)
else:
client.disconnect()
except Exception as e:
logger.warning(f"Error disconnecting MCP client: {e}")
logger.warning(f"Error disconnecting MCP client{context}: {e}")
self._mcp_clients.clear()
self._mcp_client_servers.clear()
self._mcp_managed_clients.clear()
def __del__(self):
"""Destructor to ensure cleanup."""
+22 -2
View File
@@ -137,6 +137,7 @@ class AgentRuntime:
# Deprecated — pass skills_manager_config instead.
skills_catalog_prompt: str = "",
protocols_prompt: str = "",
skill_dirs: list[str] | None = None,
):
"""
Initialize agent runtime.
@@ -158,6 +159,9 @@ class AgentRuntime:
event_bus: Optional external EventBus. If provided, the runtime shares
this bus instead of creating its own. Used by SessionManager to
share a single bus between queen, worker, and judge.
skills_catalog_prompt: Available skills catalog for system prompt
protocols_prompt: Default skill operational protocols for system prompt
skill_dirs: Skill base directories for Tier 3 resource access
skills_manager_config: Skill configuration the runtime owns
discovery, loading, and prompt renderation internally.
skills_catalog_prompt: Deprecated. Pre-rendered skills catalog.
@@ -195,6 +199,8 @@ class AgentRuntime:
self._skills_manager = SkillsManager()
self._skills_manager.load()
self.skill_dirs: list[str] = self._skills_manager.allowlisted_dirs
# Primary graph identity
self._graph_id: str = graph_id or "primary"
@@ -341,6 +347,7 @@ class AgentRuntime:
tool_provider_map=self._tool_provider_map,
skills_catalog_prompt=self.skills_catalog_prompt,
protocols_prompt=self.protocols_prompt,
skill_dirs=self.skill_dirs,
)
await stream.start()
self._streams[ep_id] = stream
@@ -977,6 +984,7 @@ class AgentRuntime:
tool_provider_map=self._tool_provider_map,
skills_catalog_prompt=self.skills_catalog_prompt,
protocols_prompt=self.protocols_prompt,
skill_dirs=self.skill_dirs,
)
if self._running:
await stream.start()
@@ -1466,6 +1474,7 @@ class AgentRuntime:
graph_id: str | None = None,
*,
is_client_input: bool = False,
image_content: list[dict[str, Any]] | None = None,
) -> bool:
"""Inject user input into a running client-facing node.
@@ -1478,6 +1487,8 @@ class AgentRuntime:
graph_id: Optional graph to search first (defaults to active graph)
is_client_input: True when the message originates from a real
human user (e.g. /chat endpoint), False for external events.
image_content: Optional list of image content blocks (OpenAI
image_url format) to include alongside the text.
Returns:
True if input was delivered, False if no matching node found
@@ -1489,7 +1500,9 @@ class AgentRuntime:
target = graph_id or self._active_graph_id
if target in self._graphs:
for stream in self._graphs[target].streams.values():
if await stream.inject_input(node_id, content, is_client_input=is_client_input):
if await stream.inject_input(
node_id, content, is_client_input=is_client_input, image_content=image_content
):
return True
# Then search all other graphs
@@ -1497,7 +1510,9 @@ class AgentRuntime:
if gid == target:
continue
for stream in reg.streams.values():
if await stream.inject_input(node_id, content, is_client_input=is_client_input):
if await stream.inject_input(
node_id, content, is_client_input=is_client_input, image_content=image_content
):
return True
return False
@@ -1760,6 +1775,7 @@ def create_agent_runtime(
# Deprecated — pass skills_manager_config instead.
skills_catalog_prompt: str = "",
protocols_prompt: str = "",
skill_dirs: list[str] | None = None,
) -> AgentRuntime:
"""
Create and configure an AgentRuntime with entry points.
@@ -1786,6 +1802,9 @@ def create_agent_runtime(
accounts_data: Raw account data for per-node prompt generation.
tool_provider_map: Tool name to provider name mapping for account routing.
event_bus: Optional external EventBus to share with other components.
skills_catalog_prompt: Available skills catalog for system prompt.
protocols_prompt: Default skill operational protocols for system prompt.
skill_dirs: Skill base directories for Tier 3 resource access.
skills_manager_config: Skill configuration the runtime owns
discovery, loading, and prompt renderation internally.
skills_catalog_prompt: Deprecated. Pre-rendered skills catalog.
@@ -1819,6 +1838,7 @@ def create_agent_runtime(
skills_manager_config=skills_manager_config,
skills_catalog_prompt=skills_catalog_prompt,
protocols_prompt=protocols_prompt,
skill_dirs=skill_dirs,
)
for spec in entry_points:
+2
View File
@@ -117,6 +117,7 @@ class EventType(StrEnum):
# Context management
CONTEXT_COMPACTED = "context_compacted"
CONTEXT_USAGE_UPDATED = "context_usage_updated"
# External triggers
WEBHOOK_RECEIVED = "webhook_received"
@@ -159,6 +160,7 @@ class EventType(StrEnum):
TRIGGER_DEACTIVATED = "trigger_deactivated"
TRIGGER_FIRED = "trigger_fired"
TRIGGER_REMOVED = "trigger_removed"
TRIGGER_UPDATED = "trigger_updated"
@dataclass
+8 -1
View File
@@ -188,6 +188,7 @@ class ExecutionStream:
tool_provider_map: dict[str, str] | None = None,
skills_catalog_prompt: str = "",
protocols_prompt: str = "",
skill_dirs: list[str] | None = None,
):
"""
Initialize execution stream.
@@ -213,6 +214,7 @@ class ExecutionStream:
tool_provider_map: Tool name to provider name mapping for account routing
skills_catalog_prompt: Available skills catalog for system prompt
protocols_prompt: Default skill operational protocols for system prompt
skill_dirs: Skill base directories for Tier 3 resource access
"""
self.stream_id = stream_id
self.entry_spec = entry_spec
@@ -236,6 +238,7 @@ class ExecutionStream:
self._tool_provider_map = tool_provider_map
self._skills_catalog_prompt = skills_catalog_prompt
self._protocols_prompt = protocols_prompt
self._skill_dirs: list[str] = skill_dirs or []
_es_logger = logging.getLogger(__name__)
if protocols_prompt:
@@ -430,6 +433,7 @@ class ExecutionStream:
content: str,
*,
is_client_input: bool = False,
image_content: list[dict[str, Any]] | None = None,
) -> bool:
"""Inject user input into a running client-facing EventLoopNode.
@@ -441,7 +445,9 @@ class ExecutionStream:
for executor in self._active_executors.values():
node = executor.node_registry.get(node_id)
if node is not None and hasattr(node, "inject_event"):
await node.inject_event(content, is_client_input=is_client_input)
await node.inject_event(
content, is_client_input=is_client_input, image_content=image_content
)
return True
return False
@@ -696,6 +702,7 @@ class ExecutionStream:
tool_provider_map=self._tool_provider_map,
skills_catalog_prompt=self._skills_catalog_prompt,
protocols_prompt=self._protocols_prompt,
skill_dirs=self._skill_dirs,
)
# Track executor so inject_input() can reach EventLoopNode instances
self._active_executors[execution_id] = executor
@@ -8,6 +8,7 @@ write. Errors are silently swallowed — this must never break the agent.
import json
import logging
import os
from datetime import datetime
from pathlib import Path
from typing import IO, Any
@@ -47,6 +48,9 @@ def log_llm_turn(
Never raises.
"""
try:
# Skip logging during test runs to avoid polluting real logs.
if os.environ.get("PYTEST_CURRENT_TEST") or os.environ.get("HIVE_DISABLE_LLM_LOGS"):
return
global _log_file, _log_ready # noqa: PLW0603
if not _log_ready:
_log_file = _open_log()
@@ -69,6 +69,7 @@ async def create_queen(
QueenPhaseState,
register_queen_lifecycle_tools,
)
from framework.tools.queen_memory_tools import register_queen_memory_tools
hive_home = Path.home() / ".hive"
@@ -122,6 +123,9 @@ async def create_queen(
phase_state=phase_state,
)
# ---- Episodic memory tools (always registered) ---------------------
register_queen_memory_tools(queen_registry)
# ---- Monitoring tools (only when worker is loaded) ----------------
if session.worker_runtime:
from framework.tools.worker_monitoring_tools import register_worker_monitoring_tools
+2
View File
@@ -37,6 +37,7 @@ DEFAULT_EVENT_TYPES = [
EventType.NODE_RETRY,
EventType.NODE_TOOL_DOOM_LOOP,
EventType.CONTEXT_COMPACTED,
EventType.CONTEXT_USAGE_UPDATED,
EventType.WORKER_LOADED,
EventType.CREDENTIALS_REQUIRED,
EventType.SUBAGENT_REPORT,
@@ -46,6 +47,7 @@ DEFAULT_EVENT_TYPES = [
EventType.TRIGGER_DEACTIVATED,
EventType.TRIGGER_FIRED,
EventType.TRIGGER_REMOVED,
EventType.TRIGGER_UPDATED,
EventType.DRAFT_GRAPH_UPDATED,
]
+11 -4
View File
@@ -108,7 +108,10 @@ async def handle_chat(request: web.Request) -> web.Response:
The input box is permanently connected to the queen agent.
Worker input is handled separately via /worker-input.
Body: {"message": "hello"}
Body: {"message": "hello", "images": [{"type": "image_url", "image_url": {"url": "data:..."}}]}
The optional ``images`` field accepts a list of OpenAI-format image_url
content blocks. The frontend encodes images as base64 data URIs.
"""
session, err = resolve_session(request)
if err:
@@ -116,15 +119,16 @@ async def handle_chat(request: web.Request) -> web.Response:
body = await request.json()
message = body.get("message", "")
image_content = body.get("images") or None # list[dict] | None
if not message:
if not message and not image_content:
return web.json_response({"error": "message is required"}, status=400)
queen_executor = session.queen_executor
if queen_executor is not None:
node = queen_executor.node_registry.get("queen")
if node is not None and hasattr(node, "inject_event"):
await node.inject_event(message, is_client_input=True)
await node.inject_event(message, is_client_input=True, image_content=image_content)
# Publish to EventBus so the session event log captures user messages
from framework.runtime.event_bus import AgentEvent, EventType
@@ -134,7 +138,10 @@ async def handle_chat(request: web.Request) -> web.Response:
stream_id="queen",
node_id="queen",
execution_id=session.id,
data={"content": message},
data={
"content": message,
"image_count": len(image_content) if image_content else 0,
},
)
)
return web.json_response(
+146 -63
View File
@@ -11,7 +11,6 @@ Session-primary routes:
- GET /api/sessions/{session_id}/entry-points list entry points
- PATCH /api/sessions/{session_id}/triggers/{id} update trigger task
- GET /api/sessions/{session_id}/graphs list graph IDs
- GET /api/sessions/{session_id}/queen-messages queen conversation history
- GET /api/sessions/{session_id}/events/history persisted eventbus log (for replay)
Worker session browsing (persisted execution runs on disk):
@@ -24,9 +23,13 @@ Worker session browsing (persisted execution runs on disk):
"""
import asyncio
import contextlib
import json
import logging
import shutil
import subprocess
import sys
import time
from pathlib import Path
@@ -50,8 +53,11 @@ def _get_manager(request: web.Request) -> SessionManager:
def _session_to_live_dict(session) -> dict:
"""Serialize a live Session to the session-primary JSON shape."""
from framework.llm.capabilities import supports_image_tool_results
info = session.worker_info
phase_state = getattr(session, "phase_state", None)
queen_model: str = getattr(getattr(session, "runner", None), "model", "") or ""
return {
"session_id": session.id,
"worker_id": session.worker_id,
@@ -67,6 +73,7 @@ def _session_to_live_dict(session) -> dict:
"queen_phase": phase_state.phase
if phase_state
else ("staging" if session.worker_runtime else "planning"),
"queen_supports_images": supports_image_tool_results(queen_model) if queen_model else True,
}
@@ -408,7 +415,7 @@ async def handle_session_entry_points(request: web.Request) -> web.Response:
async def handle_update_trigger_task(request: web.Request) -> web.Response:
"""PATCH /api/sessions/{session_id}/triggers/{trigger_id} — update trigger task."""
"""PATCH /api/sessions/{session_id}/triggers/{trigger_id} — update trigger fields."""
session, err = resolve_session(request)
if err:
return err
@@ -427,30 +434,136 @@ async def handle_update_trigger_task(request: web.Request) -> web.Response:
except Exception:
return web.json_response({"error": "Invalid JSON body"}, status=400)
task = body.get("task")
if task is None:
return web.json_response({"error": "Missing 'task' field"}, status=400)
if not isinstance(task, str):
return web.json_response({"error": "'task' must be a string"}, status=400)
updates: dict[str, object] = {}
tdef.task = task
if "task" in body:
task = body.get("task")
if not isinstance(task, str):
return web.json_response({"error": "'task' must be a string"}, status=400)
tdef.task = task
updates["task"] = tdef.task
trigger_config_update = body.get("trigger_config")
if trigger_config_update is not None:
if not isinstance(trigger_config_update, dict):
return web.json_response(
{"error": "'trigger_config' must be an object"},
status=400,
)
merged_trigger_config = dict(tdef.trigger_config)
merged_trigger_config.update(trigger_config_update)
if tdef.trigger_type == "timer":
cron_expr = merged_trigger_config.get("cron")
interval = merged_trigger_config.get("interval_minutes")
if cron_expr is not None and not isinstance(cron_expr, str):
return web.json_response(
{"error": "'trigger_config.cron' must be a string"},
status=400,
)
if cron_expr:
try:
from croniter import croniter
if not croniter.is_valid(cron_expr):
return web.json_response(
{"error": f"Invalid cron expression: {cron_expr}"},
status=400,
)
except ImportError:
return web.json_response(
{
"error": (
"croniter package not installed — cannot validate cron expression."
)
},
status=500,
)
merged_trigger_config.pop("interval_minutes", None)
elif interval is None:
return web.json_response(
{
"error": (
"Timer trigger needs 'cron' or 'interval_minutes' in trigger_config."
)
},
status=400,
)
elif not isinstance(interval, (int, float)) or interval <= 0:
return web.json_response(
{"error": "'trigger_config.interval_minutes' must be > 0"},
status=400,
)
tdef.trigger_config = merged_trigger_config
updates["trigger_config"] = tdef.trigger_config
if not updates:
return web.json_response(
{"error": "Provide at least one of 'task' or 'trigger_config'"},
status=400,
)
# Persist to session state and agent definition
from framework.tools.queen_lifecycle_tools import (
_persist_active_triggers,
_save_trigger_to_agent,
_start_trigger_timer,
_start_trigger_webhook,
)
if "trigger_config" in updates and trigger_id in getattr(session, "active_trigger_ids", set()):
task = session.active_timer_tasks.pop(trigger_id, None)
if task and not task.done():
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
getattr(session, "trigger_next_fire", {}).pop(trigger_id, None)
webhook_subs = getattr(session, "active_webhook_subs", {})
if sub_id := webhook_subs.pop(trigger_id, None):
with contextlib.suppress(Exception):
session.event_bus.unsubscribe(sub_id)
if tdef.trigger_type == "timer":
await _start_trigger_timer(session, trigger_id, tdef)
elif tdef.trigger_type == "webhook":
await _start_trigger_webhook(session, trigger_id, tdef)
if trigger_id in getattr(session, "active_trigger_ids", set()):
session_id = request.match_info["session_id"]
await _persist_active_triggers(session, session_id)
_save_trigger_to_agent(session, trigger_id, tdef)
# Emit SSE event so the frontend updates the graph and detail panel
bus = getattr(session, "event_bus", None)
if bus:
from framework.runtime.event_bus import AgentEvent, EventType
await bus.publish(
AgentEvent(
type=EventType.TRIGGER_UPDATED,
stream_id="queen",
data={
"trigger_id": trigger_id,
"task": tdef.task,
"trigger_config": tdef.trigger_config,
"trigger_type": tdef.trigger_type,
"name": tdef.description or trigger_id,
"entry_node": getattr(
getattr(getattr(session, "runner", None), "graph", None),
"entry_node",
None,
),
},
)
)
return web.json_response(
{
"trigger_id": trigger_id,
"task": tdef.task,
"trigger_config": tdef.trigger_config,
}
)
@@ -754,60 +867,6 @@ async def handle_messages(request: web.Request) -> web.Response:
return web.json_response({"messages": all_messages})
async def handle_queen_messages(request: web.Request) -> web.Response:
"""GET /api/sessions/{session_id}/queen-messages — get queen conversation.
Reads directly from disk so it works for both live sessions and cold
(post-server-restart) sessions no live session required.
"""
session_id = request.match_info["session_id"]
queen_dir = Path.home() / ".hive" / "queen" / "session" / session_id
convs_dir = queen_dir / "conversations"
if not convs_dir.exists():
return web.json_response({"messages": [], "session_id": session_id})
all_messages: list[dict] = []
def _read_parts(parts_dir: Path, node_id: str) -> None:
if not parts_dir.exists():
return
for part_file in sorted(parts_dir.iterdir()):
if part_file.suffix != ".json":
continue
try:
part = json.loads(part_file.read_text(encoding="utf-8"))
part["_node_id"] = node_id
# Use file mtime as created_at so frontend can order
# queen and worker messages chronologically.
part.setdefault("created_at", part_file.stat().st_mtime)
all_messages.append(part)
except (json.JSONDecodeError, OSError):
continue
# Flat layout: conversations/parts/*.json
_read_parts(convs_dir / "parts", "queen")
# Node-based layout: conversations/<node_id>/parts/*.json
for node_dir in convs_dir.iterdir():
if not node_dir.is_dir() or node_dir.name == "parts":
continue
_read_parts(node_dir / "parts", node_dir.name)
all_messages.sort(key=lambda m: m.get("created_at", m.get("seq", 0)))
# Filter to client-facing messages only
all_messages = [
m
for m in all_messages
if not m.get("is_transition_marker")
and m["role"] != "tool"
and not (m["role"] == "assistant" and m.get("tool_calls"))
]
return web.json_response({"messages": all_messages, "session_id": session_id})
async def handle_session_events_history(request: web.Request) -> web.Response:
"""GET /api/sessions/{session_id}/events/history — persisted eventbus log.
@@ -925,6 +984,29 @@ async def handle_discover(request: web.Request) -> web.Response:
return web.json_response(result)
async def handle_reveal_session_folder(request: web.Request) -> web.Response:
"""POST /api/sessions/{session_id}/reveal — open session data folder in the OS file manager."""
manager: SessionManager = request.app["manager"]
session_id = request.match_info["session_id"]
session = manager.get_session(session_id)
storage_session_id = (session.queen_resume_from or session.id) if session else session_id
folder = Path.home() / ".hive" / "queen" / "session" / storage_session_id
folder.mkdir(parents=True, exist_ok=True)
try:
if sys.platform == "darwin":
subprocess.Popen(["open", str(folder)])
elif sys.platform == "win32":
subprocess.Popen(["explorer", str(folder)])
else:
subprocess.Popen(["xdg-open", str(folder)])
except Exception as exc:
return web.json_response({"error": str(exc)}, status=500)
return web.json_response({"path": str(folder)})
# ------------------------------------------------------------------
# Route registration
# ------------------------------------------------------------------
@@ -949,13 +1031,14 @@ def register_routes(app: web.Application) -> None:
app.router.add_delete("/api/sessions/{session_id}/worker", handle_unload_worker)
# Session info
app.router.add_post("/api/sessions/{session_id}/reveal", handle_reveal_session_folder)
app.router.add_get("/api/sessions/{session_id}/stats", handle_session_stats)
app.router.add_get("/api/sessions/{session_id}/entry-points", handle_session_entry_points)
app.router.add_patch(
"/api/sessions/{session_id}/triggers/{trigger_id}", handle_update_trigger_task
)
app.router.add_get("/api/sessions/{session_id}/graphs", handle_session_graphs)
app.router.add_get("/api/sessions/{session_id}/queen-messages", handle_queen_messages)
app.router.add_get("/api/sessions/{session_id}/events/history", handle_session_events_history)
# Worker session browsing (session-primary)
+373 -66
View File
@@ -47,6 +47,8 @@ class Session:
worker_handoff_sub: str | None = None
# Memory consolidation subscription (fires on CONTEXT_COMPACTED)
memory_consolidation_sub: str | None = None
# Worker run digest subscription (fires on EXECUTION_COMPLETED / EXECUTION_FAILED)
worker_digest_sub: str | None = None
# Trigger definitions loaded from agent's triggers.json (available but inactive)
available_triggers: dict[str, TriggerDefinition] = field(default_factory=dict)
# Active trigger tracking (IDs currently firing + their asyncio tasks)
@@ -94,8 +96,7 @@ class SessionManager:
Internal helper use create_session() or create_session_with_worker().
"""
from framework.config import RuntimeConfig
from framework.llm.litellm import LiteLLMProvider
from framework.config import RuntimeConfig, get_hive_config
from framework.runtime.event_bus import EventBus
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
@@ -109,12 +110,20 @@ class SessionManager:
rc = RuntimeConfig(model=model or self._model or RuntimeConfig().model)
# Session owns these — shared with queen and worker
llm = LiteLLMProvider(
model=rc.model,
api_key=rc.api_key,
api_base=rc.api_base,
**rc.extra_kwargs,
)
llm_config = get_hive_config().get("llm", {})
if llm_config.get("use_antigravity_subscription"):
from framework.llm.antigravity import AntigravityProvider
llm = AntigravityProvider(model=rc.model)
else:
from framework.llm.litellm import LiteLLMProvider
llm = LiteLLMProvider(
model=rc.model,
api_key=rc.api_key,
api_base=rc.api_base,
**rc.extra_kwargs,
)
event_bus = EventBus()
session = Session(
@@ -177,6 +186,31 @@ class SessionManager:
agent_path = Path(agent_path)
resolved_worker_id = agent_id or agent_path.name
# When cold-restoring, check meta.json for the phase — if the agent
# was still being built we must NOT try to load the worker (the code
# is incomplete and will fail to import).
if queen_resume_from:
_resume_phase = None
_meta_path = (
Path.home() / ".hive" / "queen" / "session" / queen_resume_from / "meta.json"
)
if _meta_path.exists():
try:
_meta = json.loads(_meta_path.read_text(encoding="utf-8"))
_resume_phase = _meta.get("phase")
except (json.JSONDecodeError, OSError):
pass
if _resume_phase in ("building", "planning"):
# Fall back to queen-only session — cold resume handler in
# _start_queen will set phase_state.agent_path and switch to
# the correct phase.
return await self.create_session(
session_id=session_id,
model=model,
initial_prompt=initial_prompt,
queen_resume_from=queen_resume_from,
)
# Reuse the original session ID when cold-restoring so the frontend
# sees one continuous session instead of a new one each time.
session = await self._create_session_core(
@@ -193,6 +227,9 @@ class SessionManager:
model=model,
)
# Restore active triggers from persisted state (cold restore)
await self._restore_active_triggers(session, session.id)
# Start queen with worker profile + lifecycle + monitoring tools
worker_identity = (
build_worker_profile(session.worker_runtime, agent_path=agent_path)
@@ -204,7 +241,23 @@ class SessionManager:
)
except Exception:
# If anything fails, tear down the session
if queen_resume_from:
# Cold restore: worker load failed (e.g. incomplete code from a
# building session). Fall back to queen-only so the user can
# continue the conversation and fix / rebuild the agent.
logger.warning(
"Cold restore: worker load failed for '%s', falling back to queen-only",
agent_path,
exc_info=True,
)
await self.stop_session(session.id)
return await self.create_session(
session_id=session_id,
model=model,
initial_prompt=initial_prompt,
queen_resume_from=queen_resume_from,
)
# If anything fails (non-cold-restore), tear down the session
await self.stop_session(session.id)
raise
return session
@@ -241,7 +294,17 @@ class SessionManager:
try:
# Blocking I/O — load in executor
loop = asyncio.get_running_loop()
resolved_model = model or self._model
# Prioritize: explicit model arg > worker-specific model > session default
from framework.config import (
get_preferred_worker_model,
get_worker_api_base,
get_worker_api_key,
get_worker_llm_extra_kwargs,
)
worker_model = get_preferred_worker_model()
resolved_model = model or worker_model or self._model
runner = await loop.run_in_executor(
None,
lambda: AgentRunner.load(
@@ -253,6 +316,30 @@ class SessionManager:
),
)
# If a worker-specific model is configured, build an LLM provider
# with the correct worker credentials so _setup() doesn't fall back
# to the queen's llm config (which may be a different provider).
if worker_model and not model:
from framework.config import get_hive_config
worker_llm_cfg = get_hive_config().get("worker_llm", {})
if worker_llm_cfg.get("use_antigravity_subscription"):
from framework.llm.antigravity import AntigravityProvider
runner._llm = AntigravityProvider(model=resolved_model)
else:
from framework.llm.litellm import LiteLLMProvider
worker_api_key = get_worker_api_key()
worker_api_base = get_worker_api_base()
worker_extra = get_worker_llm_extra_kwargs()
runner._llm = LiteLLMProvider(
model=resolved_model,
api_key=worker_api_key,
api_base=worker_api_base,
**worker_extra,
)
# Setup with session's event bus
if runner._agent_runtime is None:
await loop.run_in_executor(
@@ -297,6 +384,9 @@ class SessionManager:
session.worker_runtime = runtime
session.worker_info = info
# Subscribe to execution completion for per-run digest generation
self._subscribe_worker_digest(session)
async with self._lock:
self._loading.discard(session.id)
@@ -399,6 +489,51 @@ class SessionManager:
return False
return True
async def _restore_active_triggers(self, session: "Session", session_id: str) -> None:
"""Restore previously active triggers from persisted session state.
Called after worker loading to restart any timer/webhook triggers
that were active before a server restart.
"""
if not session.available_triggers or not session.worker_runtime:
return
try:
store = session.worker_runtime._session_store
state = await store.read_state(session_id)
if state and state.active_triggers:
from framework.tools.queen_lifecycle_tools import (
_start_trigger_timer,
_start_trigger_webhook,
)
saved_tasks = getattr(state, "trigger_tasks", {}) or {}
for tid in state.active_triggers:
tdef = session.available_triggers.get(tid)
if tdef:
# Restore user-configured task override
saved_task = saved_tasks.get(tid, "")
if saved_task:
tdef.task = saved_task
tdef.active = True
session.active_trigger_ids.add(tid)
if tdef.trigger_type == "timer":
await _start_trigger_timer(session, tid, tdef)
logger.info("Restored trigger timer '%s'", tid)
elif tdef.trigger_type == "webhook":
await _start_trigger_webhook(session, tid, tdef)
logger.info("Restored webhook trigger '%s'", tid)
else:
logger.warning(
"Saved trigger '%s' not found in worker entry points, skipping",
tid,
)
# Restore worker_configured flag
if state and getattr(state, "worker_configured", False):
session.worker_configured = True
except Exception as e:
logger.warning("Failed to restore active triggers: %s", e)
async def load_worker(
self,
session_id: str,
@@ -447,44 +582,7 @@ class SessionManager:
except OSError:
pass
# Restore previously active triggers from persisted session state
if session.available_triggers and session.worker_runtime:
try:
store = session.worker_runtime._session_store
state = await store.read_state(session_id)
if state and state.active_triggers:
from framework.tools.queen_lifecycle_tools import (
_start_trigger_timer,
_start_trigger_webhook,
)
saved_tasks = getattr(state, "trigger_tasks", {}) or {}
for tid in state.active_triggers:
tdef = session.available_triggers.get(tid)
if tdef:
# Restore user-configured task override
saved_task = saved_tasks.get(tid, "")
if saved_task:
tdef.task = saved_task
tdef.active = True
session.active_trigger_ids.add(tid)
if tdef.trigger_type == "timer":
await _start_trigger_timer(session, tid, tdef)
logger.info("Restored trigger timer '%s'", tid)
elif tdef.trigger_type == "webhook":
await _start_trigger_webhook(session, tid, tdef)
logger.info("Restored webhook trigger '%s'", tid)
else:
logger.warning(
"Saved trigger '%s' not found in worker entry points, skipping",
tid,
)
# Restore worker_configured flag
if state and getattr(state, "worker_configured", False):
session.worker_configured = True
except Exception as e:
logger.warning("Failed to restore active triggers: %s", e)
await self._restore_active_triggers(session, session_id)
# Emit SSE event so the frontend can update UI
await self._emit_worker_loaded(session)
@@ -526,6 +624,13 @@ class SessionManager:
await self._emit_trigger_events(session, "removed", session.available_triggers)
session.available_triggers.clear()
if session.worker_digest_sub is not None:
try:
session.event_bus.unsubscribe(session.worker_digest_sub)
except Exception:
pass
session.worker_digest_sub = None
worker_id = session.worker_id
session.worker_id = None
session.worker_path = None
@@ -563,6 +668,13 @@ class SessionManager:
pass
session.worker_handoff_sub = None
if session.worker_digest_sub is not None:
try:
session.event_bus.unsubscribe(session.worker_digest_sub)
except Exception:
pass
session.worker_digest_sub = None
# Stop queen and memory consolidation subscription
if session.memory_consolidation_sub is not None:
try:
@@ -647,6 +759,135 @@ class SessionManager:
else:
logger.warning("Worker handoff received but queen node not ready")
def _subscribe_worker_digest(self, session: Session) -> None:
"""Subscribe to worker events to write per-run digests.
Three triggers:
- NODE_LOOP_ITERATION: write a mid-run snapshot, throttled to at most
once every _DIGEST_COOLDOWN seconds per execution.
- TOOL_CALL_COMPLETED for delegate_to_sub_agent: same throttled snapshot.
Orchestrator nodes often run all subagent calls in a single LLM turn,
so NODE_LOOP_ITERATION only fires once at the end. Subagent
completions provide intermediate checkpoints.
- EXECUTION_COMPLETED / EXECUTION_FAILED: always write the final digest,
bypassing the cooldown.
"""
import time as _time
from framework.runtime.event_bus import EventType as _ET
_DIGEST_COOLDOWN = 300.0 # seconds between mid-run snapshots
if session.worker_digest_sub is not None:
try:
session.event_bus.unsubscribe(session.worker_digest_sub)
except Exception:
pass
session.worker_digest_sub = None
agent_name = session.worker_path.name if session.worker_path else None
if not agent_name:
return
_agent_name = agent_name
_llm = session.llm
_bus = session.event_bus
# per-execution_id monotonic timestamp of last mid-run digest
_last_digest: dict[str, float] = {}
def _resolve_run_id(exec_id: str) -> str | None:
"""Look up the run_id for a given execution_id via EXECUTION_STARTED history."""
for e in _bus.get_history(event_type=_ET.EXECUTION_STARTED, limit=200):
if e.execution_id == exec_id and getattr(e, "run_id", None):
return e.run_id
return None
async def _inject_digest_to_queen(run_id: str) -> None:
"""Read the written digest and push it into the queen's conversation."""
from framework.agents.worker_memory import digest_path
try:
content = digest_path(_agent_name, run_id).read_text(encoding="utf-8").strip()
except OSError:
return
if not content:
return
executor = session.queen_executor
if executor is None:
return
node = executor.node_registry.get("queen")
if node is None or not hasattr(node, "inject_event"):
return
await node.inject_event(f"[WORKER_DIGEST]\n{content}")
async def _consolidate_and_notify(run_id: str, outcome_event: Any) -> None:
"""Write the digest then push it to the queen."""
from framework.agents.worker_memory import consolidate_worker_run
await consolidate_worker_run(_agent_name, run_id, outcome_event, _bus, _llm)
await _inject_digest_to_queen(run_id)
async def _on_worker_event(event: Any) -> None:
if event.stream_id == "queen":
return
exec_id = event.execution_id
if event.type == _ET.EXECUTION_STARTED:
# New run on this execution_id — start the cooldown timer so
# mid-run snapshots don't fire immediately at session start.
# The first snapshot will happen after _DIGEST_COOLDOWN seconds.
if exec_id:
_last_digest[exec_id] = _time.monotonic()
elif event.type in (
_ET.EXECUTION_COMPLETED,
_ET.EXECUTION_FAILED,
_ET.EXECUTION_PAUSED,
):
# Final digest — always fire, ignore cooldown.
# EXECUTION_PAUSED covers cancellation (queen re-triggering the
# worker cancels the previous execution, emitting paused).
run_id = getattr(event, "run_id", None) or _resolve_run_id(exec_id)
if run_id:
asyncio.create_task(
_consolidate_and_notify(run_id, event),
name=f"worker-digest-final-{run_id}",
)
elif event.type in (_ET.NODE_LOOP_ITERATION, _ET.TOOL_CALL_COMPLETED):
# Mid-run snapshot — respect 300 s cooldown per execution.
# TOOL_CALL_COMPLETED is only interesting for subagent calls;
# regular tool completions are too frequent and too cheap.
if event.type == _ET.TOOL_CALL_COMPLETED:
tool_name = (event.data or {}).get("tool_name", "")
if tool_name != "delegate_to_sub_agent":
return
if not exec_id:
return
now = _time.monotonic()
if now - _last_digest.get(exec_id, 0.0) < _DIGEST_COOLDOWN:
return
run_id = _resolve_run_id(exec_id)
if run_id:
_last_digest[exec_id] = now
asyncio.create_task(
_consolidate_and_notify(run_id, None),
name=f"worker-digest-{run_id}",
)
session.worker_digest_sub = session.event_bus.subscribe(
event_types=[
_ET.EXECUTION_STARTED,
_ET.NODE_LOOP_ITERATION,
_ET.TOOL_CALL_COMPLETED,
_ET.EXECUTION_COMPLETED,
_ET.EXECUTION_FAILED,
_ET.EXECUTION_PAUSED,
],
handler=_on_worker_event,
)
def _subscribe_worker_handoffs(self, session: Session, executor: Any) -> None:
"""Subscribe queen to worker/subagent escalation handoff events."""
from framework.runtime.event_bus import EventType as _ET
@@ -700,16 +941,21 @@ class SessionManager:
else None
)
)
_meta_path.write_text(
json.dumps(
{
"agent_name": _agent_name,
"agent_path": str(session.worker_path) if session.worker_path else None,
"created_at": time.time(),
}
),
encoding="utf-8",
)
# Merge into existing meta.json to preserve fields written by
# _update_meta_json (e.g. phase, agent_path set during building).
_existing_meta: dict = {}
if _meta_path.exists():
try:
_existing_meta = json.loads(_meta_path.read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError):
pass
_new_meta: dict = {"created_at": time.time()}
if _agent_name is not None:
_new_meta["agent_name"] = _agent_name
if session.worker_path is not None:
_new_meta["agent_path"] = str(session.worker_path)
_existing_meta.update(_new_meta)
_meta_path.write_text(json.dumps(_existing_meta), encoding="utf-8")
except OSError:
pass
@@ -719,6 +965,7 @@ class SessionManager:
# then use max+1 as offset so resumed sessions produce monotonically
# increasing iteration values — preventing frontend message ID collisions.
iteration_offset = 0
last_phase = ""
events_path = queen_dir / "events.jsonl"
try:
if events_path.exists():
@@ -730,17 +977,25 @@ class SessionManager:
continue
try:
evt = json.loads(line)
it = evt.get("data", {}).get("iteration")
data = evt.get("data", {})
it = data.get("iteration")
if isinstance(it, int) and it > max_iter:
max_iter = it
# Track the latest queen phase from QUEEN_PHASE_CHANGED events
if evt.get("type") == "queen_phase_changed":
phase = data.get("phase")
if phase:
last_phase = phase
except (json.JSONDecodeError, TypeError):
continue
if max_iter >= 0:
iteration_offset = max_iter + 1
logger.info(
"Session '%s' resuming with iteration_offset=%d (from events.jsonl max)",
"Session '%s' resuming with iteration_offset=%d"
" (from events.jsonl max), last phase: %s",
session.id,
iteration_offset,
last_phase or "unknown",
)
except OSError:
pass
@@ -762,11 +1017,27 @@ class SessionManager:
try:
_meta = json.loads(meta_path.read_text(encoding="utf-8"))
_agent_path = _meta.get("agent_path")
_phase = _meta.get("phase")
if _agent_path and Path(_agent_path).exists():
await self.load_worker(session.id, _agent_path)
if session.phase_state:
await session.phase_state.switch_to_staging(source="auto")
logger.info("Cold restore: auto-loaded worker from %s", _agent_path)
if _phase in ("staging", "running", None):
# Agent fully built — load worker and resume
await self.load_worker(session.id, _agent_path)
if session.phase_state:
await session.phase_state.switch_to_staging(source="auto")
# Emit flowchart overlay so frontend can display it
await self._emit_flowchart_on_restore(session, _agent_path)
logger.info("Cold restore: auto-loaded worker from %s", _agent_path)
elif _phase == "building":
# Agent folder exists but incomplete — resume building
if session.phase_state:
session.phase_state.agent_path = _agent_path
await session.phase_state.switch_to_building(source="auto")
logger.info("Cold restore: resumed BUILDING phase for %s", _agent_path)
elif _phase == "planning":
if session.phase_state:
session.phase_state.agent_path = _agent_path
logger.info("Cold restore: PLANNING phase for %s", _agent_path)
except Exception:
logger.warning("Cold restore: failed to auto-load worker", exc_info=True)
@@ -776,10 +1047,17 @@ class SessionManager:
_consolidation_session_dir = queen_dir
async def _on_compaction(_event) -> None:
# Only consolidate on queen compactions — worker and subagent
# compactions are frequent and don't warrant a memory update.
if getattr(_event, "stream_id", None) != "queen":
return
from framework.agents.queen.queen_memory import consolidate_queen_memory
await consolidate_queen_memory(
session.id, _consolidation_session_dir, _consolidation_llm
asyncio.create_task(
consolidate_queen_memory(
session.id, _consolidation_session_dir, _consolidation_llm
),
name=f"queen-memory-consolidation-{session.id}",
)
from framework.runtime.event_bus import EventType as _ET
@@ -841,6 +1119,29 @@ class SessionManager:
)
)
async def _emit_flowchart_on_restore(self, session: Session, agent_path: str | Path) -> None:
"""Emit FLOWCHART_MAP_UPDATED from persisted flowchart file on cold restore."""
from framework.runtime.event_bus import AgentEvent, EventType
from framework.tools.flowchart_utils import load_flowchart_file
original_draft, flowchart_map = load_flowchart_file(agent_path)
if original_draft is None:
return
# Cache in phase_state so the REST endpoint also returns it
if session.phase_state:
session.phase_state.original_draft_graph = original_draft
session.phase_state.flowchart_map = flowchart_map
await session.event_bus.publish(
AgentEvent(
type=EventType.FLOWCHART_MAP_UPDATED,
stream_id="queen",
data={
"map": flowchart_map,
"original_draft": original_draft,
},
)
)
async def _notify_queen_worker_unloaded(self, session: Session) -> None:
"""Notify the queen that the worker has been unloaded."""
executor = session.queen_executor
@@ -868,6 +1169,10 @@ class SessionManager:
event_type = (
EventType.TRIGGER_AVAILABLE if kind == "available" else EventType.TRIGGER_REMOVED
)
# Resolve graph entry node for trigger target
runner = getattr(session, "runner", None)
graph_entry = runner.graph.entry_node if runner else None
for t in triggers.values():
await session.event_bus.publish(
AgentEvent(
@@ -877,6 +1182,8 @@ class SessionManager:
"trigger_id": t.id,
"trigger_type": t.trigger_type,
"trigger_config": t.trigger_config,
"name": t.description or t.id,
**({"entry_node": graph_entry} if graph_entry else {}),
},
)
)
+67
View File
@@ -5,6 +5,7 @@ Uses aiohttp TestClient with mocked sessions to test all endpoints
without requiring actual LLM calls or agent loading.
"""
import asyncio
import json
from dataclasses import dataclass, field
from pathlib import Path
@@ -13,6 +14,7 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from aiohttp.test_utils import TestClient, TestServer
from framework.runtime.triggers import TriggerDefinition
from framework.server.app import create_app
from framework.server.session_manager import Session
@@ -172,6 +174,7 @@ def _make_session(
runner.intro_message = "Test intro"
mock_event_bus = MagicMock()
mock_event_bus.publish = AsyncMock()
mock_llm = MagicMock()
queen_executor = _make_queen_executor() if with_queen else None
@@ -484,6 +487,70 @@ class TestSessionCRUD:
data = await resp.json()
assert "primary" in data["graphs"]
@pytest.mark.asyncio
async def test_update_trigger_task(self, tmp_path):
session = _make_session(tmp_dir=tmp_path)
session.available_triggers["daily"] = TriggerDefinition(
id="daily",
trigger_type="timer",
trigger_config={"cron": "0 5 * * *"},
task="Old task",
)
app = _make_app_with_session(session)
async with TestClient(TestServer(app)) as client:
resp = await client.patch(
"/api/sessions/test_agent/triggers/daily",
json={"task": "New task"},
)
assert resp.status == 200
data = await resp.json()
assert data["task"] == "New task"
assert data["trigger_config"]["cron"] == "0 5 * * *"
assert session.available_triggers["daily"].task == "New task"
@pytest.mark.asyncio
async def test_update_trigger_cron_restarts_active_timer(self, tmp_path):
session = _make_session(tmp_dir=tmp_path)
session.available_triggers["daily"] = TriggerDefinition(
id="daily",
trigger_type="timer",
trigger_config={"cron": "0 5 * * *"},
task="Run task",
active=True,
)
session.active_trigger_ids.add("daily")
session.active_timer_tasks["daily"] = asyncio.create_task(asyncio.sleep(60))
app = _make_app_with_session(session)
async with TestClient(TestServer(app)) as client:
resp = await client.patch(
"/api/sessions/test_agent/triggers/daily",
json={"trigger_config": {"cron": "0 6 * * *"}},
)
assert resp.status == 200
data = await resp.json()
assert data["trigger_config"]["cron"] == "0 6 * * *"
assert "daily" in session.active_timer_tasks
assert session.active_timer_tasks["daily"] is not None
assert session.available_triggers["daily"].trigger_config["cron"] == "0 6 * * *"
session.active_timer_tasks["daily"].cancel()
@pytest.mark.asyncio
async def test_update_trigger_cron_rejects_invalid_expression(self, tmp_path):
session = _make_session(tmp_dir=tmp_path)
session.available_triggers["daily"] = TriggerDefinition(
id="daily",
trigger_type="timer",
trigger_config={"cron": "0 5 * * *"},
task="Run task",
)
app = _make_app_with_session(session)
async with TestClient(TestServer(app)) as client:
resp = await client.patch(
"/api/sessions/test_agent/triggers/daily",
json={"trigger_config": {"cron": "not a cron"}},
)
assert resp.status == 400
class TestExecution:
@pytest.mark.asyncio
+11 -2
View File
@@ -1,8 +1,8 @@
"""Hive Agent Skills — discovery, parsing, and injection of SKILL.md packages.
"""Hive Agent Skills — discovery, parsing, trust gating, and injection of SKILL.md packages.
Implements the open Agent Skills standard (agentskills.io) for portable
skill discovery and activation, plus built-in default skills for runtime
operational discipline.
operational discipline, and AS-13 trust gating for project-scope skills.
"""
from framework.skills.catalog import SkillCatalog
@@ -10,7 +10,10 @@ from framework.skills.config import DefaultSkillConfig, SkillsConfig
from framework.skills.defaults import DefaultSkillManager
from framework.skills.discovery import DiscoveryConfig, SkillDiscovery
from framework.skills.manager import SkillsManager, SkillsManagerConfig
from framework.skills.models import TrustStatus
from framework.skills.parser import ParsedSkill, parse_skill_md
from framework.skills.skill_errors import SkillError, SkillErrorCode, log_skill_error
from framework.skills.trust import TrustedRepoStore, TrustGate
__all__ = [
"DefaultSkillConfig",
@@ -22,5 +25,11 @@ __all__ = [
"SkillsConfig",
"SkillsManager",
"SkillsManagerConfig",
"TrustGate",
"TrustedRepoStore",
"TrustStatus",
"parse_skill_md",
"SkillError",
"SkillErrorCode",
"log_skill_error",
]
+10 -1
View File
@@ -10,6 +10,7 @@ import logging
from xml.sax.saxutils import escape
from framework.skills.parser import ParsedSkill
from framework.skills.skill_errors import SkillErrorCode, log_skill_error
logger = logging.getLogger(__name__)
@@ -76,6 +77,7 @@ class SkillCatalog:
lines.append(f" <name>{escape(skill.name)}</name>")
lines.append(f" <description>{escape(skill.description)}</description>")
lines.append(f" <location>{escape(skill.location)}</location>")
lines.append(f" <base_dir>{escape(skill.base_dir)}</base_dir>")
lines.append(" </skill>")
lines.append("</available_skills>")
@@ -96,7 +98,14 @@ class SkillCatalog:
for name in skill_names:
skill = self.get(name)
if skill is None:
logger.warning("Pre-activated skill '%s' not found in catalog", name)
log_skill_error(
logger,
"warning",
SkillErrorCode.SKILL_NOT_FOUND,
what=f"Pre-activated skill '{name}' not found in catalog",
why="The skill was listed for pre-activation but was not discovered.",
fix=f"Check that a SKILL.md for '{name}' exists in a scanned directory.",
)
continue
if self.is_activated(name):
continue # Already activated, skip duplicate
+120
View File
@@ -0,0 +1,120 @@
"""CLI commands for the Hive skill system.
Phase 1 commands (AS-13):
hive skill list list discovered skills across all scopes
hive skill trust <path> permanently trust a project repo's skills
Full CLI suite (CLI-1 through CLI-13) is Phase 2.
"""
from __future__ import annotations
import subprocess
import sys
from pathlib import Path
def register_skill_commands(subparsers) -> None:
"""Register the ``hive skill`` subcommand group."""
skill_parser = subparsers.add_parser("skill", help="Manage skills")
skill_sub = skill_parser.add_subparsers(dest="skill_command", required=True)
# hive skill list
list_parser = skill_sub.add_parser("list", help="List discovered skills across all scopes")
list_parser.add_argument(
"--project-dir",
default=None,
metavar="PATH",
help="Project directory to scan (default: current directory)",
)
list_parser.set_defaults(func=cmd_skill_list)
# hive skill trust
trust_parser = skill_sub.add_parser(
"trust",
help="Permanently trust a project repository so its skills load without prompting",
)
trust_parser.add_argument(
"project_path",
help="Path to the project directory (must contain a .git with a remote origin)",
)
trust_parser.set_defaults(func=cmd_skill_trust)
def cmd_skill_list(args) -> int:
"""List all discovered skills grouped by scope."""
from framework.skills.discovery import DiscoveryConfig, SkillDiscovery
project_dir = Path(args.project_dir).resolve() if args.project_dir else Path.cwd()
skills = SkillDiscovery(DiscoveryConfig(project_root=project_dir)).discover()
if not skills:
print("No skills discovered.")
return 0
scope_headers = {
"project": "PROJECT SKILLS",
"user": "USER SKILLS",
"framework": "FRAMEWORK SKILLS",
}
for scope in ("project", "user", "framework"):
scope_skills = [s for s in skills if s.source_scope == scope]
if not scope_skills:
continue
print(f"\n{scope_headers[scope]}")
print("" * 40)
for skill in scope_skills:
print(f"{skill.name}")
print(f" {skill.description}")
print(f" {skill.location}")
return 0
def cmd_skill_trust(args) -> int:
"""Permanently trust a project repository's skills."""
from framework.skills.trust import TrustedRepoStore, _normalize_remote_url
project_path = Path(args.project_path).resolve()
if not project_path.exists():
print(f"Error: path does not exist: {project_path}", file=sys.stderr)
return 1
if not (project_path / ".git").exists():
print(
f"Error: {project_path} is not a git repository (no .git directory).",
file=sys.stderr,
)
return 1
try:
result = subprocess.run(
["git", "-C", str(project_path), "remote", "get-url", "origin"],
capture_output=True,
text=True,
timeout=3,
)
if result.returncode != 0:
print(
"Error: no remote 'origin' configured in this repository.",
file=sys.stderr,
)
return 1
remote_url = result.stdout.strip()
except subprocess.TimeoutExpired:
print("Error: git remote lookup timed out.", file=sys.stderr)
return 1
except (FileNotFoundError, OSError) as e:
print(f"Error reading git remote: {e}", file=sys.stderr)
return 1
repo_key = _normalize_remote_url(remote_url)
store = TrustedRepoStore()
store.trust(repo_key, project_path=str(project_path))
print(f"✓ Trusted: {repo_key}")
print(" Stored in ~/.hive/trusted_repos.json")
print(" Skills from this repository will load without prompting in future runs.")
return 0
+53 -4
View File
@@ -11,6 +11,7 @@ from pathlib import Path
from framework.skills.config import SkillsConfig
from framework.skills.parser import ParsedSkill, parse_skill_md
from framework.skills.skill_errors import SkillErrorCode, log_skill_error
logger = logging.getLogger(__name__)
@@ -60,12 +61,14 @@ class DefaultSkillManager:
self._config = config or SkillsConfig()
self._skills: dict[str, ParsedSkill] = {}
self._loaded = False
self._error_count = 0
def load(self) -> None:
"""Load all enabled default skill SKILL.md files."""
if self._loaded:
return
error_count = 0
for skill_name, dir_name in SKILL_REGISTRY.items():
if not self._config.is_default_enabled(skill_name):
logger.info("Default skill '%s' disabled by config", skill_name)
@@ -73,17 +76,34 @@ class DefaultSkillManager:
skill_path = _DEFAULT_SKILLS_DIR / dir_name / "SKILL.md"
if not skill_path.is_file():
logger.error("Default skill SKILL.md not found: %s", skill_path)
log_skill_error(
logger,
"error",
SkillErrorCode.SKILL_NOT_FOUND,
what=f"Default skill SKILL.md not found: '{skill_path}'",
why=f"The framework skill '{skill_name}' is missing its SKILL.md file.",
fix="Reinstall the hive framework — this file is part of the package.",
)
error_count += 1
continue
parsed = parse_skill_md(skill_path, source_scope="framework")
if parsed is None:
logger.error("Failed to parse default skill: %s", skill_path)
log_skill_error(
logger,
"error",
SkillErrorCode.SKILL_PARSE_ERROR,
what=f"Failed to parse default skill '{skill_name}'",
why=f"parse_skill_md returned None for '{skill_path}'.",
fix="Reinstall the hive framework — this file may be corrupted.",
)
error_count += 1
continue
self._skills[skill_name] = parsed
self._loaded = True
self._error_count = error_count
def build_protocols_prompt(self) -> str:
"""Build the combined operational protocols section.
@@ -127,8 +147,23 @@ class DefaultSkillManager:
"""Log which default skills are active and their configuration."""
if not self._skills:
logger.info("Default skills: all disabled")
return
# DX-3: Per-skill structured startup log
for skill_name in SKILL_REGISTRY:
if skill_name in self._skills:
overrides = self._config.get_default_overrides(skill_name)
status = f"loaded overrides={overrides}" if overrides else "loaded"
elif not self._config.is_default_enabled(skill_name):
status = "disabled"
else:
status = "error"
logger.info(
"skill_startup name=%s scope=framework status=%s",
skill_name,
status,
)
# Original active skills log line (preserved for backward compatibility)
active = []
for skill_name in SKILL_REGISTRY:
if skill_name in self._skills:
@@ -138,7 +173,21 @@ class DefaultSkillManager:
else:
active.append(skill_name)
logger.info("Default skills active: %s", ", ".join(active))
if active:
logger.info("Default skills active: %s", ", ".join(active))
# DX-3: Summary line with error count
total = len(SKILL_REGISTRY)
active_count = len(self._skills)
error_count = getattr(self, "_error_count", 0)
disabled_count = total - active_count - error_count
logger.info(
"Skills: %d default (%d active, %d disabled, %d error)",
total,
active_count,
disabled_count,
error_count,
)
@property
def active_skill_names(self) -> list[str]:
+8 -5
View File
@@ -11,6 +11,7 @@ from dataclasses import dataclass
from pathlib import Path
from framework.skills.parser import ParsedSkill, parse_skill_md
from framework.skills.skill_errors import SkillErrorCode, log_skill_error
logger = logging.getLogger(__name__)
@@ -172,11 +173,13 @@ class SkillDiscovery:
for skill in skills:
if skill.name in seen:
existing = seen[skill.name]
logger.warning(
"Skill name collision: '%s' from %s overrides %s",
skill.name,
skill.location,
existing.location,
log_skill_error(
logger,
"warning",
SkillErrorCode.SKILL_COLLISION,
what=f"Skill name collision: '{skill.name}'",
why=f"'{skill.location}' overrides '{existing.location}'.",
fix="Rename one of the conflicting skill directories to use a unique name.",
)
seen[skill.name] = skill
+29
View File
@@ -42,11 +42,14 @@ class SkillsManagerConfig:
When ``None``, community discovery is skipped.
skip_community_discovery: Explicitly skip community scanning
even when ``project_root`` is set.
interactive: Whether trust gating can prompt the user interactively.
When ``False``, untrusted project skills are silently skipped.
"""
skills_config: SkillsConfig = field(default_factory=SkillsConfig)
project_root: Path | None = None
skip_community_discovery: bool = False
interactive: bool = True
class SkillsManager:
@@ -63,6 +66,7 @@ class SkillsManager:
self._loaded = False
self._catalog_prompt: str = ""
self._protocols_prompt: str = ""
self._allowlisted_dirs: list[str] = []
# ------------------------------------------------------------------
# Factory for backwards-compat bridge
@@ -85,6 +89,7 @@ class SkillsManager:
mgr._loaded = True # skip load()
mgr._catalog_prompt = skills_catalog_prompt
mgr._protocols_prompt = protocols_prompt
mgr._allowlisted_dirs = []
return mgr
# ------------------------------------------------------------------
@@ -113,9 +118,18 @@ class SkillsManager:
# 1. Community skill discovery (when project_root is available)
catalog_prompt = ""
if self._config.project_root is not None and not self._config.skip_community_discovery:
from framework.skills.trust import TrustGate
discovery = SkillDiscovery(DiscoveryConfig(project_root=self._config.project_root))
discovered = discovery.discover()
# Trust-gate project-scope skills (AS-13)
discovered = TrustGate(interactive=self._config.interactive).filter_and_gate(
discovered, project_dir=self._config.project_root
)
catalog = SkillCatalog(discovered)
self._allowlisted_dirs = catalog.allowlisted_dirs
catalog_prompt = catalog.to_prompt()
# Pre-activated community skills
@@ -132,6 +146,16 @@ class SkillsManager:
default_mgr.load()
default_mgr.log_active_skills()
protocols_prompt = default_mgr.build_protocols_prompt()
# DX-3: Community skill startup summary
if self._config.project_root is not None and not self._config.skip_community_discovery:
community_count = len(catalog._skills) if catalog_prompt else 0
pre_activated_count = len(skills_config.skills) if skills_config.skills else 0
logger.info(
"Skills: %d community (%d catalog, %d pre-activated)",
community_count,
community_count,
pre_activated_count,
)
# 3. Cache
self._catalog_prompt = catalog_prompt
@@ -160,6 +184,11 @@ class SkillsManager:
"""Default skill operational protocols for system prompt injection."""
return self._protocols_prompt
@property
def allowlisted_dirs(self) -> list[str]:
"""Skill base directories for Tier 3 resource access (AS-6)."""
return self._allowlisted_dirs
@property
def is_loaded(self) -> bool:
return self._loaded
+52
View File
@@ -0,0 +1,52 @@
"""Data models for the Hive skill system (Agent Skills standard)."""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import StrEnum
from pathlib import Path
class SkillScope(StrEnum):
"""Where a skill was discovered."""
PROJECT = "project"
USER = "user"
FRAMEWORK = "framework"
class TrustStatus(StrEnum):
"""Trust state of a skill entry."""
TRUSTED = "trusted"
PENDING_CONSENT = "pending_consent"
DENIED = "denied"
@dataclass
class SkillEntry:
"""In-memory record for a discovered skill (PRD §4.2)."""
name: str
"""Skill name from SKILL.md frontmatter."""
description: str
"""Skill description from SKILL.md frontmatter."""
location: Path
"""Absolute path to SKILL.md."""
base_dir: Path
"""Parent directory of SKILL.md (skill root)."""
source_scope: SkillScope
"""Which scope this skill was found in."""
trust_status: TrustStatus = TrustStatus.TRUSTED
"""Trust state; project-scope skills start as PENDING_CONSENT before gating."""
# Optional frontmatter fields
license: str | None = None
compatibility: list[str] = field(default_factory=list)
allowed_tools: list[str] = field(default_factory=list)
metadata: dict = field(default_factory=dict)
+81 -14
View File
@@ -13,6 +13,8 @@ from dataclasses import dataclass
from pathlib import Path
from typing import Any
from framework.skills.skill_errors import SkillErrorCode, log_skill_error
logger = logging.getLogger(__name__)
# Maximum name length before a warning is logged
@@ -74,17 +76,38 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
try:
content = path.read_text(encoding="utf-8")
except OSError as exc:
logger.error("Failed to read %s: %s", path, exc)
log_skill_error(
logger,
"error",
SkillErrorCode.SKILL_ACTIVATION_FAILED,
what=f"Failed to read '{path}'",
why=str(exc),
fix="Check the file exists and has read permissions.",
)
return None
if not content.strip():
logger.error("Empty SKILL.md: %s", path)
log_skill_error(
logger,
"error",
SkillErrorCode.SKILL_PARSE_ERROR,
what=f"Invalid SKILL.md at '{path}'",
why="The file exists but contains no content.",
fix="Add valid YAML frontmatter and a markdown body to the SKILL.md.",
)
return None
# Split on --- delimiters (first two occurrences)
parts = content.split("---", 2)
if len(parts) < 3:
logger.error("SKILL.md missing YAML frontmatter delimiters (---): %s", path)
log_skill_error(
logger,
"error",
SkillErrorCode.SKILL_PARSE_ERROR,
what=f"Invalid SKILL.md at '{path}'",
why="Missing YAML frontmatter (---).",
fix="Wrap the frontmatter with --- on its own line at the top and bottom.",
)
return None
# parts[0] is content before first --- (should be empty or whitespace)
@@ -94,7 +117,14 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
body = parts[2].strip()
if not raw_yaml:
logger.error("Empty YAML frontmatter in %s", path)
log_skill_error(
logger,
"error",
SkillErrorCode.SKILL_PARSE_ERROR,
what=f"Invalid SKILL.md at '{path}'",
why="The --- delimiters are present but the YAML block is empty.",
fix="Add at least 'name' and 'description' fields to the frontmatter.",
)
return None
# Parse YAML
@@ -108,19 +138,47 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
try:
fixed = _try_fix_yaml(raw_yaml)
frontmatter = yaml.safe_load(fixed)
logger.warning("Fixed YAML parse issues in %s (unquoted colons)", path)
log_skill_error(
logger,
"warning",
SkillErrorCode.SKILL_YAML_FIXUP,
what=f"Auto-fixed YAML in '{path}'",
why="Unquoted colon values detected in frontmatter.",
fix='Wrap values containing colons in quotes e.g. description: "Use for: research"',
)
except yaml.YAMLError as exc:
logger.error("Unparseable YAML in %s: %s", path, exc)
log_skill_error(
logger,
"error",
SkillErrorCode.SKILL_PARSE_ERROR,
what=f"Invalid SKILL.md at '{path}'",
why=str(exc),
fix="Validate the YAML frontmatter at https://yaml-online-parser.appspot.com/",
)
return None
if not isinstance(frontmatter, dict):
logger.error("YAML frontmatter is not a mapping in %s", path)
log_skill_error(
logger,
"error",
SkillErrorCode.SKILL_PARSE_ERROR,
what=f"Invalid SKILL.md at '{path}'",
why="YAML frontmatter is not a key-value mapping.",
fix="Ensure the frontmatter is valid YAML with key: value pairs.",
)
return None
# Required: description
description = frontmatter.get("description")
if not description or not str(description).strip():
logger.error("Missing or empty 'description' in %s — skipping skill", path)
log_skill_error(
logger,
"error",
SkillErrorCode.SKILL_MISSING_DESCRIPTION,
what=f"Missing 'description' in '{path}'",
why="The 'description' field is required but is absent or empty.",
fix="Add a non-empty 'description' field to the YAML frontmatter.",
)
return None
# Required: name (fallback to parent directory name)
@@ -128,7 +186,14 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
parent_dir_name = path.parent.name
if not name or not str(name).strip():
name = parent_dir_name
logger.warning("Missing 'name' in %s — using directory name '%s'", path, name)
log_skill_error(
logger,
"warning",
SkillErrorCode.SKILL_NAME_MISMATCH,
what=f"Missing 'name' in '{path}' — using directory name '{name}'",
why="The 'name' field is absent from the YAML frontmatter.",
fix=f"Add 'name: {name}' to the frontmatter to make this explicit.",
)
else:
name = str(name).strip()
@@ -137,11 +202,13 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
logger.warning("Skill name exceeds %d chars in %s: '%s'", _MAX_NAME_LENGTH, path, name)
if name != parent_dir_name and not name.endswith(f".{parent_dir_name}"):
logger.warning(
"Skill name '%s' doesn't match parent directory '%s' in %s",
name,
parent_dir_name,
path,
log_skill_error(
logger,
"warning",
SkillErrorCode.SKILL_NAME_MISMATCH,
what=f"Name mismatch in '{path}'",
why=f"Skill name '{name}' doesn't match directory '{parent_dir_name}'.",
fix=f"Rename the directory to '{name}' or set name to '{parent_dir_name}'.",
)
return ParsedSkill(
+70
View File
@@ -0,0 +1,70 @@
"""Structured error codes and diagnostics for the Hive skill system.
Implements DX-1 (structured error codes) and DX-2 (what/why/fix format)
from the skill system PRD §7.5.
"""
from __future__ import annotations
import logging
from enum import Enum
class SkillErrorCode(Enum):
"""Standardized error codes for skill system operations (DX-1)."""
SKILL_NOT_FOUND = "SKILL_NOT_FOUND"
SKILL_PARSE_ERROR = "SKILL_PARSE_ERROR"
SKILL_ACTIVATION_FAILED = "SKILL_ACTIVATION_FAILED"
SKILL_MISSING_DESCRIPTION = "SKILL_MISSING_DESCRIPTION"
SKILL_YAML_FIXUP = "SKILL_YAML_FIXUP"
SKILL_NAME_MISMATCH = "SKILL_NAME_MISMATCH"
SKILL_COLLISION = "SKILL_COLLISION"
class SkillError(Exception):
"""Structured exception for skill system errors (DX-2).
Raised in strict validation paths. Also used as the base
format contract for log_skill_error() log messages.
"""
def __init__(self, code: SkillErrorCode, what: str, why: str, fix: str):
self.code = code
self.what = what
self.why = why
self.fix = fix
self.message = (
f"[{self.code.value}]\nWhat failed: {self.what}\nWhy: {self.why}\nFix: {self.fix}"
)
super().__init__(self.message)
def log_skill_error(
logger: logging.Logger,
level: str,
code: SkillErrorCode,
what: str,
why: str,
fix: str,
) -> None:
"""Emit a structured skill diagnostic log with consistent format (DX-2).
Args:
logger: The module logger to emit to.
level: Log level string 'error', 'warning', or 'info'.
code: Structured error code.
what: What failed (specific skill name and path).
why: Root cause.
fix: Concrete next step for the developer.
"""
msg = f"[{code.value}] What failed: {what} | Why: {why} | Fix: {fix}"
getattr(logger, level)(
msg,
extra={
"skill_error_code": code.value,
"what": what,
"why": why,
"fix": fix,
},
)
+477
View File
@@ -0,0 +1,477 @@
"""Trust gating for project-level skills (PRD AS-13).
Project-level skills from untrusted repositories require explicit user consent
before their instructions are loaded into the agent's system prompt.
Framework and user-scope skills are always trusted.
Trusted repos are persisted at ~/.hive/trusted_repos.json.
"""
from __future__ import annotations
import json
import logging
import subprocess
import sys
from collections.abc import Callable
from dataclasses import dataclass
from datetime import UTC, datetime
from enum import StrEnum
from pathlib import Path
from urllib.parse import urlparse
from framework.skills.parser import ParsedSkill
logger = logging.getLogger(__name__)
# Env var to bypass trust gating in CI/headless pipelines (opt-in).
_ENV_TRUST_ALL = "HIVE_TRUST_PROJECT_SKILLS"
# Env var for comma-separated own-remote glob patterns (e.g. "github.com/myorg/*").
_ENV_OWN_REMOTES = "HIVE_OWN_REMOTES"
_TRUSTED_REPOS_PATH = Path.home() / ".hive" / "trusted_repos.json"
_NOTICE_SENTINEL_PATH = Path.home() / ".hive" / ".skill_trust_notice_shown"
# ---------------------------------------------------------------------------
# Trusted repo store
# ---------------------------------------------------------------------------
@dataclass
class TrustedRepoEntry:
repo_key: str
added_at: datetime
project_path: str = ""
class TrustedRepoStore:
"""Persists permanently-trusted repo keys to ~/.hive/trusted_repos.json."""
def __init__(self, path: Path | None = None) -> None:
self._path = path or _TRUSTED_REPOS_PATH
self._entries: dict[str, TrustedRepoEntry] = {}
self._loaded = False
def is_trusted(self, repo_key: str) -> bool:
self._ensure_loaded()
return repo_key in self._entries
def trust(self, repo_key: str, project_path: str = "") -> None:
self._ensure_loaded()
self._entries[repo_key] = TrustedRepoEntry(
repo_key=repo_key,
added_at=datetime.now(tz=UTC),
project_path=project_path,
)
self._save()
logger.info("skill_trust_store: trusted repo_key=%s", repo_key)
def revoke(self, repo_key: str) -> bool:
self._ensure_loaded()
if repo_key in self._entries:
del self._entries[repo_key]
self._save()
logger.info("skill_trust_store: revoked repo_key=%s", repo_key)
return True
return False
def list_entries(self) -> list[TrustedRepoEntry]:
self._ensure_loaded()
return list(self._entries.values())
def _ensure_loaded(self) -> None:
if not self._loaded:
self._load()
self._loaded = True
def _load(self) -> None:
try:
data = json.loads(self._path.read_text(encoding="utf-8"))
for raw in data.get("entries", []):
repo_key = raw.get("repo_key", "")
if not repo_key:
continue
try:
added_at = datetime.fromisoformat(raw["added_at"])
except (KeyError, ValueError):
added_at = datetime.now(tz=UTC)
self._entries[repo_key] = TrustedRepoEntry(
repo_key=repo_key,
added_at=added_at,
project_path=raw.get("project_path", ""),
)
except FileNotFoundError:
pass
except Exception as e:
logger.warning(
"skill_trust_store: could not read %s (%s); treating as empty",
self._path,
e,
)
def _save(self) -> None:
self._path.parent.mkdir(parents=True, exist_ok=True)
data = {
"version": 1,
"entries": [
{
"repo_key": e.repo_key,
"added_at": e.added_at.isoformat(),
"project_path": e.project_path,
}
for e in self._entries.values()
],
}
# Atomic write: write to .tmp then rename
tmp = self._path.with_suffix(".tmp")
tmp.write_text(json.dumps(data, indent=2), encoding="utf-8")
tmp.replace(self._path)
# ---------------------------------------------------------------------------
# Trust classification
# ---------------------------------------------------------------------------
class ProjectTrustClassification(StrEnum):
ALWAYS_TRUSTED = "always_trusted"
TRUSTED_BY_USER = "trusted_by_user"
UNTRUSTED = "untrusted"
class ProjectTrustDetector:
"""Classifies a project directory as trusted or untrusted.
Algorithm (PRD §4.1 trust note):
1. No project_dir ALWAYS_TRUSTED
2. No .git directory ALWAYS_TRUSTED (not a git repo)
3. No remote 'origin' ALWAYS_TRUSTED (local-only repo)
4. Remote URL repo_key; in TrustedRepoStore TRUSTED_BY_USER
5. Localhost remote ALWAYS_TRUSTED
6. ~/.hive/own_remotes match ALWAYS_TRUSTED
7. HIVE_OWN_REMOTES env match ALWAYS_TRUSTED
8. None of the above UNTRUSTED
"""
def __init__(self, store: TrustedRepoStore | None = None) -> None:
self._store = store or TrustedRepoStore()
def classify(self, project_dir: Path | None) -> tuple[ProjectTrustClassification, str]:
"""Return (classification, repo_key).
repo_key is empty string for ALWAYS_TRUSTED cases without a remote.
"""
if project_dir is None or not project_dir.exists():
return ProjectTrustClassification.ALWAYS_TRUSTED, ""
if not (project_dir / ".git").exists():
return ProjectTrustClassification.ALWAYS_TRUSTED, ""
remote_url = self._get_remote_origin(project_dir)
if not remote_url:
return ProjectTrustClassification.ALWAYS_TRUSTED, ""
repo_key = _normalize_remote_url(remote_url)
# Explicitly trusted by user
if self._store.is_trusted(repo_key):
return ProjectTrustClassification.TRUSTED_BY_USER, repo_key
# Localhost remotes are always trusted
if _is_localhost_remote(remote_url):
return ProjectTrustClassification.ALWAYS_TRUSTED, repo_key
# User-configured own-remote patterns
if self._matches_own_remotes(repo_key):
return ProjectTrustClassification.ALWAYS_TRUSTED, repo_key
return ProjectTrustClassification.UNTRUSTED, repo_key
def _get_remote_origin(self, project_dir: Path) -> str:
"""Run git remote get-url origin. Returns empty string on any failure."""
try:
result = subprocess.run(
["git", "-C", str(project_dir), "remote", "get-url", "origin"],
capture_output=True,
text=True,
timeout=3,
)
if result.returncode == 0:
return result.stdout.strip()
except subprocess.TimeoutExpired:
logger.warning(
"skill_trust: git remote lookup timed out for %s; treating as trusted",
project_dir,
)
except (FileNotFoundError, OSError):
pass # git not found or other OS error
return ""
def _matches_own_remotes(self, repo_key: str) -> bool:
"""Check repo_key against user-configured own-remote glob patterns."""
import fnmatch
patterns: list[str] = []
# From env var
env_patterns = _ENV_OWN_REMOTES
import os
raw = os.environ.get(env_patterns, "")
if raw:
patterns.extend(p.strip() for p in raw.split(",") if p.strip())
# From ~/.hive/own_remotes file
own_remotes_file = Path.home() / ".hive" / "own_remotes"
if own_remotes_file.is_file():
try:
for line in own_remotes_file.read_text(encoding="utf-8").splitlines():
line = line.strip()
if line and not line.startswith("#"):
patterns.append(line)
except OSError:
pass
return any(fnmatch.fnmatch(repo_key, p) for p in patterns)
# ---------------------------------------------------------------------------
# URL helpers (public so CLI can reuse)
# ---------------------------------------------------------------------------
def _normalize_remote_url(url: str) -> str:
"""Normalize a git remote URL to a canonical ``host/org/repo`` key.
Examples:
git@github.com:org/repo.git github.com/org/repo
https://github.com/org/repo github.com/org/repo
ssh://git@github.com/org/repo.git github.com/org/repo
"""
url = url.strip()
# SCP-style SSH: git@github.com:org/repo.git
if url.startswith("git@") and ":" in url and "://" not in url:
url = url[4:] # strip git@
url = url.replace(":", "/", 1)
elif "://" in url:
parsed = urlparse(url)
host = parsed.hostname or ""
path = parsed.path.lstrip("/")
url = f"{host}/{path}"
# Strip .git suffix
if url.endswith(".git"):
url = url[:-4]
return url.lower().strip("/")
def _is_localhost_remote(remote_url: str) -> bool:
"""Return True if the remote points to a local host."""
local_hosts = {"localhost", "127.0.0.1", "::1"}
try:
if "://" in remote_url:
parsed = urlparse(remote_url)
return (parsed.hostname or "").lower() in local_hosts
# SCP-style: git@localhost:org/repo
if "@" in remote_url:
host_part = remote_url.split("@", 1)[1].split(":")[0]
return host_part.lower() in local_hosts
except Exception:
pass
return False
# ---------------------------------------------------------------------------
# Trust gate
# ---------------------------------------------------------------------------
class TrustGate:
"""Filters skill list, running consent flow for untrusted project-scope skills.
Framework and user-scope skills are always allowed through.
Project-scope skills from untrusted repos require consent.
"""
def __init__(
self,
store: TrustedRepoStore | None = None,
detector: ProjectTrustDetector | None = None,
interactive: bool = True,
print_fn: Callable[[str], None] | None = None,
input_fn: Callable[[str], str] | None = None,
) -> None:
self._store = store or TrustedRepoStore()
self._detector = detector or ProjectTrustDetector(self._store)
self._interactive = interactive
self._print = print_fn or print
self._input = input_fn or input
def filter_and_gate(
self,
skills: list[ParsedSkill],
project_dir: Path | None,
) -> list[ParsedSkill]:
"""Return the subset of skills that are trusted for loading.
- Framework and user-scope skills: always included.
- Project-scope skills: classified; consent prompt shown if untrusted.
"""
import os
# Separate project skills from always-trusted scopes
always_trusted = [s for s in skills if s.source_scope != "project"]
project_skills = [s for s in skills if s.source_scope == "project"]
if not project_skills:
return always_trusted
# Env-var CI override: trust all project skills for this invocation
if os.environ.get(_ENV_TRUST_ALL, "").strip() == "1":
logger.info(
"skill_trust: %s=1 set; trusting %d project skill(s) without consent",
_ENV_TRUST_ALL,
len(project_skills),
)
return always_trusted + project_skills
classification, repo_key = self._detector.classify(project_dir)
if classification in (
ProjectTrustClassification.ALWAYS_TRUSTED,
ProjectTrustClassification.TRUSTED_BY_USER,
):
logger.info(
"skill_trust: project skills trusted classification=%s repo=%s count=%d",
classification,
repo_key or "(no remote)",
len(project_skills),
)
return always_trusted + project_skills
# UNTRUSTED — need consent
if not self._interactive or not sys.stdin.isatty():
logger.warning(
"skill_trust: skipping %d project-scope skill(s) from untrusted repo "
"'%s' (non-interactive mode). "
"To trust permanently run: hive skill trust %s",
len(project_skills),
repo_key,
project_dir or ".",
)
logger.info(
"skill_trust_decision repo=%s skills=%d decision=denied mode=headless",
repo_key,
len(project_skills),
)
return always_trusted
# Interactive consent flow
decision = self._run_consent_flow(project_skills, project_dir, repo_key)
logger.info(
"skill_trust_decision repo=%s skills=%d decision=%s mode=interactive",
repo_key,
len(project_skills),
decision,
)
if decision == "session":
return always_trusted + project_skills
if decision == "permanent":
self._store.trust(repo_key, project_path=str(project_dir or ""))
return always_trusted + project_skills
# denied
return always_trusted
def _run_consent_flow(
self,
project_skills: list[ParsedSkill],
project_dir: Path | None,
repo_key: str,
) -> str:
"""Show the security notice (once) and consent prompt.
Return 'session' | 'permanent' | 'denied'."""
from framework.credentials.setup import Colors
if not sys.stdout.isatty():
Colors.disable()
self._maybe_show_security_notice(Colors)
self._print_consent_prompt(project_skills, project_dir, repo_key, Colors)
return self._prompt_consent(Colors)
def _maybe_show_security_notice(self, Colors) -> None: # noqa: N803
"""Show the one-time security notice if not already shown (NFR-5)."""
if _NOTICE_SENTINEL_PATH.exists():
return
self._print("")
self._print(
f"{Colors.YELLOW}Security notice:{Colors.NC} Skills inject instructions "
"into the agent's system prompt."
)
self._print(
" Only load skills from sources you trust. "
"Registry skills at tier 'verified' or 'official' have been audited."
)
self._print("")
try:
_NOTICE_SENTINEL_PATH.parent.mkdir(parents=True, exist_ok=True)
_NOTICE_SENTINEL_PATH.touch()
except OSError:
pass
def _print_consent_prompt(
self,
project_skills: list[ParsedSkill],
project_dir: Path | None,
repo_key: str,
Colors, # noqa: N803
) -> None:
p = self._print
p("")
p(f"{Colors.YELLOW}{'=' * 60}{Colors.NC}")
p(f"{Colors.BOLD} SKILL TRUST REQUIRED{Colors.NC}")
p(f"{Colors.YELLOW}{'=' * 60}{Colors.NC}")
p("")
proj_label = str(project_dir) if project_dir else "this project"
p(
f" The project at {Colors.CYAN}{proj_label}{Colors.NC} wants to load "
f"{len(project_skills)} skill(s)"
)
p(" that will inject instructions into the agent's system prompt.")
if repo_key:
p(f" Source: {Colors.BOLD}{repo_key}{Colors.NC}")
p("")
p(" Skills requesting access:")
for skill in project_skills:
p(f" {Colors.CYAN}{Colors.NC} {Colors.BOLD}{skill.name}{Colors.NC}")
p(f' "{skill.description}"')
p(f" {Colors.DIM}{skill.location}{Colors.NC}")
p("")
p(" Options:")
p(f" {Colors.CYAN}1){Colors.NC} Trust this session only")
p(f" {Colors.CYAN}2){Colors.NC} Trust permanently — remember for future runs")
p(
f" {Colors.DIM}3) Deny"
f" — skip all project-scope skills from this repo{Colors.NC}"
)
p(f"{Colors.YELLOW}{'' * 60}{Colors.NC}")
def _prompt_consent(self, Colors) -> str: # noqa: N803
"""Prompt until a valid choice is entered. Returns 'session'|'permanent'|'denied'."""
mapping = {"1": "session", "2": "permanent", "3": "denied"}
while True:
try:
choice = self._input("Select option (1-3): ").strip()
if choice in mapping:
return mapping[choice]
except (KeyboardInterrupt, EOFError):
return "denied"
self._print(f"{Colors.RED}Invalid choice. Enter 1, 2, or 3.{Colors.NC}")
+122 -13
View File
@@ -727,6 +727,25 @@ def _dissolve_planning_nodes(
return converted, flowchart_map
def _update_meta_json(session_manager, manager_session_id, updates: dict) -> None:
"""Merge updates into the queen session's meta.json."""
if session_manager is None or not manager_session_id:
return
srv_session = session_manager.get_session(manager_session_id)
if not srv_session:
return
storage_sid = getattr(srv_session, "queen_resume_from", None) or srv_session.id
meta_path = Path.home() / ".hive" / "queen" / "session" / storage_sid / "meta.json"
try:
existing = {}
if meta_path.exists():
existing = json.loads(meta_path.read_text(encoding="utf-8"))
existing.update(updates)
meta_path.write_text(json.dumps(existing), encoding="utf-8")
except OSError:
pass
def register_queen_lifecycle_tools(
registry: ToolRegistry,
session: Any = None,
@@ -975,6 +994,7 @@ def register_queen_lifecycle_tools(
# Switch to building phase
if phase_state is not None:
await phase_state.switch_to_building()
_update_meta_json(session_manager, manager_session_id, {"phase": "building"})
result = json.loads(stop_result)
result["phase"] = "building"
@@ -1559,12 +1579,22 @@ def register_queen_lifecycle_tools(
# Find edges where this leaf node is the source
out_edges = [e for e in validated_edges if e["source"] == leaf_id]
in_edges = [e for e in validated_edges if e["target"] == leaf_id]
if not out_edges:
continue # already a proper leaf
# Identify the parent (predecessor that connects IN)
parent_ids = [e["source"] for e in in_edges]
if not out_edges:
# Already a proper leaf — still ensure sub_agents is set
for pid in parent_ids:
parent = node_by_id_v.get(pid)
if parent is None:
continue
existing = parent.get("sub_agents") or []
if leaf_id not in existing:
existing.append(leaf_id)
parent["sub_agents"] = existing
continue
# Strip all outgoing edges from the leaf node that
# don't go back to a parent (report edges are OK)
illegal_targets: list[str] = []
@@ -1978,6 +2008,17 @@ def register_queen_lifecycle_tools(
"type": "string",
"description": "What success looks like for this node",
},
"sub_agents": {
"type": "array",
"items": {"type": "string"},
"description": (
"IDs of GCU/browser sub-agent nodes managed by this node. "
"At build time, sub-agent nodes are dissolved into this list. "
"Set this on the PARENT node — e.g. the orchestrator that "
"delegates to GCU leaves. Visual delegation edges are "
"synthesized automatically."
),
},
"decision_clause": {
"type": "string",
"description": (
@@ -2095,8 +2136,22 @@ def register_queen_lifecycle_tools(
phase_state.draft_graph = converted
phase_state.flowchart_map = fmap
# Note: flowchart file is persisted later, in initialize_and_build_agent
# (after the agent folder is scaffolded) or in load_built_agent.
# Create agent folder early so flowchart and agent_path are available
# throughout the entire BUILDING phase.
_agent_name = phase_state.draft_graph.get("agent_name", "").strip()
if _agent_name:
_agent_folder = Path("exports") / _agent_name
_agent_folder.mkdir(parents=True, exist_ok=True)
_save_flowchart_file(_agent_folder, original_copy, fmap)
phase_state.agent_path = str(_agent_folder)
_update_meta_json(
session_manager,
manager_session_id,
{
"agent_path": str(_agent_folder),
"agent_name": _agent_name.replace("_", " ").title(),
},
)
dissolved_count = len(original_nodes) - len(converted.get("nodes", []))
decision_count = sum(1 for n in original_nodes if n.get("flowchart_type") == "decision")
@@ -2228,6 +2283,7 @@ def register_queen_lifecycle_tools(
if fallback_path:
phase_state.agent_path = str(fallback_path)
await phase_state.switch_to_building(source="tool")
_update_meta_json(session_manager, manager_session_id, {"phase": "building"})
if phase_state.inject_notification:
await phase_state.inject_notification(
"[PHASE CHANGE] Switched to BUILDING phase. "
@@ -2270,8 +2326,13 @@ def register_queen_lifecycle_tools(
if parsed.get("success", True):
if phase_state is not None:
# Set agent_path so the frontend can query credentials
phase_state.agent_path = str(Path("exports") / agent_name)
phase_state.agent_path = phase_state.agent_path or str(
Path("exports") / agent_name
)
await phase_state.switch_to_building(source="tool")
_update_meta_json(
session_manager, manager_session_id, {"phase": "building"}
)
# Reset draft state after successful scaffolding
phase_state.build_confirmed = False
# Persist flowchart now that the agent folder exists
@@ -2319,6 +2380,7 @@ def register_queen_lifecycle_tools(
# Switch to staging phase
if phase_state is not None:
await phase_state.switch_to_staging()
_update_meta_json(session_manager, manager_session_id, {"phase": "staging"})
result = json.loads(stop_result)
result["phase"] = "staging"
@@ -2347,6 +2409,30 @@ def register_queen_lifecycle_tools(
"""Get the session's event bus for querying history."""
return getattr(session, "event_bus", None)
def _get_worker_name() -> str | None:
"""Return the worker agent directory name, used for diary lookups."""
p = getattr(session, "worker_path", None)
return p.name if p else None
def _format_diary(max_runs: int) -> str:
"""Read recent run digests from disk — no EventBus required."""
agent_name = _get_worker_name()
if not agent_name:
return "No worker loaded — diary unavailable."
from framework.agents.worker_memory import read_recent_digests
entries = read_recent_digests(agent_name, max_runs)
if not entries:
return (
f"No run digests for '{agent_name}' yet. "
"Digests are written at the end of each completed run."
)
lines = [f"Worker '{agent_name}'{len(entries)} recent run digest(s):", ""]
for _run_id, content in entries:
lines.append(content)
lines.append("")
return "\n".join(lines).rstrip()
# Tiered cooldowns: summary is free, detail has short cooldown, full keeps 30s
_COOLDOWN_FULL = 30.0
_COOLDOWN_DETAIL = 10.0
@@ -2949,16 +3035,17 @@ def register_queen_lifecycle_tools(
import time as _time
# --- Tiered cooldown ---
# diary is free (file reads only), summary is free, detail has 10s, full has 30s
now = _time.monotonic()
if focus == "full":
cooldown = _COOLDOWN_FULL
tier = "full"
elif focus is not None:
elif focus == "diary" or focus is None:
cooldown = 0.0
tier = focus or "summary"
else:
cooldown = _COOLDOWN_DETAIL
tier = "detail"
else:
cooldown = 0.0
tier = "summary"
elapsed_since = now - _status_last_called.get(tier, 0.0)
if elapsed_since < cooldown:
@@ -2974,6 +3061,10 @@ def register_queen_lifecycle_tools(
)
_status_last_called[tier] = now
# --- Diary: pure file reads, no runtime required ---
if focus == "diary":
return _format_diary(last_n)
# --- Runtime check ---
runtime = _get_runtime()
if runtime is None:
@@ -3023,7 +3114,7 @@ def register_queen_lifecycle_tools(
else:
return (
f"Unknown focus '{focus}'. "
"Valid options: activity, memory, tools, issues, progress, full."
"Valid options: diary, activity, memory, tools, issues, progress, full."
)
except Exception as exc:
logger.exception("get_worker_status error")
@@ -3034,6 +3125,8 @@ def register_queen_lifecycle_tools(
description=(
"Check on the worker. Returns a brief prose summary by default. "
"Use 'focus' to drill into specifics:\n"
"- diary: persistent run digests from past executions — read this first "
"before digging into live runtime logs\n"
"- activity: current node, transitions, latest LLM output\n"
"- memory: worker's accumulated knowledge and state\n"
"- tools: running and recent tool calls\n"
@@ -3046,8 +3139,11 @@ def register_queen_lifecycle_tools(
"properties": {
"focus": {
"type": "string",
"enum": ["activity", "memory", "tools", "issues", "progress", "full"],
"description": ("Aspect to inspect. Omit for a brief summary."),
"enum": ["diary", "activity", "memory", "tools", "issues", "progress", "full"],
"description": (
"Aspect to inspect. Omit for a brief summary. "
"Use 'diary' to read persistent run history before checking live logs."
),
},
"last_n": {
"type": "integer",
@@ -3446,6 +3542,7 @@ def register_queen_lifecycle_tools(
if phase_state is not None:
phase_state.agent_path = str(resolved_path)
await phase_state.switch_to_staging()
_update_meta_json(session_manager, manager_session_id, {"phase": "staging"})
worker_name = info.name if info else updated_session.worker_id
return json.dumps(
@@ -3565,6 +3662,7 @@ def register_queen_lifecycle_tools(
# Switch to running phase
if phase_state is not None:
await phase_state.switch_to_running()
_update_meta_json(session_manager, manager_session_id, {"phase": "running"})
return json.dumps(
{
@@ -3702,6 +3800,8 @@ def register_queen_lifecycle_tools(
_save_trigger_to_agent(session, trigger_id, tdef)
bus = getattr(session, "event_bus", None)
if bus:
_runner = getattr(session, "runner", None)
_graph_entry = _runner.graph.entry_node if _runner else None
await bus.publish(
AgentEvent(
type=EventType.TRIGGER_ACTIVATED,
@@ -3710,6 +3810,8 @@ def register_queen_lifecycle_tools(
"trigger_id": trigger_id,
"trigger_type": t_type,
"trigger_config": t_config,
"name": tdef.description or trigger_id,
**({"entry_node": _graph_entry} if _graph_entry else {}),
},
)
)
@@ -3762,6 +3864,8 @@ def register_queen_lifecycle_tools(
# Emit event
bus = getattr(session, "event_bus", None)
if bus:
_runner = getattr(session, "runner", None)
_graph_entry = _runner.graph.entry_node if _runner else None
await bus.publish(
AgentEvent(
type=EventType.TRIGGER_ACTIVATED,
@@ -3770,6 +3874,8 @@ def register_queen_lifecycle_tools(
"trigger_id": trigger_id,
"trigger_type": t_type,
"trigger_config": t_config,
"name": tdef.description or trigger_id,
**({"entry_node": _graph_entry} if _graph_entry else {}),
},
)
)
@@ -3868,7 +3974,10 @@ def register_queen_lifecycle_tools(
AgentEvent(
type=EventType.TRIGGER_DEACTIVATED,
stream_id="queen",
data={"trigger_id": trigger_id},
data={
"trigger_id": trigger_id,
"name": tdef.description or trigger_id if tdef else trigger_id,
},
)
)
+2 -2
View File
@@ -34,8 +34,8 @@ export const executionApi = {
graph_id: graphId,
}),
chat: (sessionId: string, message: string) =>
api.post<ChatResult>(`/sessions/${sessionId}/chat`, { message }),
chat: (sessionId: string, message: string, images?: { type: string; image_url: { url: string } }[]) =>
api.post<ChatResult>(`/sessions/${sessionId}/chat`, { message, ...(images?.length ? { images } : {}) }),
/** Queue context for the queen without triggering an LLM response. */
queenContext: (sessionId: string, message: string) =>
+11 -3
View File
@@ -64,10 +64,14 @@ export const sessionsApi = {
`/sessions/${sessionId}/entry-points`,
),
updateTriggerTask: (sessionId: string, triggerId: string, task: string) =>
api.patch<{ trigger_id: string; task: string }>(
updateTrigger: (
sessionId: string,
triggerId: string,
patch: { task?: string; trigger_config?: Record<string, unknown> },
) =>
api.patch<{ trigger_id: string; task: string; trigger_config: Record<string, unknown> }>(
`/sessions/${sessionId}/triggers/${triggerId}`,
{ task },
patch,
),
graphs: (sessionId: string) =>
@@ -77,6 +81,10 @@ export const sessionsApi = {
eventsHistory: (sessionId: string) =>
api.get<{ events: AgentEvent[]; session_id: string }>(`/sessions/${sessionId}/events/history`),
/** Open the session's data folder in the OS file manager. */
revealFolder: (sessionId: string) =>
api.post<{ path: string }>(`/sessions/${sessionId}/reveal`),
/** List all queen sessions on disk — live + cold (post-restart). */
history: () =>
api.get<{ sessions: Array<{ session_id: string; cold: boolean; live: boolean; has_messages: boolean; created_at: number; agent_name?: string | null; agent_path?: string | null }> }>("/sessions/history"),
+5 -1
View File
@@ -14,6 +14,8 @@ export interface LiveSession {
intro_message?: string;
/** Queen operating phase — "planning", "building", "staging", or "running" */
queen_phase?: "planning" | "building" | "staging" | "running";
/** Whether the queen's LLM supports image content in messages */
queen_supports_images?: boolean;
/** Present in 409 conflict responses when worker is still loading */
loading?: boolean;
}
@@ -324,6 +326,7 @@ export type EventTypeName =
| "node_retry"
| "edge_traversed"
| "context_compacted"
| "context_usage_updated"
| "webhook_received"
| "custom"
| "escalation_requested"
@@ -337,7 +340,8 @@ export type EventTypeName =
| "trigger_activated"
| "trigger_deactivated"
| "trigger_fired"
| "trigger_removed";
| "trigger_removed"
| "trigger_updated";
export interface AgentEvent {
type: EventTypeName;
+493 -109
View File
@@ -1,8 +1,32 @@
import { memo, useState, useRef, useEffect } from "react";
import { Send, Square, Crown, Cpu, Check, Loader2 } from "lucide-react";
import { memo, useState, useRef, useEffect, useMemo } from "react";
import {
Send,
Square,
Crown,
Cpu,
Check,
Loader2,
Paperclip,
X,
} from "lucide-react";
export interface ImageContent {
type: "image_url";
image_url: { url: string };
}
export interface ContextUsageEntry {
usagePct: number;
messageCount: number;
estimatedTokens: number;
maxTokens: number;
}
import MarkdownContent from "@/components/MarkdownContent";
import QuestionWidget from "@/components/QuestionWidget";
import MultiQuestionWidget from "@/components/MultiQuestionWidget";
import ParallelSubagentBubble, {
type SubagentGroup,
} from "@/components/ParallelSubagentBubble";
export interface ChatMessage {
id: string;
@@ -10,7 +34,13 @@ export interface ChatMessage {
agentColor: string;
content: string;
timestamp: string;
type?: "system" | "agent" | "user" | "tool_status" | "worker_input_request" | "run_divider";
type?:
| "system"
| "agent"
| "user"
| "tool_status"
| "worker_input_request"
| "run_divider";
role?: "queen" | "worker";
/** Which worker thread this message belongs to (worker agent name) */
thread?: string;
@@ -18,11 +48,17 @@ export interface ChatMessage {
createdAt?: number;
/** Queen phase active when this message was created */
phase?: "planning" | "building" | "staging" | "running";
/** Images attached to a user message */
images?: ImageContent[];
/** Backend node_id that produced this message — used for subagent grouping */
nodeId?: string;
/** Backend execution_id for this message */
executionId?: string;
}
interface ChatPanelProps {
messages: ChatMessage[];
onSend: (message: string, thread: string) => void;
onSend: (message: string, thread: string, images?: ImageContent[]) => void;
isWaiting?: boolean;
/** When true a worker is thinking (not yet streaming) */
isWorkerWaiting?: boolean;
@@ -31,6 +67,8 @@ interface ChatPanelProps {
activeThread: string;
/** When true, the input is disabled (e.g. during loading) */
disabled?: boolean;
/** When false, the image attach button is hidden (model lacks vision support) */
supportsImages?: boolean;
/** Called when user clicks the stop button to cancel the queen's current turn */
onCancel?: () => void;
/** Pending question from ask_user — replaces textarea when present */
@@ -38,7 +76,9 @@ interface ChatPanelProps {
/** Options for the pending question */
pendingOptions?: string[] | null;
/** Multiple questions from ask_user_multiple */
pendingQuestions?: { id: string; prompt: string; options?: string[] }[] | null;
pendingQuestions?:
| { id: string; prompt: string; options?: string[] }[]
| null;
/** Called when user submits an answer to the pending question */
onQuestionSubmit?: (answer: string, isOther: boolean) => void;
/** Called when user submits answers to multiple questions */
@@ -47,6 +87,8 @@ interface ChatPanelProps {
onQuestionDismiss?: () => void;
/** Queen operating phase — shown as a tag on queen messages */
queenPhase?: "planning" | "building" | "staging" | "running";
/** Context window usage for queen and workers */
contextUsage?: Record<string, ContextUsageEntry>;
}
const queenColor = "hsl(45,95%,58%)";
@@ -72,7 +114,8 @@ const TOOL_HEX = [
function toolHex(name: string): string {
let hash = 0;
for (let i = 0; i < name.length; i++) hash = (hash * 31 + name.charCodeAt(i)) | 0;
for (let i = 0; i < name.length; i++)
hash = (hash * 31 + name.charCodeAt(i)) | 0;
return TOOL_HEX[Math.abs(hash) % TOOL_HEX.length];
}
@@ -120,12 +163,18 @@ function ToolActivityRow({ content }: { content: string }) {
<span
key={`run-${p.name}`}
className="inline-flex items-center gap-1 text-[11px] px-2.5 py-0.5 rounded-full"
style={{ color: hex, backgroundColor: `${hex}18`, border: `1px solid ${hex}35` }}
style={{
color: hex,
backgroundColor: `${hex}18`,
border: `1px solid ${hex}35`,
}}
>
<Loader2 className="w-2.5 h-2.5 animate-spin" />
{p.name}
{p.count > 1 && (
<span className="text-[10px] font-medium opacity-70">×{p.count}</span>
<span className="text-[10px] font-medium opacity-70">
×{p.count}
</span>
)}
</span>
);
@@ -136,7 +185,11 @@ function ToolActivityRow({ content }: { content: string }) {
<span
key={`done-${p.name}`}
className="inline-flex items-center gap-1 text-[11px] px-2.5 py-0.5 rounded-full"
style={{ color: hex, backgroundColor: `${hex}18`, border: `1px solid ${hex}35` }}
style={{
color: hex,
backgroundColor: `${hex}18`,
border: `1px solid ${hex}35`,
}}
>
<Check className="w-2.5 h-2.5" />
{p.name}
@@ -151,109 +204,249 @@ function ToolActivityRow({ content }: { content: string }) {
);
}
const MessageBubble = memo(function MessageBubble({ msg, queenPhase }: { msg: ChatMessage; queenPhase?: "planning" | "building" | "staging" | "running" }) {
const isUser = msg.type === "user";
const isQueen = msg.role === "queen";
const color = getColor(msg.agent, msg.role);
const MessageBubble = memo(
function MessageBubble({
msg,
queenPhase,
}: {
msg: ChatMessage;
queenPhase?: "planning" | "building" | "staging" | "running";
}) {
const isUser = msg.type === "user";
const isQueen = msg.role === "queen";
const color = getColor(msg.agent, msg.role);
if (msg.type === "run_divider") {
return (
<div className="flex items-center gap-3 py-2 my-1">
<div className="flex-1 h-px bg-border/60" />
<span className="text-[10px] text-muted-foreground font-medium uppercase tracking-wider">
{msg.content}
</span>
<div className="flex-1 h-px bg-border/60" />
</div>
);
}
if (msg.type === "system") {
return (
<div className="flex justify-center py-1">
<span className="text-[11px] text-muted-foreground bg-muted/60 px-3 py-1.5 rounded-full">
{msg.content}
</span>
</div>
);
}
if (msg.type === "tool_status") {
return <ToolActivityRow content={msg.content} />;
}
if (isUser) {
return (
<div className="flex justify-end">
<div className="max-w-[75%] bg-primary text-primary-foreground text-sm leading-relaxed rounded-2xl rounded-br-md px-4 py-3">
<p className="whitespace-pre-wrap break-words">{msg.content}</p>
if (msg.type === "run_divider") {
return (
<div className="flex items-center gap-3 py-2 my-1">
<div className="flex-1 h-px bg-border/60" />
<span className="text-[10px] text-muted-foreground font-medium uppercase tracking-wider">
{msg.content}
</span>
<div className="flex-1 h-px bg-border/60" />
</div>
</div>
);
}
);
}
return (
<div className="flex gap-3">
<div
className={`flex-shrink-0 ${isQueen ? "w-9 h-9" : "w-7 h-7"} rounded-xl flex items-center justify-center`}
style={{
backgroundColor: `${color}18`,
border: `1.5px solid ${color}35`,
boxShadow: isQueen ? `0 0 12px ${color}20` : undefined,
}}
>
{isQueen ? (
<Crown className="w-4 h-4" style={{ color }} />
) : (
<Cpu className="w-3.5 h-3.5" style={{ color }} />
)}
</div>
<div className={`flex-1 min-w-0 ${isQueen ? "max-w-[85%]" : "max-w-[75%]"}`}>
<div className="flex items-center gap-2 mb-1">
<span className={`font-medium ${isQueen ? "text-sm" : "text-xs"}`} style={{ color }}>
{msg.agent}
</span>
<span
className={`text-[10px] font-medium px-1.5 py-0.5 rounded-md ${
isQueen ? "bg-primary/15 text-primary" : "bg-muted text-muted-foreground"
}`}
>
{isQueen
? ((msg.phase ?? queenPhase) === "running"
? "running"
: (msg.phase ?? queenPhase) === "staging"
? "staging"
: (msg.phase ?? queenPhase) === "planning"
? "planning"
: "building")
: "Worker"}
if (msg.type === "system") {
return (
<div className="flex justify-center py-1">
<span className="text-[11px] text-muted-foreground bg-muted/60 px-3 py-1.5 rounded-full">
{msg.content}
</span>
</div>
);
}
if (msg.type === "tool_status") {
return <ToolActivityRow content={msg.content} />;
}
if (isUser) {
return (
<div className="flex justify-end">
<div className="max-w-[75%] bg-primary text-primary-foreground text-sm leading-relaxed rounded-2xl rounded-br-md px-4 py-3">
{msg.images && msg.images.length > 0 && (
<div className="flex flex-wrap gap-2 mb-2">
{msg.images.map((img, i) => (
<img
key={i}
src={img.image_url.url}
alt={`attachment ${i + 1}`}
className="max-h-48 max-w-full rounded-lg object-contain"
/>
))}
</div>
)}
{msg.content && (
<p className="whitespace-pre-wrap break-words">{msg.content}</p>
)}
</div>
</div>
);
}
return (
<div className="flex gap-3">
<div
className={`flex-shrink-0 ${isQueen ? "w-9 h-9" : "w-7 h-7"} rounded-xl flex items-center justify-center`}
style={{
backgroundColor: `${color}18`,
border: `1.5px solid ${color}35`,
boxShadow: isQueen ? `0 0 12px ${color}20` : undefined,
}}
>
{isQueen ? (
<Crown className="w-4 h-4" style={{ color }} />
) : (
<Cpu className="w-3.5 h-3.5" style={{ color }} />
)}
</div>
<div
className={`text-sm leading-relaxed rounded-2xl rounded-tl-md px-4 py-3 ${
isQueen ? "border border-primary/20 bg-primary/5" : "bg-muted/60"
}`}
className={`flex-1 min-w-0 ${isQueen ? "max-w-[85%]" : "max-w-[75%]"}`}
>
<MarkdownContent content={msg.content} />
<div className="flex items-center gap-2 mb-1">
<span
className={`font-medium ${isQueen ? "text-sm" : "text-xs"}`}
style={{ color }}
>
{msg.agent}
</span>
<span
className={`text-[10px] font-medium px-1.5 py-0.5 rounded-md ${
isQueen
? "bg-primary/15 text-primary"
: "bg-muted text-muted-foreground"
}`}
>
{isQueen
? (msg.phase ?? queenPhase) === "running"
? "running"
: (msg.phase ?? queenPhase) === "staging"
? "staging"
: (msg.phase ?? queenPhase) === "planning"
? "planning"
: "building"
: "Worker"}
</span>
</div>
<div
className={`text-sm leading-relaxed rounded-2xl rounded-tl-md px-4 py-3 ${
isQueen ? "border border-primary/20 bg-primary/5" : "bg-muted/60"
}`}
>
<MarkdownContent content={msg.content} />
</div>
</div>
</div>
</div>
);
}, (prev, next) => prev.msg.id === next.msg.id && prev.msg.content === next.msg.content && prev.msg.phase === next.msg.phase && prev.queenPhase === next.queenPhase);
);
},
(prev, next) =>
prev.msg.id === next.msg.id &&
prev.msg.content === next.msg.content &&
prev.msg.phase === next.msg.phase &&
prev.queenPhase === next.queenPhase,
);
export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting, isBusy, activeThread, disabled, onCancel, pendingQuestion, pendingOptions, pendingQuestions, onQuestionSubmit, onMultiQuestionSubmit, onQuestionDismiss, queenPhase }: ChatPanelProps) {
export default function ChatPanel({
messages,
onSend,
isWaiting,
isWorkerWaiting,
isBusy,
activeThread,
disabled,
onCancel,
pendingQuestion,
pendingOptions,
pendingQuestions,
onQuestionSubmit,
onMultiQuestionSubmit,
onQuestionDismiss,
queenPhase,
contextUsage,
supportsImages = true,
}: ChatPanelProps) {
const [input, setInput] = useState("");
const [pendingImages, setPendingImages] = useState<ImageContent[]>([]);
const [readMap, setReadMap] = useState<Record<string, number>>({});
const bottomRef = useRef<HTMLDivElement>(null);
const scrollRef = useRef<HTMLDivElement>(null);
const stickToBottom = useRef(true);
const textareaRef = useRef<HTMLTextAreaElement>(null);
const fileInputRef = useRef<HTMLInputElement>(null);
const threadMessages = messages.filter((m) => {
if (m.type === "system" && !m.thread) return false;
return m.thread === activeThread;
if (m.thread !== activeThread) return false;
// Hide queen messages whose content is whitespace-only — these are
// tool-use-only turns that have no visible text. During live operation
// tool pills provide context, but on resume the pills are gone so
// the empty bubble is meaningless.
if (m.role === "queen" && !m.type && (!m.content || !m.content.trim()))
return false;
return true;
});
// Group subagent messages into parallel bubbles.
// A subagent message has nodeId containing ":subagent:".
// The run only ends on hard boundaries (user messages, run_dividers)
// so interleaved queen/tool/system messages don't fragment the bubble.
type RenderItem =
| { kind: "message"; msg: ChatMessage }
| { kind: "parallel"; groupId: string; groups: SubagentGroup[] };
const renderItems = useMemo<RenderItem[]>(() => {
const items: RenderItem[] = [];
let i = 0;
while (i < threadMessages.length) {
const msg = threadMessages[i];
const isSubagent = msg.nodeId?.includes(":subagent:");
if (!isSubagent) {
items.push({ kind: "message", msg });
i++;
continue;
}
// Start a subagent run. Collect all subagent messages, allowing
// non-subagent messages in between (they render as normal items
// before the bubble). Only break on hard boundaries.
const subagentMsgs: ChatMessage[] = [];
const interleaved: { idx: number; msg: ChatMessage }[] = [];
const firstId = msg.id;
while (i < threadMessages.length) {
const m = threadMessages[i];
const isSa = m.nodeId?.includes(":subagent:");
if (isSa) {
subagentMsgs.push(m);
i++;
continue;
}
// Hard boundary — stop the run
if (m.type === "user" || m.type === "run_divider") break;
// Worker message from a non-subagent node means the graph has
// moved on to the next stage. Close the bubble even if some
// subagents are still streaming in the background.
if (m.role === "worker" && m.nodeId && !m.nodeId.includes(":subagent:"))
break;
// Soft interruption (queen output, system, tool_status without
// nodeId) — render it normally but keep the subagent run going
interleaved.push({ idx: items.length + interleaved.length, msg: m });
i++;
}
// Emit interleaved messages first (before the bubble)
for (const { msg: im } of interleaved) {
items.push({ kind: "message", msg: im });
}
// Build the single parallel bubble from all collected subagent msgs
if (subagentMsgs.length > 0) {
const byNode = new Map<string, ChatMessage[]>();
for (const m of subagentMsgs) {
const nid = m.nodeId!;
if (!byNode.has(nid)) byNode.set(nid, []);
byNode.get(nid)!.push(m);
}
const groups: SubagentGroup[] = [];
for (const [nodeId, msgs] of byNode) {
groups.push({
nodeId,
messages: msgs,
contextUsage: contextUsage?.[nodeId],
});
}
items.push({ kind: "parallel", groupId: `par-${firstId}`, groups });
}
}
return items;
}, [threadMessages, contextUsage]);
// Mark current thread as read
useEffect(() => {
const count = messages.filter((m) => m.thread === activeThread).length;
@@ -284,26 +477,64 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
const handleSubmit = (e: React.FormEvent) => {
e.preventDefault();
if (!input.trim()) return;
onSend(input.trim(), activeThread);
if (!input.trim() && pendingImages.length === 0) return;
onSend(
input.trim(),
activeThread,
pendingImages.length > 0 ? pendingImages : undefined,
);
setInput("");
setPendingImages([]);
if (textareaRef.current) textareaRef.current.style.height = "auto";
};
const handleFileChange = (e: React.ChangeEvent<HTMLInputElement>) => {
const files = Array.from(e.target.files ?? []);
if (files.length === 0) return;
files.forEach((file) => {
const reader = new FileReader();
reader.onload = (ev) => {
const url = ev.target?.result as string;
setPendingImages((prev) => [
...prev,
{ type: "image_url", image_url: { url } },
]);
};
reader.readAsDataURL(file);
});
// Reset so the same file can be re-selected
e.target.value = "";
};
return (
<div className="flex flex-col h-full min-w-0">
{/* Compact sub-header */}
<div className="px-5 pt-4 pb-2 flex items-center gap-2">
<p className="text-[11px] text-muted-foreground font-medium uppercase tracking-wider">Conversation</p>
<p className="text-[11px] text-muted-foreground font-medium uppercase tracking-wider">
Conversation
</p>
</div>
{/* Messages */}
<div ref={scrollRef} onScroll={handleScroll} className="flex-1 overflow-auto px-5 py-4 space-y-3">
{threadMessages.map((msg) => (
<div key={msg.id}>
<MessageBubble msg={msg} queenPhase={queenPhase} />
</div>
))}
<div
ref={scrollRef}
onScroll={handleScroll}
className="flex-1 overflow-auto px-5 py-4 space-y-3"
>
{renderItems.map((item) =>
item.kind === "parallel" ? (
<div key={item.groupId}>
<ParallelSubagentBubble
groupId={item.groupId}
groups={item.groups}
/>
</div>
) : (
<div key={item.msg.id}>
<MessageBubble msg={item.msg} queenPhase={queenPhase} />
</div>
),
)}
{/* Show typing indicator while waiting for first queen response (disabled + empty chat) */}
{(isWaiting || (disabled && threadMessages.length === 0)) && (
@@ -320,9 +551,18 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
</div>
<div className="border border-primary/20 bg-primary/5 rounded-2xl rounded-tl-md px-4 py-3">
<div className="flex gap-1.5">
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce" style={{ animationDelay: "0ms" }} />
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce" style={{ animationDelay: "150ms" }} />
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce" style={{ animationDelay: "300ms" }} />
<span
className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce"
style={{ animationDelay: "0ms" }}
/>
<span
className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce"
style={{ animationDelay: "150ms" }}
/>
<span
className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce"
style={{ animationDelay: "300ms" }}
/>
</div>
</div>
</div>
@@ -340,9 +580,18 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
</div>
<div className="bg-muted/60 rounded-2xl rounded-tl-md px-4 py-3">
<div className="flex gap-1.5">
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce" style={{ animationDelay: "0ms" }} />
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce" style={{ animationDelay: "150ms" }} />
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce" style={{ animationDelay: "300ms" }} />
<span
className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce"
style={{ animationDelay: "0ms" }}
/>
<span
className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce"
style={{ animationDelay: "150ms" }}
/>
<span
className="w-1.5 h-1.5 rounded-full bg-muted-foreground animate-bounce"
style={{ animationDelay: "300ms" }}
/>
</div>
</div>
</div>
@@ -350,8 +599,99 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
<div ref={bottomRef} />
</div>
{/* Context window usage bar — sits between messages and input */}
{(() => {
if (!contextUsage) return null;
const queenUsage = contextUsage["__queen__"];
const workerEntries = Object.entries(contextUsage).filter(
([k]) => k !== "__queen__",
);
const workerUsage =
workerEntries.length > 0
? workerEntries.reduce(
(best, [, v]) => (v.usagePct > best.usagePct ? v : best),
workerEntries[0][1],
)
: undefined;
if (!queenUsage && !workerUsage) return null;
return (
<div className="flex items-center gap-3 mx-4 px-3 py-1 rounded-lg bg-muted/30 border border-border/20 group/ctx flex-shrink-0">
{queenUsage && (
<div
className="flex items-center gap-2 flex-1 min-w-0"
title={`Queen: ${(queenUsage.estimatedTokens / 1000).toFixed(1)}k / ${(queenUsage.maxTokens / 1000).toFixed(0)}k tokens \u00b7 ${queenUsage.messageCount} messages`}
>
<Crown
className="w-3 h-3 flex-shrink-0"
style={{ color: "hsl(45,95%,58%)" }}
/>
<div className="flex-1 h-1.5 rounded-full bg-muted/50 overflow-hidden min-w-[60px]">
<div
className="h-full rounded-full transition-all duration-500 ease-out"
style={{
width: `${Math.min(queenUsage.usagePct, 100)}%`,
backgroundColor:
queenUsage.usagePct >= 90
? "hsl(0,65%,55%)"
: queenUsage.usagePct >= 70
? "hsl(35,90%,55%)"
: "hsl(45,95%,58%)",
}}
/>
</div>
<span className="text-[10px] text-muted-foreground/70 flex-shrink-0 tabular-nums">
<span className="group-hover/ctx:hidden">
{queenUsage.usagePct}%
</span>
<span className="hidden group-hover/ctx:inline">
{(queenUsage.estimatedTokens / 1000).toFixed(1)}k /{" "}
{(queenUsage.maxTokens / 1000).toFixed(0)}k
</span>
</span>
</div>
)}
{workerUsage && (
<div
className="flex items-center gap-2 flex-1 min-w-0"
title={`Worker: ${(workerUsage.estimatedTokens / 1000).toFixed(1)}k / ${(workerUsage.maxTokens / 1000).toFixed(0)}k tokens \u00b7 ${workerUsage.messageCount} messages`}
>
<Cpu
className="w-3 h-3 flex-shrink-0"
style={{ color: "hsl(220,60%,55%)" }}
/>
<div className="flex-1 h-1.5 rounded-full bg-muted/50 overflow-hidden min-w-[60px]">
<div
className="h-full rounded-full transition-all duration-500 ease-out"
style={{
width: `${Math.min(workerUsage.usagePct, 100)}%`,
backgroundColor:
workerUsage.usagePct >= 90
? "hsl(0,65%,55%)"
: workerUsage.usagePct >= 70
? "hsl(35,90%,55%)"
: "hsl(220,60%,55%)",
}}
/>
</div>
<span className="text-[10px] text-muted-foreground/70 flex-shrink-0 tabular-nums">
<span className="group-hover/ctx:hidden">
{workerUsage.usagePct}%
</span>
<span className="hidden group-hover/ctx:inline">
{(workerUsage.estimatedTokens / 1000).toFixed(1)}k /{" "}
{(workerUsage.maxTokens / 1000).toFixed(0)}k
</span>
</span>
</div>
)}
</div>
);
})()}
{/* Input area — question widget replaces textarea when a question is pending */}
{pendingQuestions && pendingQuestions.length >= 2 && onMultiQuestionSubmit ? (
{pendingQuestions &&
pendingQuestions.length >= 2 &&
onMultiQuestionSubmit ? (
<MultiQuestionWidget
questions={pendingQuestions}
onSubmit={onMultiQuestionSubmit}
@@ -366,7 +706,47 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
/>
) : (
<form onSubmit={handleSubmit} className="p-4">
{/* Image preview strip */}
{pendingImages.length > 0 && (
<div className="flex flex-wrap gap-2 mb-2 px-1">
{pendingImages.map((img, i) => (
<div key={i} className="relative group">
<img
src={img.image_url.url}
alt={`preview ${i + 1}`}
className="h-16 w-16 object-cover rounded-lg border border-border"
/>
<button
type="button"
onClick={() =>
setPendingImages((prev) => prev.filter((_, j) => j !== i))
}
className="absolute -top-1.5 -right-1.5 w-4 h-4 rounded-full bg-destructive text-destructive-foreground flex items-center justify-center opacity-0 group-hover:opacity-100 transition-opacity"
>
<X className="w-2.5 h-2.5" />
</button>
</div>
))}
</div>
)}
<div className="flex items-center gap-3 bg-muted/40 rounded-xl px-4 py-2.5 border border-border focus-within:border-primary/40 transition-colors">
<input
ref={fileInputRef}
type="file"
accept="image/*"
multiple
className="hidden"
onChange={handleFileChange}
/>
<button
type="button"
disabled={disabled || !supportsImages}
onClick={() => supportsImages && fileInputRef.current?.click()}
className="flex-shrink-0 p-1 rounded-md text-muted-foreground hover:text-foreground disabled:opacity-30 transition-colors"
title={supportsImages ? "Attach image" : "Image not supported by the current model"}
>
<Paperclip className="w-4 h-4" />
</button>
<textarea
ref={textareaRef}
rows={1}
@@ -383,7 +763,9 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
handleSubmit(e);
}
}}
placeholder={disabled ? "Connecting to agent..." : "Message Queen Bee..."}
placeholder={
disabled ? "Connecting to agent..." : "Message Queen Bee..."
}
disabled={disabled}
className="flex-1 bg-transparent text-sm text-foreground outline-none placeholder:text-muted-foreground disabled:opacity-50 disabled:cursor-not-allowed resize-none overflow-y-auto"
/>
@@ -398,7 +780,9 @@ export default function ChatPanel({ messages, onSend, isWaiting, isWorkerWaiting
) : (
<button
type="submit"
disabled={!input.trim() || disabled}
disabled={
(!input.trim() && pendingImages.length === 0) || disabled
}
className="p-2 rounded-lg bg-primary text-primary-foreground disabled:opacity-30 hover:opacity-90 transition-opacity"
>
<Send className="w-4 h-4" />
+161 -14
View File
@@ -3,11 +3,23 @@ import { Loader2 } from "lucide-react";
import type { DraftGraph as DraftGraphData, DraftNode } from "@/api/types";
import { RunButton } from "./RunButton";
import type { GraphNode, RunState } from "./graph-types";
import {
cssVar,
truncateLabel,
TRIGGER_ICONS,
ACTIVE_TRIGGER_COLORS,
useTriggerColors,
} from "@/lib/graphUtils";
// Read a CSS custom property value (space-separated HSL components)
function cssVar(name: string): string {
return getComputedStyle(document.documentElement).getPropertyValue(name).trim();
}
// ── Trigger layout constants ──
const TRIGGER_H = 38; // pill height
const TRIGGER_PILL_GAP_X = 16; // horizontal gap between multiple trigger pills
const TRIGGER_ICON_X = 16; // icon center offset from pill left edge
const TRIGGER_LABEL_X = 30; // label start offset from pill left edge
const TRIGGER_LABEL_INSET = 38; // icon + padding subtracted from pill width for label space
const TRIGGER_TEXT_Y = 11; // y-offset below pill for first text line (countdown or status)
const TRIGGER_TEXT_STEP = 11; // additional y-offset for second text line when countdown present
const TRIGGER_CLEARANCE = 30; // vertical space below pill for countdown + status text
interface DraftChromeColors {
edge: string;
@@ -107,13 +119,6 @@ function formatNodeId(id: string): string {
return id.split("-").map(w => w.charAt(0).toUpperCase() + w.slice(1)).join(" ");
}
function truncateLabel(label: string, availablePx: number, fontSize: number): string {
const avgCharW = fontSize * 0.58;
const maxChars = Math.floor(availablePx / avgCharW);
if (label.length <= maxChars) return label;
return label.slice(0, Math.max(maxChars - 1, 1)) + "\u2026";
}
/** Return the bounding-rect corner radius for a given flowchart shape. */
/**
* Render an ISO 5807 flowchart shape as an SVG element.
@@ -240,6 +245,13 @@ export default function DraftGraph({ draft, originalDraft, onNodeClick, flowchar
const runBtnRef = useRef<HTMLButtonElement>(null);
const [containerW, setContainerW] = useState(484);
const chrome = useDraftChromeColors();
const triggerColors = useTriggerColors();
// Extract trigger nodes from runtimeNodes
const triggerNodes = useMemo(
() => (runtimeNodes ?? []).filter(n => n.nodeType === "trigger"),
[runtimeNodes],
);
// ── Entrance animation — fires when originalDraft becomes a new non-null value ──
// This covers: agent loaded, build finished, queen modifies flowchart.
@@ -709,12 +721,17 @@ export default function DraftGraph({ draft, originalDraft, onNodeClick, flowchar
return { nodeYOffset: offsets, totalExtraY: totalExtra, groupBoxMaxX: maxGroupX };
}, [nodes, maxLayer, flowchartMap, idxMap, layers, nodeXPositions, nodeW]);
// When triggers are present, push the entire draft graph down to make room
const triggerOffsetY = triggerNodes.length > 0
? TRIGGER_H + TRIGGER_TEXT_Y + TRIGGER_TEXT_STEP + TRIGGER_CLEARANCE
: 0;
const nodePos = (i: number) => ({
x: nodeXPositions[i],
y: TOP_Y + layers[i] * (NODE_H + GAP_Y) + nodeYOffset[i],
y: TOP_Y + triggerOffsetY + layers[i] * (NODE_H + GAP_Y) + nodeYOffset[i],
});
const svgHeight = TOP_Y + (maxLayer + 1) * NODE_H + maxLayer * GAP_Y + totalExtraY + 16;
const svgHeight = TOP_Y + triggerOffsetY + (maxLayer + 1) * NODE_H + maxLayer * GAP_Y + totalExtraY + 16;
// Compute group areas for runtime node boundaries on the draft
const groupAreas = useMemo(() => {
@@ -847,6 +864,131 @@ export default function DraftGraph({ draft, originalDraft, onNodeClick, flowchar
pending: "",
};
// ── Trigger node rendering ──
const triggerW = Math.min(nodeW, 180);
// Shared trigger pill X position (used by both node and edge renderers)
const triggerPillX = (idx: number) => {
const totalW = triggerNodes.length * triggerW + (triggerNodes.length - 1) * TRIGGER_PILL_GAP_X;
return (containerW - totalW) / 2 + idx * (triggerW + TRIGGER_PILL_GAP_X);
};
const renderTriggerNode = (node: GraphNode, triggerIdx: number) => {
const icon = TRIGGER_ICONS[node.triggerType || ""] || "\u26A1";
const isActive = node.status === "running" || node.status === "complete";
const colors = isActive ? ACTIVE_TRIGGER_COLORS : triggerColors;
const nextFireIn = node.triggerConfig?.next_fire_in as number | undefined;
const tx = triggerPillX(triggerIdx);
const ty = TOP_Y;
const fontSize = triggerW < 140 ? 10.5 : 11.5;
const displayLabel = truncateLabel(node.label, triggerW - TRIGGER_LABEL_INSET, fontSize);
// Countdown
let countdownLabel: string | null = null;
if (isActive && nextFireIn != null && nextFireIn > 0) {
const h = Math.floor(nextFireIn / 3600);
const m = Math.floor((nextFireIn % 3600) / 60);
const s = Math.floor(nextFireIn % 60);
countdownLabel = h > 0
? `next in ${h}h ${String(m).padStart(2, "0")}m`
: `next in ${m}m ${String(s).padStart(2, "0")}s`;
}
const statusLabel = isActive ? "active" : "inactive";
const statusColor = isActive ? "hsl(140,40%,50%)" : "hsl(210,20%,40%)";
return (
<g
key={node.id}
onClick={() => onRuntimeNodeClick?.(node.id)}
style={{ cursor: onRuntimeNodeClick ? "pointer" : "default" }}
>
<title>{node.label}</title>
{/* Pill-shaped background */}
<rect
x={tx} y={ty}
width={triggerW} height={TRIGGER_H}
rx={TRIGGER_H / 2}
fill={colors.bg}
stroke={colors.border}
strokeWidth={isActive ? 1.5 : 1}
strokeDasharray={isActive ? undefined : "4 2"}
/>
{/* Icon */}
<text
x={tx + TRIGGER_ICON_X} y={ty + TRIGGER_H / 2}
fill={colors.icon} fontSize={13}
textAnchor="middle" dominantBaseline="middle"
>
{icon}
</text>
{/* Label */}
<text
x={tx + TRIGGER_LABEL_X} y={ty + TRIGGER_H / 2}
fill={colors.text}
fontSize={fontSize}
fontWeight={500}
dominantBaseline="middle"
letterSpacing="0.01em"
>
{displayLabel}
</text>
{/* Countdown */}
{countdownLabel && (
<text
x={tx + triggerW / 2} y={ty + TRIGGER_H + TRIGGER_TEXT_Y}
fill={colors.text} fontSize={9}
textAnchor="middle" fontStyle="italic" opacity={0.7}
>
{countdownLabel}
</text>
)}
{/* Status */}
<text
x={tx + triggerW / 2} y={ty + TRIGGER_H + (countdownLabel ? TRIGGER_TEXT_Y + TRIGGER_TEXT_STEP : TRIGGER_TEXT_Y)}
fill={statusColor} fontSize={8.5}
textAnchor="middle" opacity={0.8}
>
{statusLabel}
</text>
</g>
);
};
const renderTriggerEdge = (triggerIdx: number) => {
if (nodes.length === 0) return null;
const triggerNode = triggerNodes[triggerIdx];
const runtimeTargetId = triggerNode?.next?.[0];
const targetDraftId = runtimeTargetId
? flowchartMap?.[runtimeTargetId]?.[0] ?? runtimeTargetId
: draft?.entry_node;
const targetIdx = targetDraftId ? idxMap[targetDraftId] ?? 0 : 0;
const targetPos = nodePos(targetIdx);
const targetX = targetPos.x + nodeW / 2;
const targetY = targetPos.y;
const tx = triggerPillX(triggerIdx) + triggerW / 2;
const ty = TOP_Y + TRIGGER_H + TRIGGER_TEXT_Y + TRIGGER_TEXT_STEP + 4;
const midY = (ty + targetY) / 2;
const d = Math.abs(tx - targetX) < 2
? `M ${tx} ${ty} L ${targetX} ${targetY}`
: `M ${tx} ${ty} L ${tx} ${midY} L ${targetX} ${midY} L ${targetX} ${targetY}`;
return (
<g key={`trigger-edge-${triggerIdx}`}>
<path d={d} fill="none" stroke={chrome.edge} strokeWidth={1.2} strokeDasharray="4 3" />
<polygon
points={`${targetX - 3},${targetY - 5} ${targetX + 3},${targetY - 5} ${targetX},${targetY - 1}`}
fill={chrome.edgeArrow}
/>
</g>
);
};
const renderNode = (node: DraftNode, i: number) => {
const pos = nodePos(i);
const isHovered = hoveredNode === node.id;
@@ -994,7 +1136,7 @@ export default function DraftGraph({ draft, originalDraft, onNodeClick, flowchar
>
<svg
width="100%"
viewBox={`0 0 ${Math.max((maxContentRight ?? 0), groupBoxMaxX) + (backEdgeOverflow ?? 0)} ${totalH}`}
viewBox={`0 0 ${Math.max((maxContentRight ?? 0), groupBoxMaxX, triggerNodes.length > 0 ? triggerPillX(triggerNodes.length - 1) + triggerW : 0) + (backEdgeOverflow ?? 0)} ${totalH}`}
preserveAspectRatio="xMidYMin meet"
className="select-none"
style={{
@@ -1078,6 +1220,11 @@ export default function DraftGraph({ draft, originalDraft, onNodeClick, flowchar
);
})}
{/* Trigger edges (dashed lines from trigger pills to first draft node) */}
{triggerNodes.map((_, i) => renderTriggerEdge(i))}
{/* Trigger pill nodes */}
{triggerNodes.map((tn, i) => renderTriggerNode(tn, i))}
{forwardEdges.map((e, i) => renderEdge(e, i))}
{backEdges.map((e, i) => renderBackEdge(e, i))}
{nodes.map((n, i) => renderNode(n, i))}
@@ -28,6 +28,13 @@ export interface SubagentReport {
status?: "running" | "complete" | "error";
}
interface ContextUsage {
usagePct: number;
messageCount: number;
estimatedTokens: number;
maxTokens: number;
}
interface NodeDetailPanelProps {
node: GraphNode | null;
nodeSpec?: NodeSpec | null;
@@ -38,6 +45,7 @@ interface NodeDetailPanelProps {
workerSessionId?: string | null;
nodeLogs?: string[];
actionPlan?: string;
contextUsage?: ContextUsage;
onClose: () => void;
}
@@ -309,7 +317,7 @@ const tabs: { id: Tab; label: string; Icon: React.FC<{ className?: string }> }[]
{ id: "subagents", label: "Subagents", Icon: ({ className }) => <Bot className={className} /> },
];
export default function NodeDetailPanel({ node, nodeSpec, allNodeSpecs, subagentReports, sessionId, graphId, workerSessionId, nodeLogs, actionPlan, onClose }: NodeDetailPanelProps) {
export default function NodeDetailPanel({ node, nodeSpec, allNodeSpecs, subagentReports, sessionId, graphId, workerSessionId, nodeLogs, actionPlan, contextUsage, onClose }: NodeDetailPanelProps) {
const [activeTab, setActiveTab] = useState<Tab>("overview");
const [realTools, setRealTools] = useState<ToolInfo[] | null>(null);
const [realCriteria, setRealCriteria] = useState<NodeCriteria | null>(null);
@@ -389,6 +397,43 @@ export default function NodeDetailPanel({ node, nodeSpec, allNodeSpecs, subagent
</div>
)}
{/* Context window usage */}
{contextUsage && (
<div className="px-4 py-2 border-b border-border/20 flex-shrink-0">
<div className="flex items-center gap-2 mb-1">
<span className="text-[10px] text-muted-foreground font-medium">Context</span>
<span className="text-[10px] text-muted-foreground/70 ml-auto">
{(contextUsage.estimatedTokens / 1000).toFixed(1)}k / {(contextUsage.maxTokens / 1000).toFixed(0)}k tokens
</span>
</div>
<div className="w-full h-1.5 rounded-full bg-muted/50 overflow-hidden">
<div
className="h-full rounded-full transition-all duration-500 ease-out"
style={{
width: `${Math.min(contextUsage.usagePct, 100)}%`,
backgroundColor: contextUsage.usagePct >= 90
? "hsl(0,65%,55%)"
: contextUsage.usagePct >= 70
? "hsl(35,90%,55%)"
: "hsl(45,95%,58%)",
}}
/>
</div>
<div className="flex items-center gap-2 mt-1">
<span className="text-[10px] text-muted-foreground/60">{contextUsage.messageCount} messages</span>
<span className="text-[10px] font-medium ml-auto" style={{
color: contextUsage.usagePct >= 90
? "hsl(0,65%,55%)"
: contextUsage.usagePct >= 70
? "hsl(35,90%,55%)"
: "hsl(45,95%,58%)",
}}>
{contextUsage.usagePct}%
</span>
</div>
</div>
)}
{/* Tab bar */}
<div className="flex border-b border-border/30 flex-shrink-0 px-2 pt-1 overflow-x-auto scrollbar-hide">
{tabs.filter(t => t.id !== "subagents" || (nodeSpec?.sub_agents && nodeSpec.sub_agents.length > 0)).map(tab => (
@@ -0,0 +1,413 @@
import { memo, useState, useRef, useEffect } from "react";
import { ChevronDown, ChevronUp, Cpu } from "lucide-react";
import type { ChatMessage, ContextUsageEntry } from "@/components/ChatPanel";
import MarkdownContent from "@/components/MarkdownContent";
// ---------------------------------------------------------------------------
// Shared helpers
// ---------------------------------------------------------------------------
const workerColor = "hsl(220,60%,55%)";
const SUBAGENT_COLORS = [
"hsl(220,60%,55%)",
"hsl(260,50%,55%)",
"hsl(180,50%,45%)",
"hsl(30,70%,50%)",
"hsl(340,55%,50%)",
"hsl(150,45%,45%)",
"hsl(45,80%,50%)",
"hsl(290,45%,55%)",
];
function colorForIndex(i: number): string {
return SUBAGENT_COLORS[i % SUBAGENT_COLORS.length];
}
function subagentLabel(nodeId: string): string {
const parts = nodeId.split(":subagent:");
const raw = parts.length >= 2 ? parts[1] : nodeId;
return raw
.replace(/:\d+$/, "") // strip instance suffix like ":3"
.replace(/[_-]/g, " ")
.replace(/\b\w/g, (c) => c.toUpperCase())
.trim();
}
function last<T>(arr: T[]): T | undefined {
return arr[arr.length - 1];
}
export interface SubagentGroup {
nodeId: string;
messages: ChatMessage[];
contextUsage?: ContextUsageEntry;
}
interface ParallelSubagentBubbleProps {
groups: SubagentGroup[];
groupId: string;
}
// ---------------------------------------------------------------------------
// Thermometer — vertical context gauge on right edge of each pane
// ---------------------------------------------------------------------------
// ---------------------------------------------------------------------------
// Tool overlay — shown when a tool_status message is active (not all done)
// ---------------------------------------------------------------------------
function ToolOverlay({
toolName,
color,
visible,
}: {
toolName: string;
color: string;
visible: boolean;
}) {
return (
<div
className="absolute inset-0 top-[22px] flex items-center justify-center transition-opacity duration-200 z-10"
style={{
background: "rgba(8,8,14,0.82)",
opacity: visible ? 1 : 0,
pointerEvents: visible ? "auto" : "none",
}}
>
<div className="text-center px-3 py-2 rounded-md border" style={{ borderColor: `${color}40` }}>
<div className="text-[10px] font-medium" style={{ color }}>
{toolName}
</div>
<div className="text-[11px] mt-0.5" style={{ color }}>
{visible ? "..." : "\u2713"}
</div>
</div>
</div>
);
}
// ---------------------------------------------------------------------------
// Single tmux pane
// ---------------------------------------------------------------------------
function MuxPane({
group,
index,
label,
isFocused,
isZoomed,
onClickTitle,
}: {
group: SubagentGroup;
index: number;
label: string;
isFocused: boolean;
isZoomed: boolean;
onClickTitle: () => void;
}) {
const bodyRef = useRef<HTMLDivElement>(null);
const stickRef = useRef(true);
const color = colorForIndex(index);
const pct = group.contextUsage?.usagePct ?? 0;
const streamMsgs = group.messages.filter((m) => m.type !== "tool_status");
const latestContent = last(streamMsgs)?.content ?? "";
const msgCount = streamMsgs.length;
// Detect active tool and finished state from latest tool_status
const latestTool = last(
group.messages.filter((m) => m.type === "tool_status")
);
let activeToolName = "";
let toolRunning = false;
let isFinished = false;
if (latestTool) {
try {
const parsed = JSON.parse(latestTool.content);
const tools: { name: string; done: boolean }[] = parsed.tools || [];
const allDone = parsed.allDone as boolean | undefined;
const running = tools.find((t) => !t.done);
if (running) {
activeToolName = running.name;
toolRunning = true;
}
// Finished when all tools are done and one of them is set_output
// or report_to_parent (terminal tool calls)
if (allDone && tools.length > 0) {
const hasTerminal = tools.some(
(t) =>
t.done &&
(t.name === "set_output" || t.name === "report_to_parent")
);
if (hasTerminal) isFinished = true;
}
} catch {
/* ignore */
}
}
// Auto-scroll
useEffect(() => {
if (stickRef.current && bodyRef.current) {
bodyRef.current.scrollTop = bodyRef.current.scrollHeight;
}
}, [latestContent]);
const handleScroll = () => {
const el = bodyRef.current;
if (!el) return;
stickRef.current = el.scrollHeight - el.scrollTop - el.clientHeight < 30;
};
return (
<div
className="flex flex-col min-h-0 overflow-hidden relative transition-all duration-200"
style={{
borderWidth: 1,
borderStyle: "solid",
borderColor: isFocused && !isFinished ? `${color}60` : "transparent",
opacity: isFinished ? 0.4 : isFocused || isZoomed ? 1 : 0.55,
...(isZoomed
? { gridColumn: "1 / -1", gridRow: "1 / -1", zIndex: 10 }
: {}),
}}
>
{/* Title bar */}
<div
className="flex items-center gap-1.5 px-2 py-[3px] flex-shrink-0 cursor-pointer select-none"
style={{ background: "#0e0e16", borderBottom: "1px solid #1a1a2a" }}
onClick={onClickTitle}
>
{isFinished ? (
<span className="text-[8px] flex-shrink-0 leading-none" style={{ color: "#4a4" }}>&#10003;</span>
) : (
<div
className="w-[6px] h-[6px] rounded-full flex-shrink-0"
style={{ background: color }}
/>
)}
<span className="text-[9px] flex-shrink-0" style={{ color: isFinished ? "#555" : color }}>
{label}
</span>
<span className="flex-1" />
<span className="text-[8px] tabular-nums flex-shrink-0" style={{ color: "#555" }}>
{msgCount}
</span>
<div
className="w-[36px] h-[3px] rounded-full overflow-hidden flex-shrink-0"
style={{ background: "#1a1a2a" }}
>
<div
className="h-full rounded-full transition-all duration-500"
style={{
width: `${Math.min(pct, 100)}%`,
backgroundColor:
pct >= 80 ? "hsl(0,65%,55%)" : pct >= 50 ? "hsl(35,90%,55%)" : color,
}}
/>
</div>
<span className="text-[8px] tabular-nums flex-shrink-0" style={{ color: "#555" }}>
{pct}%
</span>
</div>
{/* Body */}
<div
ref={bodyRef}
onScroll={handleScroll}
className="flex-1 min-h-0 overflow-y-auto px-2 py-1 text-[10px] leading-[1.7]"
style={{ background: "#08080e", color: "#555", fontFamily: "monospace" }}
>
{latestContent ? (
<div style={{ color: "#ccc" }}>
<MarkdownContent content={latestContent} />
</div>
) : (
<span style={{ color: "#333" }}>waiting...</span>
)}
{/* Blinking cursor — hidden when finished */}
{!isFinished && (
<span
className="inline-block w-[6px] h-[11px] align-middle ml-0.5"
style={{
background: color,
animation: "cursorBlink 1s step-end infinite",
}}
/>
)}
</div>
{/* Tool overlay */}
<ToolOverlay
toolName={activeToolName}
color={color}
visible={toolRunning}
/>
</div>
);
}
// ---------------------------------------------------------------------------
// Main component
// ---------------------------------------------------------------------------
const ParallelSubagentBubble = memo(
function ParallelSubagentBubble({ groups }: ParallelSubagentBubbleProps) {
const [expanded, setExpanded] = useState(false);
const [zoomedIdx, setZoomedIdx] = useState<number | null>(null);
// Labels with instance numbers for duplicates
const labels: string[] = (() => {
const countByBase = new Map<string, number>();
const bases = groups.map((g) => subagentLabel(g.nodeId));
for (const b of bases)
countByBase.set(b, (countByBase.get(b) ?? 0) + 1);
const idxByBase = new Map<string, number>();
return bases.map((b) => {
if ((countByBase.get(b) ?? 1) <= 1) return b;
const idx = (idxByBase.get(b) ?? 0) + 1;
idxByBase.set(b, idx);
return `${b} #${idx}`;
});
})();
// Latest-active pane
const latestIdx = groups.reduce<number>((best, g, i) => {
const filtered = g.messages.filter((m) => m.type !== "tool_status");
const lm = last(filtered);
if (!lm) return best;
if (best < 0) return i;
const bm = last(
groups[best].messages.filter((m) => m.type !== "tool_status")
);
if (!bm) return i;
return (lm.createdAt ?? 0) >= (bm.createdAt ?? 0) ? i : best;
}, -1);
// Per-group finished detection (same logic as MuxPane)
const finishedFlags = groups.map((g) => {
const lt = last(g.messages.filter((m) => m.type === "tool_status"));
if (!lt) return false;
try {
const p = JSON.parse(lt.content);
const tools: { name: string; done: boolean }[] = p.tools || [];
if (!p.allDone || tools.length === 0) return false;
return tools.some(
(t) => t.done && (t.name === "set_output" || t.name === "report_to_parent")
);
} catch { return false; }
});
const activeCount = finishedFlags.filter((f) => !f).length;
if (groups.length === 0) return null;
// Grid sizing: 2 columns, auto rows capped at a fixed height
const rows = Math.ceil(groups.length / 2);
const gridHeight = expanded
? Math.min(rows * 200, 480)
: Math.min(rows * 100, 240);
return (
<div className="flex gap-3">
{/* Left icon */}
<div
className="flex-shrink-0 w-7 h-7 rounded-xl flex items-center justify-center mt-1"
style={{
backgroundColor: `${workerColor}18`,
border: `1.5px solid ${workerColor}35`,
}}
>
<Cpu className="w-3.5 h-3.5" style={{ color: workerColor }} />
</div>
<div className="flex-1 min-w-0 max-w-[90%]">
{/* Header */}
<div className="flex items-center gap-2 mb-1">
<span className="font-medium text-xs" style={{ color: workerColor }}>
{groups.length === 1 ? "Sub-agent" : "Parallel Agents"}
</span>
<span className="text-[10px] font-medium px-1.5 py-0.5 rounded-md bg-muted text-muted-foreground">
{activeCount > 0 ? `${activeCount} running` : `${groups.length} done`}
</span>
<button
onClick={() => {
setExpanded((v) => !v);
setZoomedIdx(null);
}}
className="ml-auto text-muted-foreground/60 hover:text-muted-foreground transition-colors p-0.5 rounded"
title={expanded ? "Collapse" : "Expand"}
>
{expanded ? (
<ChevronUp className="w-3.5 h-3.5" />
) : (
<ChevronDown className="w-3.5 h-3.5" />
)}
</button>
</div>
{/* Mux frame */}
<div
className="rounded-lg overflow-hidden"
style={{
border: "2px solid #1a1a2a",
background: "#08080e",
}}
>
{/* Grid */}
<div
className="grid gap-px"
style={{
gridTemplateColumns:
groups.length === 1 ? "1fr" : "1fr 1fr",
gridTemplateRows: `repeat(${rows}, 1fr)`,
height: gridHeight,
background: "#111",
}}
>
{groups.map((group, i) => (
<MuxPane
key={group.nodeId}
group={group}
index={i}
label={labels[i]}
isFocused={latestIdx === i}
isZoomed={zoomedIdx === i}
onClickTitle={() =>
setZoomedIdx(zoomedIdx === i ? null : i)
}
/>
))}
</div>
</div>
</div>
</div>
);
},
(prev, next) =>
prev.groupId === next.groupId &&
prev.groups.length === next.groups.length &&
prev.groups.every(
(g, i) =>
g.nodeId === next.groups[i].nodeId &&
g.messages.length === next.groups[i].messages.length &&
last(g.messages)?.content === last(next.groups[i].messages)?.content &&
g.contextUsage?.usagePct === next.groups[i].contextUsage?.usagePct
)
);
export default ParallelSubagentBubble;
// Injected as a global style (keyframes can't be inline)
if (typeof document !== "undefined") {
const id = "parallel-subagent-keyframes";
if (!document.getElementById(id)) {
const style = document.createElement("style");
style.id = id;
style.textContent = `
@keyframes cursorBlink { 0%, 100% { opacity: 1; } 50% { opacity: 0; } }
@keyframes thermoPulse { 0%, 100% { opacity: 1; } 50% { opacity: 0.4; } }
`;
document.head.appendChild(style);
}
}
+6 -2
View File
@@ -62,7 +62,7 @@ export function sseEventToChatMessage(
const innerSuffix = innerTurn != null && innerTurn > 0 ? `-t${innerTurn}` : "";
const snapshot = (event.data?.snapshot as string) || (event.data?.content as string) || "";
if (!snapshot) return null;
if (!snapshot.trim()) return null;
return {
id: `stream-${iterIdKey}${innerSuffix}-${event.node_id}`,
agent: agentDisplayName || event.node_id || "Agent",
@@ -72,6 +72,8 @@ export function sseEventToChatMessage(
role: "worker",
thread,
createdAt,
nodeId: event.node_id || undefined,
executionId: event.execution_id || undefined,
};
}
@@ -100,7 +102,7 @@ export function sseEventToChatMessage(
const llmInnerSuffix = llmInnerTurn != null && llmInnerTurn > 0 ? `-t${llmInnerTurn}` : "";
const snapshot = (event.data?.snapshot as string) || (event.data?.content as string) || "";
if (!snapshot) return null;
if (!snapshot.trim()) return null;
return {
id: `stream-${idKey}${llmInnerSuffix}-${event.node_id}`,
agent: event.node_id || "Agent",
@@ -110,6 +112,8 @@ export function sseEventToChatMessage(
role: "worker",
thread,
createdAt,
nodeId: event.node_id || undefined,
executionId: event.execution_id || undefined,
};
}
+88
View File
@@ -0,0 +1,88 @@
import { useEffect, useState } from "react";
// ── Shared graph utilities ──
// Common helpers used by both AgentGraph and DraftGraph.
// AgentGraph still has its own copies for now (separate cleanup PR).
/** Read a CSS custom property value (space-separated HSL components). */
export function cssVar(name: string): string {
return getComputedStyle(document.documentElement).getPropertyValue(name).trim();
}
/** Truncate label to fit within `availablePx` at the given fontSize. */
export function truncateLabel(label: string, availablePx: number, fontSize: number): string {
const avgCharW = fontSize * 0.58;
const maxChars = Math.floor(availablePx / avgCharW);
if (label.length <= maxChars) return label;
return label.slice(0, Math.max(maxChars - 1, 1)) + "\u2026";
}
// ── Trigger styling ──
export type TriggerColorSet = { bg: string; border: string; text: string; icon: string };
export function buildTriggerColors(): TriggerColorSet {
const bg = cssVar("--trigger-bg") || "210 25% 14%";
const border = cssVar("--trigger-border") || "210 30% 30%";
const text = cssVar("--trigger-text") || "210 30% 65%";
const icon = cssVar("--trigger-icon") || "210 40% 55%";
return {
bg: `hsl(${bg})`,
border: `hsl(${border})`,
text: `hsl(${text})`,
icon: `hsl(${icon})`,
};
}
export const ACTIVE_TRIGGER_COLORS: TriggerColorSet = {
bg: "hsl(210,30%,18%)",
border: "hsl(210,50%,50%)",
text: "hsl(210,40%,75%)",
icon: "hsl(210,60%,65%)",
};
export const TRIGGER_ICONS: Record<string, string> = {
webhook: "\u26A1", // lightning bolt
timer: "\u23F1", // stopwatch
api: "\u2192", // right arrow
event: "\u223F", // sine wave
};
/** Format a cron expression into a human-readable schedule label. */
export function cronToLabel(cron: string): string {
const parts = cron.trim().split(/\s+/);
if (parts.length !== 5) return cron;
const [min, hour, dom, mon, dow] = parts;
// */N * * * * -> "Every Nm"
if (min.startsWith("*/") && hour === "*" && dom === "*" && mon === "*" && dow === "*") {
return `Every ${min.slice(2)}m`;
}
// 0 */N * * * -> "Every Nh"
if (min === "0" && hour.startsWith("*/") && dom === "*" && mon === "*" && dow === "*") {
return `Every ${hour.slice(2)}h`;
}
// 0 H * * * -> "Daily at Ham/pm"
if (dom === "*" && mon === "*" && dow === "*" && !min.includes("*") && !hour.includes("*")) {
const h = parseInt(hour, 10);
const m = parseInt(min, 10);
const suffix = h >= 12 ? "PM" : "AM";
const h12 = h % 12 || 12;
return m === 0 ? `Daily at ${h12}${suffix}` : `Daily at ${h12}:${String(m).padStart(2, "0")}${suffix}`;
}
return cron;
}
/** Theme-reactive hook for inactive trigger colors. */
export function useTriggerColors(): TriggerColorSet {
const [colors, setColors] = useState<TriggerColorSet>(buildTriggerColors);
useEffect(() => {
const rebuild = () => setColors(buildTriggerColors());
const obs = new MutationObserver(rebuild);
obs.observe(document.documentElement, { attributes: true, attributeFilter: ["class", "style"] });
return () => obs.disconnect();
}, []);
return colors;
}
+8 -1
View File
@@ -27,7 +27,14 @@ export default function MyAgents() {
agentsApi
.discover()
.then((result) => {
setAgents(result["Your Agents"] || []);
const entries = result["Your Agents"] || [];
entries.sort((a, b) => {
if (!a.last_active && !b.last_active) return 0;
if (!a.last_active) return 1;
if (!b.last_active) return -1;
return b.last_active.localeCompare(a.last_active);
});
setAgents(entries);
})
.catch((err) => {
setError(err.message || "Failed to load agents");
+314 -33
View File
@@ -1,7 +1,7 @@
import { useState, useCallback, useRef, useEffect, useMemo } from "react";
import ReactDOM from "react-dom";
import { useSearchParams, useNavigate } from "react-router-dom";
import { Plus, KeyRound, Sparkles, Layers, ChevronLeft, Bot, Loader2, WifiOff, X } from "lucide-react";
import { Plus, KeyRound, Sparkles, Layers, ChevronLeft, Bot, Loader2, WifiOff, X, FolderOpen } from "lucide-react";
import type { GraphNode, NodeStatus } from "@/components/graph-types";
import DraftGraph from "@/components/DraftGraph";
import ChatPanel, { type ChatMessage } from "@/components/ChatPanel";
@@ -17,6 +17,7 @@ import { useMultiSSE } from "@/hooks/use-sse";
import type { LiveSession, AgentEvent, DiscoverEntry, NodeSpec, DraftGraph as DraftGraphData } from "@/api/types";
import { sseEventToChatMessage, formatAgentDisplayName } from "@/lib/chat-helpers";
import { topologyToGraphNodes } from "@/lib/graph-converter";
import { cronToLabel } from "@/lib/graphUtils";
import { ApiError } from "@/api/client";
const makeId = () => Math.random().toString(36).slice(2, 9);
@@ -251,6 +252,10 @@ function truncate(s: string, max: number): string {
type SessionRestoreResult = {
messages: ChatMessage[];
restoredPhase: "planning" | "building" | "staging" | "running" | null;
/** Last flowchart map from events — used to restore flowchart overlay on cold resume. */
flowchartMap: Record<string, string[]> | null;
/** Last original draft from events — used to restore flowchart overlay on cold resume. */
originalDraft: DraftGraphData | null;
};
/**
@@ -267,6 +272,8 @@ async function restoreSessionMessages(
if (events.length > 0) {
const messages: ChatMessage[] = [];
let runningPhase: ChatMessage["phase"] = undefined;
let flowchartMap: Record<string, string[]> | null = null;
let originalDraft: DraftGraphData | null = null;
for (const evt of events) {
// Track phase transitions so each message gets the phase it was created in
const p = evt.type === "queen_phase_changed" ? evt.data?.phase as string
@@ -275,6 +282,12 @@ async function restoreSessionMessages(
if (p && ["planning", "building", "staging", "running"].includes(p)) {
runningPhase = p as ChatMessage["phase"];
}
// Track last flowchart state for cold restore
if (evt.type === "flowchart_map_updated" && evt.data) {
const mapData = evt.data as { map?: Record<string, string[]>; original_draft?: DraftGraphData };
flowchartMap = mapData.map ?? null;
originalDraft = mapData.original_draft ?? null;
}
const msg = sseEventToChatMessage(evt, thread, agentDisplayName);
if (!msg) continue;
if (evt.stream_id === "queen") {
@@ -283,12 +296,12 @@ async function restoreSessionMessages(
}
messages.push(msg);
}
return { messages, restoredPhase: runningPhase ?? null };
return { messages, restoredPhase: runningPhase ?? null, flowchartMap, originalDraft };
}
} catch {
// Event log not available — session will start fresh.
}
return { messages: [], restoredPhase: null };
return { messages: [], restoredPhase: null, flowchartMap: null, originalDraft: null };
}
// --- Per-agent backend state (consolidated) ---
@@ -339,6 +352,10 @@ interface AgentBackendState {
pendingQuestions: { id: string; prompt: string; options?: string[] }[] | null;
/** Whether the pending question came from queen or worker */
pendingQuestionSource: "queen" | "worker" | null;
/** Per-node context window usage (from context_usage_updated events) */
contextUsage: Record<string, { usagePct: number; messageCount: number; estimatedTokens: number; maxTokens: number }>;
/** Whether the queen's LLM supports image content — false disables the attach button */
queenSupportsImages: boolean;
}
function defaultAgentState(): AgentBackendState {
@@ -376,6 +393,8 @@ function defaultAgentState(): AgentBackendState {
pendingOptions: null,
pendingQuestions: null,
pendingQuestionSource: null,
contextUsage: {},
queenSupportsImages: true,
};
}
@@ -557,7 +576,11 @@ export default function Workspace() {
const [dismissedBanner, setDismissedBanner] = useState<string | null>(null);
const [selectedNode, setSelectedNode] = useState<GraphNode | null>(null);
const [triggerTaskDraft, setTriggerTaskDraft] = useState("");
const [triggerCronDraft, setTriggerCronDraft] = useState("");
const [triggerTaskSaving, setTriggerTaskSaving] = useState(false);
const [triggerScheduleSaving, setTriggerScheduleSaving] = useState(false);
const [triggerCronSaved, setTriggerCronSaved] = useState(false);
const [triggerTaskSaved, setTriggerTaskSaved] = useState(false);
const [newTabOpen, setNewTabOpen] = useState(false);
const newTabBtnRef = useRef<HTMLButtonElement>(null);
const [graphPanelPct, setGraphPanelPct] = useState(30);
@@ -613,6 +636,10 @@ export default function Workspace() {
// it was created in (avoids stale-closure when phase change and message
// events arrive in the same React batch).
const queenPhaseRef = useRef<Record<string, string>>({});
// Accumulated queen text across inner_turns within the same iteration.
// Key: `${agentType}:${execution_id}:${iteration}`, value: { [inner_turn]: snapshot }.
// This lets us merge all inner_turn text into one chat bubble per iteration.
const queenIterTextRef = useRef<Record<string, Record<number, string>>>({});
// Timestamp when designingDraft was set — used to enforce minimum spinner duration.
const designingDraftSinceRef = useRef<Record<string, number>>({});
const designingDraftTimerRef = useRef<Record<string, ReturnType<typeof setTimeout>>>({});
@@ -794,6 +821,8 @@ export default function Workspace() {
}
let restoredPhase: "planning" | "building" | "staging" | "running" | null = null;
let restoredFlowchartMap: Record<string, string[]> | null = null;
let restoredOriginalDraft: DraftGraphData | null = null;
if (!liveSession) {
// Fetch conversation history from disk BEFORE creating the new session.
// SKIP if messages were already pre-populated by handleHistoryOpen.
@@ -805,9 +834,22 @@ export default function Workspace() {
const restored = await restoreSessionMessages(restoreFrom, agentType, "Queen Bee");
preRestoredMsgs.push(...restored.messages);
restoredPhase = restored.restoredPhase;
restoredFlowchartMap = restored.flowchartMap;
restoredOriginalDraft = restored.originalDraft;
} catch {
// Not available — will start fresh
}
} else if (restoreFrom && alreadyHasMessages) {
// Messages already cached in localStorage — still fetch events for
// non-message state (phase, flowchart) that isn't cached.
try {
const restored = await restoreSessionMessages(restoreFrom, agentType, "Queen Bee");
restoredPhase = restored.restoredPhase;
restoredFlowchartMap = restored.flowchartMap;
restoredOriginalDraft = restored.originalDraft;
} catch {
// Not critical — UI will still show cached messages
}
}
// Suppress the queen's intro cycle whenever we are about to restore a
@@ -830,7 +872,7 @@ export default function Workspace() {
}));
}
restoredMessageCount = preRestoredMsgs.length;
} else if (restoreFrom && activeId) {
} else if (restoreFrom && activeId && !alreadyHasMessages) {
// We had a stored session but no messages on disk — wipe stale localStorage cache
setSessionsByAgent(prev => ({
...prev,
@@ -884,6 +926,10 @@ export default function Workspace() {
queenReady: true,
queenPhase: qPhase,
queenBuilding: qPhase === "building",
queenSupportsImages: liveSession.queen_supports_images !== false,
// Restore flowchart overlay from persisted events
...(restoredFlowchartMap ? { flowchartMap: restoredFlowchartMap } : {}),
...(restoredOriginalDraft ? { originalDraft: restoredOriginalDraft, draftGraph: null } : {}),
});
} catch (err: unknown) {
const msg = err instanceof Error ? err.message : String(err);
@@ -958,6 +1004,8 @@ export default function Workspace() {
// Track the last queen phase seen in the event log for cold restore
let restoredPhase: "planning" | "building" | "staging" | "running" | null = null;
let restoredFlowchartMap: Record<string, string[]> | null = null;
let restoredOriginalDraft: DraftGraphData | null = null;
if (!liveSession) {
// Reconnect failed — clear stale cached messages from localStorage restore.
@@ -985,6 +1033,19 @@ export default function Workspace() {
const restored = await restoreSessionMessages(coldRestoreId, agentType, displayNameTemp);
preQueenMsgs = restored.messages;
restoredPhase = restored.restoredPhase;
restoredFlowchartMap = restored.flowchartMap;
restoredOriginalDraft = restored.originalDraft;
} else if (coldRestoreId && alreadyHasMessages) {
// Messages already cached — still fetch events for non-message state (phase, flowchart)
try {
const displayNameTemp = formatAgentDisplayName(agentPath);
const restored = await restoreSessionMessages(coldRestoreId, agentType, displayNameTemp);
restoredPhase = restored.restoredPhase;
restoredFlowchartMap = restored.flowchartMap;
restoredOriginalDraft = restored.originalDraft;
} catch {
// Not critical — UI will still show cached messages
}
}
// Suppress intro whenever we are about to restore a previous conversation.
@@ -1065,6 +1126,10 @@ export default function Workspace() {
displayName,
queenPhase: initialPhase,
queenBuilding: initialPhase === "building",
queenSupportsImages: session.queen_supports_images !== false,
// Restore flowchart overlay from persisted events
...(restoredFlowchartMap ? { flowchartMap: restoredFlowchartMap } : {}),
...(restoredOriginalDraft ? { originalDraft: restoredOriginalDraft, draftGraph: null } : {}),
});
// Update the session label + backendSessionId. Also set historySourceId
@@ -1102,6 +1167,11 @@ export default function Workspace() {
if (historyId && !coldRestoreId) {
const restored = await restoreSessionMessages(historyId, agentType, displayName);
restoredMsgs.push(...restored.messages);
// Use flowchart from event log if not already set
if (restored.flowchartMap && !restoredFlowchartMap) {
restoredFlowchartMap = restored.flowchartMap;
restoredOriginalDraft = restored.originalDraft;
}
// Check worker status (needed for isWorkerRunning flag)
try {
@@ -1144,6 +1214,9 @@ export default function Workspace() {
loading: false,
queenReady: !!(isResumedSession || hasRestoredContent),
...(isWorkerRunning ? { workerRunState: "running" } : {}),
// Restore flowchart overlay from persisted events
...(restoredFlowchartMap ? { flowchartMap: restoredFlowchartMap } : {}),
...(restoredOriginalDraft ? { originalDraft: restoredOriginalDraft, draftGraph: null } : {}),
});
} catch (err: unknown) {
const msg = err instanceof Error ? err.message : String(err);
@@ -1260,12 +1333,28 @@ export default function Workspace() {
const fireMap = new Map<string, number>();
const taskMap = new Map<string, string>();
const labelMap = new Map<string, string>();
const targetMap = new Map<string, string>();
for (const ep of triggerEps) {
const nodeId = `__trigger_${ep.id}`;
if (ep.next_fire_in != null) {
fireMap.set(`__trigger_${ep.id}`, ep.next_fire_in);
fireMap.set(nodeId, ep.next_fire_in);
}
if (ep.task != null) {
taskMap.set(`__trigger_${ep.id}`, ep.task);
taskMap.set(nodeId, ep.task);
}
const cron = ep.trigger_config?.cron as string | undefined;
const interval = ep.trigger_config?.interval_minutes as number | undefined;
const epLabel = cron
? cronToLabel(cron)
: interval
? `Every ${interval >= 60 ? `${interval / 60}h` : `${interval}m`}`
: ep.name || undefined;
if (epLabel) {
labelMap.set(nodeId, epLabel);
}
if (ep.entry_node) {
targetMap.set(nodeId, ep.entry_node);
}
}
@@ -1274,14 +1363,18 @@ export default function Workspace() {
if (!ss?.length) return prev;
const existingIds = new Set(ss[0].graphNodes.map(n => n.id));
// Update existing trigger nodes
// Update existing trigger nodes (countdown, task, label, target)
let updated = ss[0].graphNodes.map((n) => {
if (n.nodeType !== "trigger") return n;
const nfi = fireMap.get(n.id);
const task = taskMap.get(n.id);
if (nfi == null && task == null) return n;
const label = labelMap.get(n.id);
const target = targetMap.get(n.id);
if (nfi == null && task == null && !label && !target) return n;
return {
...n,
...(label && label !== n.label ? { label } : {}),
...(target ? { next: [target] } : {}),
triggerConfig: {
...n.triggerConfig,
...(nfi != null ? { next_fire_in: nfi } : {}),
@@ -1291,14 +1384,15 @@ export default function Workspace() {
});
// Discover new triggers not yet in the graph
const entryNode = ss[0].graphNodes.find(n => n.nodeType !== "trigger")?.id;
const fallbackEntry = ss[0].graphNodes.find(n => n.nodeType !== "trigger")?.id;
const newNodes: GraphNode[] = [];
for (const ep of triggerEps) {
const nodeId = `__trigger_${ep.id}`;
if (existingIds.has(nodeId)) continue;
const target = ep.entry_node || fallbackEntry;
newNodes.push({
id: nodeId,
label: ep.name || ep.id,
label: labelMap.get(nodeId) || ep.name || ep.id,
status: "pending",
nodeType: "trigger",
triggerType: ep.trigger_type,
@@ -1307,7 +1401,7 @@ export default function Workspace() {
...(ep.next_fire_in != null ? { next_fire_in: ep.next_fire_in } : {}),
...(ep.task ? { task: ep.task } : {}),
},
...(entryNode ? { next: [entryNode] } : {}),
...(target ? { next: [target] } : {}),
});
}
if (newNodes.length > 0) {
@@ -1625,14 +1719,29 @@ export default function Workspace() {
if (isQueen) console.log('[QUEEN] chatMsg:', chatMsg?.id, chatMsg?.content?.slice(0, 50), 'turn:', currentTurn);
if (chatMsg && !suppressQueenMessages) {
// Queen emits multiple client_output_delta / llm_text_delta snapshots
// across iterations and inner tool-loop turns. Build a stable ID that
// groups streaming deltas for the *same* output (same execution +
// iteration + inner_turn) into one bubble, while keeping distinct
// outputs as separate bubbles so earlier text isn't overwritten.
// across iterations and inner tool-loop turns. Merge all inner_turns
// within the same iteration into ONE bubble so the queen's multi-step
// tool loop (text → tool → text → tool → text) appears as one cohesive
// message rather than many small fragments.
if (isQueen && (event.type === "client_output_delta" || event.type === "llm_text_delta") && event.execution_id) {
const iter = event.data?.iteration ?? 0;
const inner = event.data?.inner_turn ?? 0;
chatMsg.id = `queen-stream-${event.execution_id}-${iter}-${inner}`;
const inner = (event.data?.inner_turn as number) ?? 0;
const iterKey = `${agentType}:${event.execution_id}:${iter}`;
// Store the latest snapshot for this inner_turn
if (!queenIterTextRef.current[iterKey]) {
queenIterTextRef.current[iterKey] = {};
}
const snapshot = (event.data?.snapshot as string) || (event.data?.content as string) || "";
queenIterTextRef.current[iterKey][inner] = snapshot;
// Concatenate all inner_turn snapshots in order
const parts = queenIterTextRef.current[iterKey];
const sortedInners = Object.keys(parts).map(Number).sort((a, b) => a - b);
chatMsg.content = sortedInners.map(k => parts[k]).join("\n");
// Single ID per iteration — no inner_turn in the ID
chatMsg.id = `queen-stream-${event.execution_id}-${iter}`;
}
if (isQueen) {
chatMsg.role = role;
@@ -1907,6 +2016,8 @@ export default function Workspace() {
role,
thread: agentType,
createdAt: eventCreatedAt,
nodeId: event.node_id || undefined,
executionId: event.execution_id || undefined,
});
return {
...prev,
@@ -1978,6 +2089,8 @@ export default function Workspace() {
role,
thread: agentType,
createdAt: eventCreatedAt,
nodeId: event.node_id || undefined,
executionId: event.execution_id || undefined,
});
return {
...prev,
@@ -2054,6 +2167,29 @@ export default function Workspace() {
}
break;
case "context_usage_updated": {
const streamKey = isQueen ? "__queen__" : (event.node_id || streamId);
const usagePct = (event.data?.usage_pct as number) ?? 0;
const messageCount = (event.data?.message_count as number) ?? 0;
const estimatedTokens = (event.data?.estimated_tokens as number) ?? 0;
const maxTokens = (event.data?.max_context_tokens as number) ?? 0;
setAgentStates(prev => {
const state = prev[agentType];
if (!state) return prev;
return {
...prev,
[agentType]: {
...state,
contextUsage: {
...state.contextUsage,
[streamKey]: { usagePct, messageCount, estimatedTokens, maxTokens },
},
},
};
});
}
break;
case "node_action_plan":
if (!isQueen && event.node_id) {
const plan = (event.data?.plan as string) || "";
@@ -2237,10 +2373,18 @@ export default function Workspace() {
// Synthesize new trigger node at the front of the graph
const triggerType = (event.data?.trigger_type as string) || "timer";
const triggerConfig = (event.data?.trigger_config as Record<string, unknown>) || {};
const entryNode = s.graphNodes.find(n => n.nodeType !== "trigger")?.id;
const entryNode = (event.data?.entry_node as string) || s.graphNodes.find(n => n.nodeType !== "trigger")?.id;
const triggerName = (event.data?.name as string) || triggerId;
const _cron = triggerConfig.cron as string | undefined;
const _interval = triggerConfig.interval_minutes as number | undefined;
const computedLabel = _cron
? cronToLabel(_cron)
: _interval
? `Every ${_interval >= 60 ? `${_interval / 60}h` : `${_interval}m`}`
: triggerName;
const newNode: GraphNode = {
id: nodeId,
label: triggerId,
label: computedLabel,
status: "running",
nodeType: "trigger",
triggerType,
@@ -2305,10 +2449,18 @@ export default function Workspace() {
if (s.graphNodes.some(n => n.id === nodeId)) return s;
const triggerType = (event.data?.trigger_type as string) || "timer";
const triggerConfig = (event.data?.trigger_config as Record<string, unknown>) || {};
const entryNode = s.graphNodes.find(n => n.nodeType !== "trigger")?.id;
const entryNode = (event.data?.entry_node as string) || s.graphNodes.find(n => n.nodeType !== "trigger")?.id;
const triggerName = (event.data?.name as string) || triggerId;
const _cron2 = triggerConfig.cron as string | undefined;
const _interval2 = triggerConfig.interval_minutes as number | undefined;
const computedLabel2 = _cron2
? cronToLabel(_cron2)
: _interval2
? `Every ${_interval2 >= 60 ? `${_interval2 / 60}h` : `${_interval2}m`}`
: triggerName;
const newNode: GraphNode = {
id: nodeId,
label: triggerId,
label: computedLabel2,
status: "pending",
nodeType: "trigger",
triggerType,
@@ -2323,6 +2475,43 @@ export default function Workspace() {
break;
}
case "trigger_updated": {
const triggerId = event.data?.trigger_id as string;
if (triggerId) {
const nodeId = `__trigger_${triggerId}`;
const triggerConfig = (event.data?.trigger_config as Record<string, unknown>) || {};
const cron = triggerConfig.cron as string | undefined;
const interval = triggerConfig.interval_minutes as number | undefined;
const newLabel = cron
? cronToLabel(cron)
: interval
? `Every ${interval >= 60 ? `${interval / 60}h` : `${interval}m`}`
: undefined;
setSessionsByAgent(prev => {
const sessions = prev[agentType] || [];
const activeId = activeSessionRef.current[agentType] || sessions[0]?.id;
return {
...prev,
[agentType]: sessions.map(s => {
if (s.id !== activeId) return s;
return {
...s,
graphNodes: s.graphNodes.map(n => {
if (n.id !== nodeId) return n;
return {
...n,
...(newLabel ? { label: newLabel } : {}),
triggerConfig: { ...n.triggerConfig, ...triggerConfig },
};
}),
};
}),
};
});
}
break;
}
case "trigger_removed": {
const triggerId = event.data?.trigger_id as string;
if (triggerId) {
@@ -2376,14 +2565,43 @@ export default function Workspace() {
const liveSelectedNode = selectedNode && currentGraph.nodes.find(n => n.id === selectedNode.id);
const resolvedSelectedNode = liveSelectedNode || selectedNode;
// Sync trigger task draft when selected trigger node changes
// Sync trigger drafts when selected trigger node changes
useEffect(() => {
if (resolvedSelectedNode?.nodeType === "trigger") {
const tc = resolvedSelectedNode.triggerConfig as Record<string, unknown> | undefined;
setTriggerTaskDraft((tc?.task as string) || "");
setTriggerCronDraft((tc?.cron as string) || "");
}
}, [resolvedSelectedNode?.id]);
const patchTriggerNode = useCallback((agentType: string, triggerNodeId: string, patch: { task?: string; trigger_config?: Record<string, unknown>; label?: string }) => {
setSessionsByAgent(prev => {
const sessions = prev[agentType] || [];
const activeId = activeSessionRef.current[agentType] || sessions[0]?.id;
return {
...prev,
[agentType]: sessions.map(s => {
if (s.id !== activeId) return s;
return {
...s,
graphNodes: s.graphNodes.map(n => {
if (n.id !== triggerNodeId) return n;
return {
...n,
...(patch.label !== undefined ? { label: patch.label } : {}),
triggerConfig: {
...n.triggerConfig,
...(patch.trigger_config || {}),
...(patch.task !== undefined ? { task: patch.task } : {}),
},
};
}),
};
}),
};
});
}, []);
// Build a flat list of all agent-type tabs for the tab bar
const agentTabs = Object.entries(sessionsByAgent)
.filter(([, sessions]) => sessions.length > 0)
@@ -2400,7 +2618,7 @@ export default function Workspace() {
});
// --- handleSend ---
const handleSend = useCallback((text: string, thread: string) => {
const handleSend = useCallback((text: string, thread: string, images?: import("@/components/ChatPanel").ImageContent[]) => {
if (!activeSession) return;
const state = agentStates[activeWorker];
@@ -2466,6 +2684,7 @@ export default function Workspace() {
const userMsg: ChatMessage = {
id: makeId(), agent: "You", agentColor: "",
content: text, timestamp: "", type: "user", thread, createdAt: Date.now(),
images,
};
setSessionsByAgent(prev => ({
...prev,
@@ -2477,7 +2696,7 @@ export default function Workspace() {
updateAgentState(activeWorker, { isTyping: true, queenIsTyping: true });
if (state?.sessionId && state?.ready) {
executionApi.chat(state.sessionId, text).catch((err: unknown) => {
executionApi.chat(state.sessionId, text, images).catch((err: unknown) => {
const errMsg = err instanceof Error ? err.message : String(err);
const errorChatMsg: ChatMessage = {
id: makeId(), agent: "System", agentColor: "",
@@ -2893,6 +3112,16 @@ export default function Workspace() {
<KeyRound className="w-3.5 h-3.5" />
Credentials
</button>
{activeAgentState?.sessionId && (
<button
onClick={() => sessionsApi.revealFolder(activeAgentState.sessionId!).catch(() => {})}
className="flex items-center gap-1.5 px-3 py-1.5 rounded-md text-xs font-medium text-muted-foreground hover:text-foreground hover:bg-muted/50 transition-colors flex-shrink-0"
title="Open session data folder"
>
<FolderOpen className="w-3.5 h-3.5" />
Data
</button>
)}
</TopBar>
{/* Main content area */}
@@ -3010,6 +3239,8 @@ export default function Workspace() {
}
onMultiQuestionSubmit={handleMultiQuestionAnswer}
onQuestionDismiss={handleQuestionDismiss}
contextUsage={activeAgentState?.contextUsage}
supportsImages={activeAgentState?.queenSupportsImages ?? true}
/>
)}
</div>
@@ -3052,18 +3283,64 @@ export default function Workspace() {
const interval = tc?.interval_minutes as number | undefined;
const eventTypes = tc?.event_types as string[] | undefined;
const scheduleLabel = cron
? `cron: ${cron}`
? cronToLabel(cron)
: interval
? `Every ${interval >= 60 ? `${interval / 60}h` : `${interval}m`}`
: eventTypes?.length
? eventTypes.join(", ")
: null;
return scheduleLabel ? (
const canEditCron = resolvedSelectedNode.triggerType === "timer";
const cronChanged = canEditCron && triggerCronDraft.trim() !== (cron || "");
return scheduleLabel || canEditCron ? (
<div>
<p className="text-[10px] font-medium text-muted-foreground uppercase tracking-wider mb-1.5">Schedule</p>
<p className="text-xs text-foreground/80 font-mono bg-muted/30 rounded-lg px-3 py-2 border border-border/20">
{scheduleLabel}
</p>
{scheduleLabel && (
<p className="text-xs text-foreground/80 font-mono bg-muted/30 rounded-lg px-3 py-2 border border-border/20">
{scheduleLabel}
</p>
)}
{canEditCron && (
<>
<input
value={triggerCronDraft}
onChange={(e) => setTriggerCronDraft(e.target.value)}
placeholder="0 5 * * *"
className="mt-1.5 w-full text-xs text-foreground/80 bg-muted/30 rounded-lg px-3 py-2 border border-border/20 font-mono focus:outline-none focus:border-primary/40"
/>
<p className="text-[10px] text-muted-foreground/60 mt-1">
Edit the cron expression for this timer trigger.
</p>
{(cronChanged || triggerCronSaved) && (
<button
disabled={triggerScheduleSaving || !cronChanged}
onClick={async () => {
const sessionId = activeAgentState?.sessionId;
const triggerId = resolvedSelectedNode.id.replace("__trigger_", "");
const nextCron = triggerCronDraft.trim();
if (!sessionId || !nextCron) return;
const nextTriggerConfig: Record<string, unknown> = { cron: nextCron };
setTriggerScheduleSaving(true);
try {
await sessionsApi.updateTrigger(sessionId, triggerId, {
trigger_config: nextTriggerConfig,
});
patchTriggerNode(activeWorker, resolvedSelectedNode.id, {
trigger_config: nextTriggerConfig,
label: cronToLabel(nextCron),
});
setTriggerCronSaved(true);
setTimeout(() => setTriggerCronSaved(false), 2000);
} finally {
setTriggerScheduleSaving(false);
}
}}
className="mt-1.5 w-full text-[11px] px-3 py-1.5 rounded-lg border border-primary/30 text-primary hover:bg-primary/10 transition-colors disabled:opacity-50"
>
{triggerScheduleSaving ? "Saving..." : triggerCronSaved ? "Saved" : "Save Cron"}
</button>
)}
</>
)}
</div>
) : null;
})()}
@@ -3090,24 +3367,27 @@ export default function Workspace() {
{(() => {
const currentTask = (resolvedSelectedNode.triggerConfig as Record<string, unknown> | undefined)?.task as string || "";
const hasChanged = triggerTaskDraft !== currentTask;
if (!hasChanged) return null;
if (!hasChanged && !triggerTaskSaved) return null;
return (
<button
disabled={triggerTaskSaving}
disabled={triggerTaskSaving || !hasChanged}
onClick={async () => {
const sessionId = activeAgentState?.sessionId;
const triggerId = resolvedSelectedNode.id.replace("__trigger_", "");
if (!sessionId) return;
setTriggerTaskSaving(true);
try {
await sessionsApi.updateTriggerTask(sessionId, triggerId, triggerTaskDraft);
await sessionsApi.updateTrigger(sessionId, triggerId, { task: triggerTaskDraft });
patchTriggerNode(activeWorker, resolvedSelectedNode.id, { task: triggerTaskDraft });
setTriggerTaskSaved(true);
setTimeout(() => setTriggerTaskSaved(false), 2000);
} finally {
setTriggerTaskSaving(false);
}
}}
className="mt-1.5 w-full text-[11px] px-3 py-1.5 rounded-lg border border-primary/30 text-primary hover:bg-primary/10 transition-colors disabled:opacity-50"
>
{triggerTaskSaving ? "Saving..." : "Save Task"}
{triggerTaskSaving ? "Saving..." : triggerTaskSaved ? "Saved" : "Save Task"}
</button>
);
})()}
@@ -3164,6 +3444,7 @@ export default function Workspace() {
workerSessionId={null}
nodeLogs={activeAgentState?.nodeLogs[resolvedSelectedNode.id] || []}
actionPlan={activeAgentState?.nodeActionPlans[resolvedSelectedNode.id]}
contextUsage={activeAgentState?.contextUsage[resolvedSelectedNode.id]}
onClose={() => setSelectedNode(null)}
/>
)}
+4
View File
@@ -62,6 +62,10 @@ lint.isort.section-order = [
"first-party",
"local-folder",
]
[tool.pytest.ini_options]
filterwarnings = [
"ignore::DeprecationWarning:litellm.*"
]
[dependency-groups]
dev = [
+172
View File
@@ -0,0 +1,172 @@
"""Integration test: Run a real EventLoopNode against the Antigravity backend.
Run: .venv/bin/python core/tests/test_antigravity_eventloop.py
Requires:
- ~/.hive/antigravity-accounts.json with valid credentials
(run 'uv run python core/antigravity_auth.py auth account add' to authenticate)
"""
import asyncio
import logging
import sys
from unittest.mock import MagicMock
sys.path.insert(0, "core")
logging.basicConfig(level=logging.WARNING, format="%(levelname)s %(name)s: %(message)s")
# Show our provider's retry/stream logs
logging.getLogger("framework.llm.litellm").setLevel(logging.DEBUG)
from framework.config import RuntimeConfig # noqa: E402
from framework.graph.event_loop_node import EventLoopNode, LoopConfig # noqa: E402
from framework.graph.node import NodeContext, NodeResult, NodeSpec, SharedMemory # noqa: E402
from framework.llm.litellm import LiteLLMProvider # noqa: E402
def make_provider() -> LiteLLMProvider:
cfg = RuntimeConfig()
if not cfg.api_key:
print("ERROR: No Antigravity token found.")
print(" 1. Run 'antigravity-auth accounts add' to authenticate.")
print(" 2. Run 'antigravity-auth serve' to start the local proxy.")
print(" 3. Configure Hive: run quickstart.sh and select option 7 (Antigravity).")
sys.exit(1)
print(f"Model : {cfg.model}")
print(f"Base : {cfg.api_base}")
print(f"Antigravity : {'localhost:8069' in (cfg.api_base or '')}")
return LiteLLMProvider(
model=cfg.model,
api_key=cfg.api_key,
api_base=cfg.api_base,
**cfg.extra_kwargs,
)
def make_context(
llm: LiteLLMProvider,
*,
node_id: str = "test",
system_prompt: str = "You are a helpful assistant.",
output_keys: list[str] | None = None,
) -> NodeContext:
if output_keys is None:
output_keys = ["answer"]
spec = NodeSpec(
id=node_id,
name="Test Node",
description="Integration test node",
node_type="event_loop",
output_keys=output_keys,
system_prompt=system_prompt,
)
runtime = MagicMock()
runtime.start_run = MagicMock(return_value="run-1")
runtime.decide = MagicMock(return_value="dec-1")
runtime.record_outcome = MagicMock()
runtime.end_run = MagicMock()
memory = SharedMemory()
return NodeContext(
runtime=runtime,
node_id=node_id,
node_spec=spec,
memory=memory,
input_data={},
llm=llm,
available_tools=[],
max_tokens=4096,
)
async def run_test(
name: str, llm: LiteLLMProvider, system: str, output_keys: list[str]
) -> NodeResult:
print(f"\n{'=' * 60}")
print(f"TEST: {name}")
print(f"{'=' * 60}")
ctx = make_context(llm, system_prompt=system, output_keys=output_keys)
node = EventLoopNode(config=LoopConfig(max_iterations=3))
try:
result = await node.execute(ctx)
print(f" Success : {result.success}")
print(f" Output : {result.output}")
if result.error:
print(f" Error : {result.error}")
return result
except Exception as e:
print(f" EXCEPTION: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
return NodeResult(success=False, error=str(e))
async def main():
llm = make_provider()
print()
# Test 1: Simple text output — the node should call set_output to fill "answer"
r1 = await run_test(
name="Simple text generation",
llm=llm,
system=(
"You are a helpful assistant. When asked a question, use the "
"set_output tool to store your answer in the 'answer' key. "
"Keep answers short (1-2 sentences)."
),
output_keys=["answer"],
)
# Test 2: If test 1 failed, try bare stream() to isolate the issue
if not r1.success:
print(f"\n{'=' * 60}")
print("FALLBACK: Testing bare provider.stream() directly")
print(f"{'=' * 60}")
try:
from framework.llm.stream_events import (
FinishEvent,
StreamErrorEvent,
TextDeltaEvent,
ToolCallEvent,
)
text = ""
events = []
async for event in llm.stream(
messages=[{"role": "user", "content": "Say hello in 3 words."}],
):
events.append(type(event).__name__)
if isinstance(event, TextDeltaEvent):
text = event.snapshot
elif isinstance(event, FinishEvent):
print(
f" Finish: stop={event.stop_reason}"
f" in={event.input_tokens}"
f" out={event.output_tokens}"
)
elif isinstance(event, StreamErrorEvent):
print(f" StreamError: {event.error} (recoverable={event.recoverable})")
elif isinstance(event, ToolCallEvent):
print(f" ToolCall: {event.tool_name}")
print(f" Text : {text!r}")
print(f" Events : {events}")
print(f" RESULT : {'OK' if text else 'EMPTY'}")
except Exception as e:
print(f" EXCEPTION: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
print(f"\n{'=' * 60}")
print("DONE")
print(f"{'=' * 60}")
if __name__ == "__main__":
asyncio.run(main())
+58
View File
@@ -0,0 +1,58 @@
"""Tests for LLM model capability checks."""
from __future__ import annotations
import pytest
from framework.llm.capabilities import supports_image_tool_results
class TestSupportsImageToolResults:
"""Verify the deny-list correctly identifies models that can't handle images."""
@pytest.mark.parametrize(
"model",
[
"gpt-4o",
"gpt-4o-mini",
"gpt-4-turbo",
"openai/gpt-4o",
"anthropic/claude-sonnet-4-20250514",
"claude-haiku-4-5-20251001",
"gemini/gemini-1.5-pro",
"google/gemini-1.5-flash",
"mistral/mistral-large",
"groq/llama3-70b",
"together/meta-llama/Llama-3-70b",
"fireworks_ai/llama-v3-70b",
"azure/gpt-4o",
"kimi/claude-sonnet-4-20250514",
"hive/claude-sonnet-4-20250514",
],
)
def test_supported_models(self, model: str):
assert supports_image_tool_results(model) is True
@pytest.mark.parametrize(
"model",
[
"deepseek/deepseek-chat",
"deepseek/deepseek-coder",
"deepseek-chat",
"deepseek-reasoner",
"ollama/llama3",
"ollama/mistral",
"ollama_chat/llama3",
"lm_studio/my-model",
"vllm/meta-llama/Llama-3-70b",
"llamacpp/model",
"cerebras/llama3-70b",
],
)
def test_unsupported_models(self, model: str):
assert supports_image_tool_results(model) is False
def test_case_insensitive(self):
assert supports_image_tool_results("DeepSeek/deepseek-chat") is False
assert supports_image_tool_results("OLLAMA/llama3") is False
assert supports_image_tool_results("GPT-4o") is True
+209
View File
@@ -0,0 +1,209 @@
import importlib.util
from pathlib import Path
def _load_check_llm_key_module():
module_path = Path(__file__).resolve().parents[2] / "scripts" / "check_llm_key.py"
spec = importlib.util.spec_from_file_location("check_llm_key_script", module_path)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
return module
def _run_openrouter_check(monkeypatch, status_code: int):
module = _load_check_llm_key_module()
calls = {}
class FakeResponse:
def __init__(self, code):
self.status_code = code
class FakeClient:
def __init__(self, timeout):
calls["timeout"] = timeout
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def get(self, endpoint, headers):
calls["endpoint"] = endpoint
calls["headers"] = headers
return FakeResponse(status_code)
monkeypatch.setattr(module.httpx, "Client", FakeClient)
result = module.check_openrouter("test-key")
return result, calls
def _run_openrouter_model_check(
monkeypatch,
status_code: int,
payload: dict | None = None,
model: str = "openai/gpt-4o-mini",
):
module = _load_check_llm_key_module()
calls = {}
class FakeResponse:
def __init__(self, code):
self.status_code = code
self._payload = payload
self.text = ""
def json(self):
if self._payload is None:
raise ValueError("no json")
return self._payload
class FakeClient:
def __init__(self, timeout):
calls["timeout"] = timeout
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def get(self, endpoint, headers):
calls["endpoint"] = endpoint
calls["headers"] = headers
return FakeResponse(status_code)
monkeypatch.setattr(module.httpx, "Client", FakeClient)
result = module.check_openrouter_model("test-key", model)
return result, calls
def test_check_openrouter_200(monkeypatch):
result, calls = _run_openrouter_check(monkeypatch, 200)
assert result == {"valid": True, "message": "OpenRouter API key valid"}
assert calls["endpoint"] == "https://openrouter.ai/api/v1/models"
assert calls["headers"] == {"Authorization": "Bearer test-key"}
def test_check_openrouter_401(monkeypatch):
result, _ = _run_openrouter_check(monkeypatch, 401)
assert result == {"valid": False, "message": "Invalid OpenRouter API key"}
def test_check_openrouter_403(monkeypatch):
result, _ = _run_openrouter_check(monkeypatch, 403)
assert result == {"valid": False, "message": "OpenRouter API key lacks permissions"}
def test_check_openrouter_429(monkeypatch):
result, _ = _run_openrouter_check(monkeypatch, 429)
assert result == {"valid": True, "message": "OpenRouter API key valid"}
def test_check_openrouter_model_200(monkeypatch):
result, calls = _run_openrouter_model_check(
monkeypatch,
200,
{
"data": [
{
"id": "openai/gpt-4o-mini",
"canonical_slug": "openai/gpt-4o-mini",
}
]
},
)
assert result == {
"valid": True,
"message": "OpenRouter model is available: openai/gpt-4o-mini",
"model": "openai/gpt-4o-mini",
}
assert calls["endpoint"] == "https://openrouter.ai/api/v1/models/user"
assert calls["headers"] == {"Authorization": "Bearer test-key"}
def test_check_openrouter_model_200_matches_canonical_slug(monkeypatch):
result, _ = _run_openrouter_model_check(
monkeypatch,
200,
{
"data": [
{
"id": "mistralai/mistral-small-4",
"canonical_slug": "mistralai/mistral-small-2603",
}
]
},
model="mistralai/mistral-small-2603",
)
assert result == {
"valid": True,
"message": "OpenRouter model is available: mistralai/mistral-small-2603",
"model": "mistralai/mistral-small-2603",
}
def test_check_openrouter_model_200_sanitizes_pasted_unicode(monkeypatch):
result, _ = _run_openrouter_model_check(
monkeypatch,
200,
{
"data": [
{
"id": "z-ai/glm-5-turbo",
"canonical_slug": "z-ai/glm-5-turbo",
}
]
},
model="openrouter/z-ai\u200b/glm\u20115\u2011turbo",
)
assert result == {
"valid": True,
"message": "OpenRouter model is available: z-ai/glm-5-turbo",
"model": "z-ai/glm-5-turbo",
}
def test_check_openrouter_model_200_not_found_with_suggestions(monkeypatch):
result, _ = _run_openrouter_model_check(
monkeypatch,
200,
{
"data": [
{"id": "z-ai/glm-5-turbo"},
{"id": "z-ai/glm-4.6v"},
]
},
model="z-ai/glm-5-turb",
)
assert result == {
"valid": False,
"message": (
"OpenRouter model is not available for this key/settings: z-ai/glm-5-turb. "
"Closest matches: z-ai/glm-5-turbo"
),
}
def test_check_openrouter_model_404_with_error_message(monkeypatch):
result, _ = _run_openrouter_model_check(
monkeypatch,
404,
{"error": {"message": "No endpoints available for this model"}},
)
assert result == {
"valid": False,
"message": (
"OpenRouter model is not available for this key/settings: openai/gpt-4o-mini. "
"No endpoints available for this model"
),
}
def test_check_openrouter_model_429(monkeypatch):
result, _ = _run_openrouter_model_check(monkeypatch, 429)
assert result == {
"valid": True,
"message": "OpenRouter model check rate-limited; assuming model is reachable",
}
+45 -1
View File
@@ -2,7 +2,7 @@
import logging
from framework.config import get_hive_config
from framework.config import get_api_base, get_hive_config, get_preferred_model
class TestGetHiveConfig:
@@ -21,3 +21,47 @@ class TestGetHiveConfig:
assert result == {}
assert "Failed to load Hive config" in caplog.text
assert str(config_file) in caplog.text
class TestOpenRouterConfig:
"""OpenRouter config composition and fallback behavior."""
def test_get_preferred_model_for_openrouter(self, tmp_path, monkeypatch):
config_file = tmp_path / "configuration.json"
config_file.write_text(
'{"llm":{"provider":"openrouter","model":"x-ai/grok-4.20-beta"}}',
encoding="utf-8",
)
monkeypatch.setattr("framework.config.HIVE_CONFIG_FILE", config_file)
assert get_preferred_model() == "openrouter/x-ai/grok-4.20-beta"
def test_get_preferred_model_normalizes_openrouter_prefixed_model(self, tmp_path, monkeypatch):
config_file = tmp_path / "configuration.json"
config_file.write_text(
'{"llm":{"provider":"openrouter","model":"openrouter/x-ai/grok-4.20-beta"}}',
encoding="utf-8",
)
monkeypatch.setattr("framework.config.HIVE_CONFIG_FILE", config_file)
assert get_preferred_model() == "openrouter/x-ai/grok-4.20-beta"
def test_get_api_base_falls_back_to_openrouter_default(self, tmp_path, monkeypatch):
config_file = tmp_path / "configuration.json"
config_file.write_text(
'{"llm":{"provider":"openrouter","model":"x-ai/grok-4.20-beta"}}',
encoding="utf-8",
)
monkeypatch.setattr("framework.config.HIVE_CONFIG_FILE", config_file)
assert get_api_base() == "https://openrouter.ai/api/v1"
def test_get_api_base_keeps_explicit_openrouter_api_base(self, tmp_path, monkeypatch):
config_file = tmp_path / "configuration.json"
config_file.write_text(
'{"llm":{"provider":"openrouter","model":"x-ai/grok-4.20-beta","api_base":"https://proxy.example/v1"}}',
encoding="utf-8",
)
monkeypatch.setattr("framework.config.HIVE_CONFIG_FILE", config_file)
assert get_api_base() == "https://proxy.example/v1"
+70
View File
@@ -0,0 +1,70 @@
import os
import sys
from types import ModuleType, SimpleNamespace
from framework.credentials import key_storage
from framework.credentials.validation import ensure_credential_key_env
def _install_fake_aden_modules(monkeypatch, check_fn, credential_specs):
shell_config_module = ModuleType("aden_tools.credentials.shell_config")
shell_config_module.check_env_var_in_shell_config = check_fn
credentials_module = ModuleType("aden_tools.credentials")
credentials_module.CREDENTIAL_SPECS = credential_specs
monkeypatch.setitem(sys.modules, "aden_tools.credentials.shell_config", shell_config_module)
monkeypatch.setitem(sys.modules, "aden_tools.credentials", credentials_module)
def test_bootstrap_loads_configured_llm_env_var_from_shell_config(monkeypatch):
monkeypatch.setattr(key_storage, "load_credential_key", lambda: None)
monkeypatch.setattr(key_storage, "load_aden_api_key", lambda: None)
monkeypatch.setattr(
"framework.config.get_hive_config",
lambda: {"llm": {"api_key_env_var": "OPENROUTER_API_KEY"}},
)
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
calls = []
def check_env(var_name):
calls.append(var_name)
if var_name == "OPENROUTER_API_KEY":
return True, "or-key-123"
return False, None
_install_fake_aden_modules(
monkeypatch,
check_env,
{"anthropic": SimpleNamespace(env_var="ANTHROPIC_API_KEY")},
)
ensure_credential_key_env()
assert os.environ.get("OPENROUTER_API_KEY") == "or-key-123"
assert "OPENROUTER_API_KEY" in calls
def test_bootstrap_does_not_override_existing_configured_llm_env_var(monkeypatch):
monkeypatch.setattr(key_storage, "load_credential_key", lambda: None)
monkeypatch.setattr(key_storage, "load_aden_api_key", lambda: None)
monkeypatch.setattr(
"framework.config.get_hive_config",
lambda: {"llm": {"api_key_env_var": "OPENROUTER_API_KEY"}},
)
monkeypatch.setenv("OPENROUTER_API_KEY", "already-set")
calls = []
def check_env(var_name):
calls.append(var_name)
return True, "new-value-should-not-apply"
_install_fake_aden_modules(monkeypatch, check_env, {})
ensure_credential_key_env()
assert os.environ.get("OPENROUTER_API_KEY") == "already-set"
assert "OPENROUTER_API_KEY" not in calls
+28
View File
@@ -1530,6 +1530,34 @@ class TestTransientErrorRetry:
await node.execute(ctx)
assert llm._call_index == 1 # only tried once
@pytest.mark.asyncio
async def test_client_facing_non_transient_error_does_not_crash(
self, runtime, node_spec, memory
):
"""Client-facing non-transient errors should wait for input, not crash on token vars."""
node_spec.output_keys = []
node_spec.client_facing = True
llm = ErrorThenSuccessLLM(
error=ValueError("bad request: blocked by policy"),
fail_count=100, # always fails
success_scenario=text_scenario("unreachable"),
)
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(
config=LoopConfig(
max_iterations=1,
max_stream_retries=0,
stream_retry_backoff_base=0.01,
),
)
node._await_user_input = AsyncMock(return_value=None)
result = await node.execute(ctx)
assert result.success is False
assert "Max iterations" in (result.error or "")
node._await_user_input.assert_awaited_once()
@pytest.mark.asyncio
async def test_transient_error_exhausts_retries(self, runtime, node_spec, memory):
"""Transient errors that exhaust retries should raise."""
+356 -1
View File
@@ -19,7 +19,11 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from framework.llm.anthropic import AnthropicProvider
from framework.llm.litellm import LiteLLMProvider, _compute_retry_delay
from framework.llm.litellm import (
OPENROUTER_TOOL_COMPAT_MODEL_CACHE,
LiteLLMProvider,
_compute_retry_delay,
)
from framework.llm.provider import LLMProvider, LLMResponse, Tool
@@ -72,6 +76,20 @@ class TestLiteLLMProviderInit:
)
assert provider.api_base == "https://proxy.example/v1"
def test_init_openrouter_defaults_api_base(self):
"""OpenRouter should default to the official OpenAI-compatible endpoint."""
provider = LiteLLMProvider(model="openrouter/x-ai/grok-4.20-beta", api_key="my-key")
assert provider.api_base == "https://openrouter.ai/api/v1"
def test_init_openrouter_keeps_custom_api_base(self):
"""Explicit api_base should win over OpenRouter defaults."""
provider = LiteLLMProvider(
model="openrouter/x-ai/grok-4.20-beta",
api_key="my-key",
api_base="https://proxy.example/v1",
)
assert provider.api_base == "https://proxy.example/v1"
def test_init_ollama_no_key_needed(self):
"""Test that Ollama models don't require API key."""
with patch.dict(os.environ, {}, clear=True):
@@ -192,6 +210,34 @@ class TestToolConversion:
assert result["function"]["parameters"]["properties"]["query"]["type"] == "string"
assert result["function"]["parameters"]["required"] == ["query"]
def test_parse_tool_call_arguments_repairs_truncated_json(self):
"""Truncated JSON fragments should be repaired into valid tool inputs."""
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
parsed = provider._parse_tool_call_arguments(
(
'{"question":"What story structure should the agent use?",'
'"options":["3-act structure","Beginning-Middle-End","Random paragraph"'
),
"ask_user",
)
assert parsed == {
"question": "What story structure should the agent use?",
"options": [
"3-act structure",
"Beginning-Middle-End",
"Random paragraph",
],
}
def test_parse_tool_call_arguments_raises_when_unrepairable(self):
"""Completely invalid JSON should fail fast instead of producing _raw loops."""
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
with pytest.raises(ValueError, match="Failed to parse tool call arguments"):
provider._parse_tool_call_arguments('{"question": foo', "ask_user")
class TestAnthropicProviderBackwardCompatibility:
"""Test AnthropicProvider backward compatibility with LiteLLM backend."""
@@ -682,6 +728,315 @@ class TestMiniMaxStreamFallback:
assert not LiteLLMProvider(model="gpt-4o-mini", api_key="x")._is_minimax_model()
class TestOpenRouterToolCompatFallback:
"""OpenRouter models should fall back when native tool use is unavailable."""
def teardown_method(self):
OPENROUTER_TOOL_COMPAT_MODEL_CACHE.clear()
@pytest.mark.asyncio
@patch("litellm.acompletion")
async def test_stream_falls_back_to_json_tool_emulation(self, mock_acompletion):
"""OpenRouter tool-use 404s should emit synthetic ToolCallEvents instead of errors."""
from framework.llm.stream_events import FinishEvent, ToolCallEvent
provider = LiteLLMProvider(
model="openrouter/liquid/lfm-2.5-1.2b-thinking:free",
api_key="test-key",
)
tools = [
Tool(
name="web_search",
description="Search the web",
parameters={
"properties": {
"query": {"type": "string"},
"num_results": {"type": "integer"},
},
"required": ["query"],
},
)
]
compat_response = MagicMock()
compat_response.choices = [MagicMock()]
compat_response.choices[0].message.content = (
'{"assistant_response":"","tool_calls":['
'{"name":"web_search","arguments":'
'{"query":"Python 3.13 release notes","num_results":3}}'
"]}"
)
compat_response.choices[0].finish_reason = "stop"
compat_response.model = provider.model
compat_response.usage.prompt_tokens = 18
compat_response.usage.completion_tokens = 9
async def side_effect(*args, **kwargs):
if kwargs.get("stream"):
raise RuntimeError(
'OpenrouterException - {"error":{"message":"No endpoints found '
"that support tool use. To learn more about provider routing, "
'visit: https://openrouter.ai/docs/guides/routing/provider-selection",'
'"code":404}}'
)
return compat_response
mock_acompletion.side_effect = side_effect
events = []
async for event in provider.stream(
messages=[{"role": "user", "content": "Search for the Python 3.13 release notes."}],
system="Use tools when needed.",
tools=tools,
max_tokens=256,
):
events.append(event)
tool_calls = [event for event in events if isinstance(event, ToolCallEvent)]
assert len(tool_calls) == 1
assert tool_calls[0].tool_name == "web_search"
assert tool_calls[0].tool_input == {
"query": "Python 3.13 release notes",
"num_results": 3,
}
assert tool_calls[0].tool_use_id.startswith("openrouter_compat_")
finish_events = [event for event in events if isinstance(event, FinishEvent)]
assert len(finish_events) == 1
assert finish_events[0].stop_reason == "tool_calls"
assert finish_events[0].input_tokens == 18
assert finish_events[0].output_tokens == 9
assert mock_acompletion.call_count == 2
first_call = mock_acompletion.call_args_list[0].kwargs
assert first_call["stream"] is True
assert "tools" in first_call
second_call = mock_acompletion.call_args_list[1].kwargs
assert "tools" not in second_call
assert "Tool compatibility mode is active" in second_call["messages"][0]["content"]
assert provider.model in OPENROUTER_TOOL_COMPAT_MODEL_CACHE
@pytest.mark.asyncio
@patch("litellm.acompletion")
async def test_stream_tool_compat_parses_textual_tool_calls_and_uses_cache(
self,
mock_acompletion,
):
"""Textual tool-call markers should become ToolCallEvents and skip repeat probing."""
from framework.llm.stream_events import ToolCallEvent
provider = LiteLLMProvider(
model="openrouter/liquid/lfm-2.5-1.2b-thinking:free",
api_key="test-key",
)
tools = [
Tool(
name="ask_user_multiple",
description="Ask the user a multiple-choice question",
parameters={
"properties": {
"options": {"type": "array"},
"question": {"type": "string"},
"prompt": {"type": "string"},
},
"required": ["options", "question", "prompt"],
},
)
]
compat_response = MagicMock()
compat_response.choices = [MagicMock()]
compat_response.choices[0].message.content = (
"<|tool_call_start|>"
"[ask_user_multiple(options=['Quartet Collaborator', 'Project Advisor'], "
"question='Who are you?', prompt='Who are you?')]"
"<|tool_call_end|>"
)
compat_response.choices[0].finish_reason = "stop"
compat_response.model = provider.model
compat_response.usage.prompt_tokens = 10
compat_response.usage.completion_tokens = 5
call_state = {"count": 0}
async def side_effect(*args, **kwargs):
call_state["count"] += 1
if kwargs.get("stream"):
raise RuntimeError(
'OpenrouterException - {"error":{"message":"No endpoints found '
'that support tool use.","code":404}}'
)
return compat_response
mock_acompletion.side_effect = side_effect
first_events = []
async for event in provider.stream(
messages=[{"role": "user", "content": "Who are you?"}],
system="Use tools when needed.",
tools=tools,
max_tokens=128,
):
first_events.append(event)
tool_calls = [event for event in first_events if isinstance(event, ToolCallEvent)]
assert len(tool_calls) == 1
assert tool_calls[0].tool_name == "ask_user_multiple"
assert tool_calls[0].tool_input == {
"options": ["Quartet Collaborator", "Project Advisor"],
"question": "Who are you?",
"prompt": "Who are you?",
}
second_events = []
async for event in provider.stream(
messages=[{"role": "user", "content": "Who are you?"}],
system="Use tools when needed.",
tools=tools,
max_tokens=128,
):
second_events.append(event)
second_tool_calls = [event for event in second_events if isinstance(event, ToolCallEvent)]
assert len(second_tool_calls) == 1
assert mock_acompletion.call_count == 3
assert mock_acompletion.call_args_list[0].kwargs["stream"] is True
assert "stream" not in mock_acompletion.call_args_list[1].kwargs
assert "stream" not in mock_acompletion.call_args_list[2].kwargs
@pytest.mark.asyncio
@patch("litellm.acompletion")
async def test_stream_tool_compat_parses_plain_text_tool_call_lines(
self,
mock_acompletion,
):
"""Plain textual tool-call lines should execute as tools, not user-visible text."""
from framework.llm.stream_events import FinishEvent, TextDeltaEvent, ToolCallEvent
provider = LiteLLMProvider(
model="openrouter/liquid/lfm-2.5-1.2b-thinking:free",
api_key="test-key",
)
tools = [
Tool(
name="ask_user",
description="Ask the user a single multiple-choice question",
parameters={
"properties": {
"question": {"type": "string"},
"options": {"type": "array"},
},
"required": ["question", "options"],
},
)
]
compat_response = MagicMock()
compat_response.choices = [MagicMock()]
compat_response.choices[0].message.content = (
"Queen has been loaded. It's ready to assist with your planning needs.\n\n"
"ask_user('What would you like to do?', ['Define a new agent', "
"'Diagnose an existing agent', 'Explore tools'])"
)
compat_response.choices[0].finish_reason = "stop"
compat_response.model = provider.model
compat_response.usage.prompt_tokens = 11
compat_response.usage.completion_tokens = 7
async def side_effect(*args, **kwargs):
if kwargs.get("stream"):
raise RuntimeError(
'OpenrouterException - {"error":{"message":"No endpoints found '
'that support tool use.","code":404}}'
)
return compat_response
mock_acompletion.side_effect = side_effect
events = []
async for event in provider.stream(
messages=[{"role": "user", "content": "hello"}],
system="Use tools when needed.",
tools=tools,
max_tokens=128,
):
events.append(event)
tool_calls = [event for event in events if isinstance(event, ToolCallEvent)]
assert len(tool_calls) == 1
assert tool_calls[0].tool_name == "ask_user"
assert tool_calls[0].tool_input == {
"question": "What would you like to do?",
"options": ["Define a new agent", "Diagnose an existing agent", "Explore tools"],
}
text_events = [event for event in events if isinstance(event, TextDeltaEvent)]
assert len(text_events) == 1
assert "ask_user(" not in text_events[0].snapshot
assert text_events[0].snapshot == (
"Queen has been loaded. It's ready to assist with your planning needs."
)
finish_events = [event for event in events if isinstance(event, FinishEvent)]
assert len(finish_events) == 1
assert finish_events[0].stop_reason == "tool_calls"
@pytest.mark.asyncio
@patch("litellm.acompletion")
async def test_stream_tool_compat_treats_non_json_as_plain_text(self, mock_acompletion):
"""If fallback output is not valid JSON, preserve it as assistant text."""
from framework.llm.stream_events import FinishEvent, TextDeltaEvent, ToolCallEvent
provider = LiteLLMProvider(
model="openrouter/liquid/lfm-2.5-1.2b-thinking:free",
api_key="test-key",
)
tools = [
Tool(
name="web_search",
description="Search the web",
parameters={"properties": {"query": {"type": "string"}}, "required": ["query"]},
)
]
compat_response = MagicMock()
compat_response.choices = [MagicMock()]
compat_response.choices[0].message.content = "I can answer directly without tools."
compat_response.choices[0].finish_reason = "stop"
compat_response.model = provider.model
compat_response.usage.prompt_tokens = 12
compat_response.usage.completion_tokens = 6
async def side_effect(*args, **kwargs):
if kwargs.get("stream"):
raise RuntimeError(
'OpenrouterException - {"error":{"message":"No endpoints found '
'that support tool use.","code":404}}'
)
return compat_response
mock_acompletion.side_effect = side_effect
events = []
async for event in provider.stream(
messages=[{"role": "user", "content": "Say hello."}],
system="Be concise.",
tools=tools,
max_tokens=128,
):
events.append(event)
text_events = [event for event in events if isinstance(event, TextDeltaEvent)]
assert len(text_events) == 1
assert text_events[0].snapshot == "I can answer directly without tools."
assert not any(isinstance(event, ToolCallEvent) for event in events)
finish_events = [event for event in events if isinstance(event, FinishEvent)]
assert len(finish_events) == 1
assert finish_events[0].stop_reason == "stop"
# ---------------------------------------------------------------------------
# AgentRunner._is_local_model — parameterized tests
# ---------------------------------------------------------------------------
+228
View File
@@ -0,0 +1,228 @@
"""Unit tests for MCP client transport and reconnect behavior."""
from types import SimpleNamespace
import httpx
import pytest
from framework.runner import mcp_client as mcp_client_module
from framework.runner.mcp_client import MCPClient, MCPServerConfig, MCPTool
class _FakeResponse:
def __init__(self, payload=None):
self._payload = payload or {}
def raise_for_status(self) -> None:
"""Pretend the request succeeded."""
def json(self):
return self._payload
class _FakeHttpClient:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.get_calls: list[str] = []
self.closed = False
def get(self, path: str) -> _FakeResponse:
self.get_calls.append(path)
return _FakeResponse()
def close(self) -> None:
self.closed = True
def test_connect_unix_transport_uses_socket_path(monkeypatch):
created = {}
class FakeHTTPTransport:
def __init__(self, *, uds: str):
created["uds"] = uds
self.uds = uds
def fake_client_factory(**kwargs):
client = _FakeHttpClient(**kwargs)
created["client"] = client
return client
monkeypatch.setattr(mcp_client_module.httpx, "HTTPTransport", FakeHTTPTransport)
monkeypatch.setattr(mcp_client_module.httpx, "Client", fake_client_factory)
monkeypatch.setattr(MCPClient, "_discover_tools", lambda self: None)
client = MCPClient(
MCPServerConfig(
name="unix-server",
transport="unix",
url="http://localhost",
socket_path="/tmp/test.sock",
)
)
client.connect()
assert created["uds"] == "/tmp/test.sock"
assert client._http_client is created["client"] # noqa: SLF001 - direct unit test
assert created["client"].kwargs["base_url"] == "http://localhost"
assert created["client"].get_calls == ["/health"]
client.disconnect()
assert created["client"].closed is True
def test_connect_sse_and_list_tools(monkeypatch):
pytest.importorskip("mcp")
sse_module = pytest.importorskip("mcp.client.sse")
import mcp
contexts = []
class FakeSSEContext:
def __init__(self, url: str, headers: dict[str, str] | None, timeout: float):
self.url = url
self.headers = headers
self.timeout = timeout
self.exited = False
async def __aenter__(self):
return "read-stream", "write-stream"
async def __aexit__(self, exc_type, exc, tb):
self.exited = True
class FakeSession:
def __init__(self, read_stream, write_stream):
self.read_stream = read_stream
self.write_stream = write_stream
self.closed = False
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
self.closed = True
async def initialize(self):
"""Pretend session initialization succeeded."""
async def list_tools(self):
return SimpleNamespace(
tools=[
SimpleNamespace(
name="search",
description="Search docs",
inputSchema={"type": "object"},
)
]
)
def fake_sse_client(url: str, headers=None, timeout=5, **_kwargs):
context = FakeSSEContext(url=url, headers=headers, timeout=timeout)
contexts.append(context)
return context
monkeypatch.setattr(sse_module, "sse_client", fake_sse_client)
monkeypatch.setattr(mcp, "ClientSession", FakeSession)
client = MCPClient(
MCPServerConfig(
name="sse-server",
transport="sse",
url="http://localhost/sse",
headers={"Authorization": "Bearer token"},
)
)
client.connect()
tools = client.list_tools()
assert [tool.name for tool in tools] == ["search"]
assert tools[0].description == "Search docs"
assert contexts[0].url == "http://localhost/sse"
assert contexts[0].headers == {"Authorization": "Bearer token"}
assert contexts[0].timeout == 30.0
client.disconnect()
assert contexts[0].exited is True
def test_call_tool_retries_once_on_connect_error_for_unix(monkeypatch):
client = MCPClient(MCPServerConfig(name="unix-server", transport="unix"))
client._connected = True # noqa: SLF001 - direct unit test
client._tools = { # noqa: SLF001 - direct unit test
"ping": MCPTool("ping", "Ping tool", {}, "unix-server")
}
first_error = httpx.ConnectError("first failure")
calls = {"count": 0}
reconnects = []
def fake_call_tool_http(tool_name, arguments):
calls["count"] += 1
if calls["count"] == 1:
raise first_error
return [{"type": "text", "text": f"{tool_name}:{arguments['value']}"}]
monkeypatch.setattr(client, "_call_tool_http", fake_call_tool_http)
monkeypatch.setattr(client, "_reconnect", lambda: reconnects.append("reconnected"))
result = client.call_tool("ping", {"value": "ok"})
assert result == [{"type": "text", "text": "ping:ok"}]
assert calls["count"] == 2
assert reconnects == ["reconnected"]
def test_call_tool_retry_exhausted_raises_original_error_for_unix(monkeypatch):
client = MCPClient(MCPServerConfig(name="unix-server", transport="unix"))
client._connected = True # noqa: SLF001 - direct unit test
client._tools = { # noqa: SLF001 - direct unit test
"ping": MCPTool("ping", "Ping tool", {}, "unix-server")
}
first_error = httpx.ConnectError("first failure")
second_error = httpx.ConnectError("second failure")
calls = {"count": 0}
reconnects = []
def fake_call_tool_http(_tool_name, _arguments):
calls["count"] += 1
if calls["count"] == 1:
raise first_error
raise second_error
monkeypatch.setattr(client, "_call_tool_http", fake_call_tool_http)
monkeypatch.setattr(client, "_reconnect", lambda: reconnects.append("reconnected"))
with pytest.raises(httpx.ConnectError) as exc_info:
client.call_tool("ping", {"value": "ok"})
assert exc_info.value is first_error
assert calls["count"] == 2
assert reconnects == ["reconnected"]
def test_call_tool_http_preserves_runtime_error_wrapping(monkeypatch):
client = MCPClient(MCPServerConfig(name="http-server", transport="http"))
client._connected = True # noqa: SLF001 - direct unit test
client._tools = { # noqa: SLF001 - direct unit test
"ping": MCPTool("ping", "Ping tool", {}, "http-server")
}
connect_error = httpx.ConnectError("first failure")
class FailingHttpClient:
def post(self, _path, json):
raise connect_error
client._http_client = FailingHttpClient() # noqa: SLF001 - direct unit test
reconnects = []
monkeypatch.setattr(client, "_reconnect", lambda: reconnects.append("reconnected"))
with pytest.raises(RuntimeError) as exc_info:
client.call_tool("ping", {"value": "ok"})
assert "Failed to call tool via HTTP" in str(exc_info.value)
assert exc_info.value.__cause__ is connect_error
assert reconnects == []
+172
View File
@@ -0,0 +1,172 @@
"""Tests for the shared MCP connection manager."""
import threading
import httpx
import pytest
from framework.runner.mcp_client import MCPServerConfig, MCPTool
from framework.runner.mcp_connection_manager import MCPConnectionManager
class FakeMCPClient:
"""Minimal fake MCP client for connection manager tests."""
instances: list["FakeMCPClient"] = []
def __init__(self, config: MCPServerConfig):
self.config = config
self._connected = False
self.connect_calls = 0
self.disconnect_calls = 0
self.list_tools_calls = 0
self.list_tools_error: Exception | None = None
FakeMCPClient.instances.append(self)
def connect(self) -> None:
self.connect_calls += 1
self._connected = True
def disconnect(self) -> None:
self.disconnect_calls += 1
self._connected = False
def list_tools(self) -> list[MCPTool]:
self.list_tools_calls += 1
if self.list_tools_error is not None:
raise self.list_tools_error
return [MCPTool("ping", "Ping", {"type": "object"}, self.config.name)]
@pytest.fixture
def manager(monkeypatch):
monkeypatch.setattr("framework.runner.mcp_connection_manager.MCPClient", FakeMCPClient)
monkeypatch.setattr(MCPConnectionManager, "_instance", None)
FakeMCPClient.instances.clear()
manager = MCPConnectionManager.get_instance()
yield manager
manager.cleanup_all()
monkeypatch.setattr(MCPConnectionManager, "_instance", None)
FakeMCPClient.instances.clear()
def test_acquire_returns_same_client_for_same_server_name(manager):
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
client_one = manager.acquire(config)
client_two = manager.acquire(config)
assert client_one is client_two
assert manager._refcounts["shared"] == 2 # noqa: SLF001 - state assertion for unit test
assert len(FakeMCPClient.instances) == 1
def test_release_with_refcount_above_one_keeps_connection_open(manager):
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
client = manager.acquire(config)
manager.acquire(config)
manager.release("shared")
assert client.disconnect_calls == 0
assert manager._pool["shared"] is client # noqa: SLF001 - state assertion for unit test
assert manager._refcounts["shared"] == 1 # noqa: SLF001 - state assertion for unit test
def test_release_last_reference_disconnects_and_removes_from_pool(manager):
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
client = manager.acquire(config)
manager.release("shared")
assert client.disconnect_calls == 1
assert "shared" not in manager._pool # noqa: SLF001 - state assertion for unit test
assert "shared" not in manager._refcounts # noqa: SLF001 - state assertion for unit test
def test_concurrent_acquire_and_release_keeps_state_consistent(manager):
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
worker_count = 8
acquire_barrier = threading.Barrier(worker_count + 1)
release_barrier = threading.Barrier(worker_count)
acquired_clients: list[FakeMCPClient] = []
acquired_lock = threading.Lock()
def worker() -> None:
acquire_barrier.wait()
client = manager.acquire(config)
with acquired_lock:
acquired_clients.append(client)
release_barrier.wait()
manager.release("shared")
threads = [threading.Thread(target=worker) for _ in range(worker_count)]
for thread in threads:
thread.start()
acquire_barrier.wait()
for thread in threads:
thread.join()
assert len({id(client) for client in acquired_clients}) == 1
assert len(FakeMCPClient.instances) == 1
assert FakeMCPClient.instances[0].disconnect_calls == 1
assert manager._pool == {} # noqa: SLF001 - state assertion for unit test
assert manager._refcounts == {} # noqa: SLF001 - state assertion for unit test
def test_cleanup_all_disconnects_every_pooled_client(manager):
manager.acquire(MCPServerConfig(name="one", transport="stdio", command="echo"))
manager.acquire(MCPServerConfig(name="two", transport="stdio", command="echo"))
manager.cleanup_all()
assert len(FakeMCPClient.instances) == 2
assert all(client.disconnect_calls == 1 for client in FakeMCPClient.instances)
assert manager._pool == {} # noqa: SLF001 - state assertion for unit test
assert manager._refcounts == {} # noqa: SLF001 - state assertion for unit test
assert manager._configs == {} # noqa: SLF001 - state assertion for unit test
def test_reconnect_replaces_client_even_with_existing_refcount(manager):
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
original_client = manager.acquire(config)
manager.acquire(config)
replacement = manager.reconnect("shared")
assert replacement is not original_client
assert original_client.disconnect_calls == 1
assert manager._pool["shared"] is replacement # noqa: SLF001 - state assertion for unit test
assert manager._refcounts["shared"] == 2 # noqa: SLF001 - state assertion for unit test
def test_health_check_returns_false_when_server_is_unreachable(manager, monkeypatch):
config = MCPServerConfig(name="shared", transport="http", url="http://localhost:9")
manager.acquire(config)
class FailingHttpClient:
def __init__(self, **_kwargs):
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def get(self, _path: str):
raise httpx.ConnectError("unreachable")
monkeypatch.setattr("framework.runner.mcp_connection_manager.httpx.Client", FailingHttpClient)
assert manager.health_check("shared") is False
def test_health_check_for_stdio_returns_false_on_tools_list_error(manager):
config = MCPServerConfig(name="shared", transport="stdio", command="echo")
client = manager.acquire(config)
client.list_tools_error = RuntimeError("broken")
assert manager.health_check("shared") is False
@@ -21,3 +21,8 @@ def test_minimax_provider_prefix_maps_to_minimax_api_key():
def test_minimax_model_name_prefix_maps_to_minimax_api_key():
runner = _runner_for_unit_test()
assert runner._get_api_key_env_var("minimax-chat") == "MINIMAX_API_KEY"
def test_openrouter_provider_prefix_maps_to_openrouter_api_key():
runner = _runner_for_unit_test()
assert runner._get_api_key_env_var("openrouter/x-ai/grok-4.20-beta") == "OPENROUTER_API_KEY"
+520
View File
@@ -0,0 +1,520 @@
"""Tests for safe_eval — the sandboxed expression evaluator used by edge conditions.
Covers: literals, data structures, arithmetic, comparisons, boolean logic
(including short-circuit semantics), variable lookup, subscript/attribute
access, whitelisted function calls, method calls, ternary expressions,
chained comparisons, and security boundaries (private attrs, disallowed
AST nodes, disallowed function calls).
"""
import pytest
from framework.graph.safe_eval import safe_eval
# ---------------------------------------------------------------------------
# Literals and constants
# ---------------------------------------------------------------------------
class TestLiterals:
def test_integer(self):
assert safe_eval("42") == 42
def test_negative_integer(self):
assert safe_eval("-1") == -1
def test_float(self):
assert safe_eval("3.14") == pytest.approx(3.14)
def test_string(self):
assert safe_eval("'hello'") == "hello"
def test_double_quoted_string(self):
assert safe_eval('"world"') == "world"
def test_boolean_true(self):
assert safe_eval("True") is True
def test_boolean_false(self):
assert safe_eval("False") is False
def test_none(self):
assert safe_eval("None") is None
# ---------------------------------------------------------------------------
# Data structures
# ---------------------------------------------------------------------------
class TestDataStructures:
def test_list(self):
assert safe_eval("[1, 2, 3]") == [1, 2, 3]
def test_empty_list(self):
assert safe_eval("[]") == []
def test_nested_list(self):
assert safe_eval("[[1, 2], [3, 4]]") == [[1, 2], [3, 4]]
def test_tuple(self):
assert safe_eval("(1, 2, 3)") == (1, 2, 3)
def test_dict(self):
assert safe_eval("{'a': 1, 'b': 2}") == {"a": 1, "b": 2}
def test_empty_dict(self):
assert safe_eval("{}") == {}
# ---------------------------------------------------------------------------
# Arithmetic and binary operators
# ---------------------------------------------------------------------------
class TestArithmetic:
def test_addition(self):
assert safe_eval("2 + 3") == 5
def test_subtraction(self):
assert safe_eval("10 - 4") == 6
def test_multiplication(self):
assert safe_eval("3 * 7") == 21
def test_division(self):
assert safe_eval("10 / 4") == 2.5
def test_floor_division(self):
assert safe_eval("10 // 3") == 3
def test_modulo(self):
assert safe_eval("10 % 3") == 1
def test_power(self):
assert safe_eval("2 ** 10") == 1024
def test_complex_expression(self):
assert safe_eval("(2 + 3) * 4 - 1") == 19
# ---------------------------------------------------------------------------
# Unary operators
# ---------------------------------------------------------------------------
class TestUnaryOps:
def test_negation(self):
assert safe_eval("-5") == -5
def test_positive(self):
assert safe_eval("+5") == 5
def test_not_true(self):
assert safe_eval("not True") is False
def test_not_false(self):
assert safe_eval("not False") is True
def test_bitwise_invert(self):
assert safe_eval("~0") == -1
# ---------------------------------------------------------------------------
# Comparisons
# ---------------------------------------------------------------------------
class TestComparisons:
def test_equal(self):
assert safe_eval("1 == 1") is True
def test_not_equal(self):
assert safe_eval("1 != 2") is True
def test_less_than(self):
assert safe_eval("1 < 2") is True
def test_greater_than(self):
assert safe_eval("2 > 1") is True
def test_less_equal(self):
assert safe_eval("2 <= 2") is True
def test_greater_equal(self):
assert safe_eval("3 >= 2") is True
def test_is_none(self):
assert safe_eval("x is None", {"x": None}) is True
def test_is_not_none(self):
assert safe_eval("x is not None", {"x": 42}) is True
def test_in_list(self):
assert safe_eval("'a' in x", {"x": ["a", "b", "c"]}) is True
def test_not_in_list(self):
assert safe_eval("'z' not in x", {"x": ["a", "b"]}) is True
def test_chained_comparison(self):
"""Chained comparisons like 1 < x < 10 should work."""
assert safe_eval("1 < x < 10", {"x": 5}) is True
def test_chained_comparison_false(self):
assert safe_eval("1 < x < 3", {"x": 5}) is False
def test_chained_three_way(self):
assert safe_eval("0 <= x <= 100", {"x": 50}) is True
# ---------------------------------------------------------------------------
# Boolean operators (with short-circuit semantics)
# ---------------------------------------------------------------------------
class TestBooleanOps:
def test_and_true(self):
assert safe_eval("True and True") is True
def test_and_false(self):
assert safe_eval("True and False") is False
def test_or_true(self):
assert safe_eval("False or True") is True
def test_or_false(self):
assert safe_eval("False or False") is False
def test_and_returns_last_truthy(self):
"""Python `and` returns the last value if all truthy."""
assert safe_eval("1 and 2 and 3") == 3
def test_and_returns_first_falsy(self):
"""Python `and` returns the first falsy value."""
assert safe_eval("1 and 0 and 3") == 0
def test_or_returns_first_truthy(self):
"""Python `or` returns the first truthy value."""
assert safe_eval("0 or '' or 42") == 42
def test_or_returns_last_falsy(self):
"""Python `or` returns the last value if all falsy."""
assert safe_eval("0 or '' or None") is None
def test_and_short_circuits(self):
"""and should NOT evaluate the right side if left is falsy.
This is the bug we fixed previously this would crash with
TypeError because all operands were eagerly evaluated.
"""
# x is None, so `x.get("key")` would crash if evaluated
assert safe_eval("x is not None and x.get('key')", {"x": None}) is False
def test_or_short_circuits(self):
"""or should NOT evaluate the right side if left is truthy."""
# x is truthy, so the crash-prone right side should never run
assert safe_eval("x or y.get('missing')", {"x": "found", "y": {}}) == "found"
def test_and_guard_pattern_truthy(self):
"""Guard pattern: check not None, then access — when value exists."""
ctx = {"x": {"key": "value"}}
assert safe_eval("x is not None and x.get('key')", ctx) == "value"
def test_multi_and(self):
assert safe_eval("True and True and True") is True
def test_multi_or(self):
assert safe_eval("False or False or True") is True
def test_mixed_and_or(self):
assert safe_eval("True or False and False") is True
# ---------------------------------------------------------------------------
# Ternary (if/else) expressions
# ---------------------------------------------------------------------------
class TestTernary:
def test_ternary_true_branch(self):
assert safe_eval("'yes' if True else 'no'") == "yes"
def test_ternary_false_branch(self):
assert safe_eval("'yes' if False else 'no'") == "no"
def test_ternary_with_context(self):
assert safe_eval("x * 2 if x > 0 else -x", {"x": 5}) == 10
def test_ternary_false_with_context(self):
assert safe_eval("x * 2 if x > 0 else -x", {"x": -3}) == 3
# ---------------------------------------------------------------------------
# Variable lookup
# ---------------------------------------------------------------------------
class TestVariables:
def test_simple_variable(self):
assert safe_eval("x", {"x": 42}) == 42
def test_string_variable(self):
assert safe_eval("name", {"name": "Alice"}) == "Alice"
def test_dict_variable(self):
ctx = {"output": {"status": "ok"}}
assert safe_eval("output", ctx) == {"status": "ok"}
def test_undefined_variable_raises(self):
with pytest.raises(NameError, match="not defined"):
safe_eval("undefined_var")
def test_multiple_variables(self):
assert safe_eval("x + y", {"x": 10, "y": 20}) == 30
# ---------------------------------------------------------------------------
# Subscript access (indexing)
# ---------------------------------------------------------------------------
class TestSubscript:
def test_dict_subscript(self):
assert safe_eval("d['key']", {"d": {"key": "value"}}) == "value"
def test_list_subscript(self):
assert safe_eval("items[0]", {"items": [10, 20, 30]}) == 10
def test_nested_subscript(self):
ctx = {"data": {"users": [{"name": "Alice"}]}}
assert safe_eval("data['users'][0]['name']", ctx) == "Alice"
def test_missing_key_raises(self):
with pytest.raises(KeyError):
safe_eval("d['missing']", {"d": {}})
# ---------------------------------------------------------------------------
# Attribute access
# ---------------------------------------------------------------------------
class TestAttributeAccess:
def test_private_attr_blocked(self):
"""Attributes starting with _ must be blocked for security."""
with pytest.raises(ValueError, match="private attribute"):
safe_eval("x.__class__", {"x": 42})
def test_dunder_blocked(self):
with pytest.raises(ValueError, match="private attribute"):
safe_eval("x.__dict__", {"x": {}})
def test_single_underscore_blocked(self):
with pytest.raises(ValueError, match="private attribute"):
safe_eval("x._internal", {"x": {}})
# ---------------------------------------------------------------------------
# Whitelisted function calls
# ---------------------------------------------------------------------------
class TestFunctionCalls:
def test_len(self):
assert safe_eval("len(x)", {"x": [1, 2, 3]}) == 3
def test_int_conversion(self):
assert safe_eval("int('42')") == 42
def test_float_conversion(self):
assert safe_eval("float('3.14')") == pytest.approx(3.14)
def test_str_conversion(self):
assert safe_eval("str(42)") == "42"
def test_bool_conversion(self):
assert safe_eval("bool(1)") is True
def test_abs(self):
assert safe_eval("abs(-5)") == 5
def test_min(self):
assert safe_eval("min(3, 1, 2)") == 1
def test_max(self):
assert safe_eval("max(3, 1, 2)") == 3
def test_sum(self):
assert safe_eval("sum(x)", {"x": [1, 2, 3]}) == 6
def test_round(self):
assert safe_eval("round(3.7)") == 4
def test_all(self):
assert safe_eval("all([True, True, True])") is True
def test_any(self):
assert safe_eval("any([False, False, True])") is True
def test_list_constructor(self):
assert safe_eval("list(x)", {"x": (1, 2, 3)}) == [1, 2, 3]
def test_dict_constructor(self):
assert safe_eval("dict(a=1, b=2)") == {"a": 1, "b": 2}
def test_tuple_constructor(self):
assert safe_eval("tuple(x)", {"x": [1, 2]}) == (1, 2)
def test_set_constructor(self):
assert safe_eval("set(x)", {"x": [1, 2, 2, 3]}) == {1, 2, 3}
# ---------------------------------------------------------------------------
# Whitelisted method calls
# ---------------------------------------------------------------------------
class TestMethodCalls:
def test_dict_get(self):
assert safe_eval("d.get('key', 'default')", {"d": {"key": "val"}}) == "val"
def test_dict_get_missing(self):
assert safe_eval("d.get('missing', 'default')", {"d": {}}) == "default"
def test_dict_keys(self):
result = safe_eval("list(d.keys())", {"d": {"a": 1, "b": 2}})
assert sorted(result) == ["a", "b"]
def test_dict_values(self):
result = safe_eval("list(d.values())", {"d": {"a": 1, "b": 2}})
assert sorted(result) == [1, 2]
def test_string_lower(self):
assert safe_eval("s.lower()", {"s": "HELLO"}) == "hello"
def test_string_upper(self):
assert safe_eval("s.upper()", {"s": "hello"}) == "HELLO"
def test_string_strip(self):
assert safe_eval("s.strip()", {"s": " hi "}) == "hi"
def test_string_split(self):
assert safe_eval("s.split(',')", {"s": "a,b,c"}) == ["a", "b", "c"]
# ---------------------------------------------------------------------------
# Security: disallowed operations
# ---------------------------------------------------------------------------
class TestSecurity:
def test_import_blocked(self):
"""__import__ is not in context, so NameError is raised."""
with pytest.raises(NameError, match="not defined"):
safe_eval("__import__('os')")
def test_lambda_blocked(self):
with pytest.raises(ValueError, match="not allowed"):
safe_eval("(lambda: 1)()")
def test_comprehension_blocked(self):
with pytest.raises(ValueError, match="not allowed"):
safe_eval("[x for x in range(10)]")
def test_assignment_blocked(self):
"""Assignment expressions should not parse in eval mode."""
with pytest.raises(SyntaxError):
safe_eval("x = 5")
def test_disallowed_function_blocked(self):
"""eval is not in safe functions, so NameError is raised."""
with pytest.raises(NameError, match="not defined"):
safe_eval("eval('1+1')")
def test_exec_blocked(self):
"""exec is not in safe functions, so NameError is raised."""
with pytest.raises(NameError, match="not defined"):
safe_eval("exec('x=1')")
def test_type_call_blocked(self):
"""type is not in safe functions, so NameError is raised."""
with pytest.raises(NameError, match="not defined"):
safe_eval("type(42)")
def test_getattr_builtin_blocked(self):
"""getattr is not in safe functions, so NameError is raised."""
with pytest.raises(NameError, match="not defined"):
safe_eval("getattr(x, '__class__')", {"x": 42})
def test_empty_expression_raises(self):
with pytest.raises(SyntaxError):
safe_eval("")
# ---------------------------------------------------------------------------
# Real-world edge condition patterns (from graph executor usage)
# ---------------------------------------------------------------------------
class TestEdgeConditionPatterns:
"""Patterns commonly used in EdgeSpec.condition_expr."""
def test_output_key_exists_and_not_none(self):
ctx = {"output": {"approved_contacts": ["alice@example.com"]}}
assert safe_eval("output.get('approved_contacts') is not None", ctx) is True
def test_output_key_missing(self):
ctx = {"output": {}}
assert safe_eval("output.get('approved_contacts') is not None", ctx) is False
def test_output_key_check_with_fallback(self):
ctx = {"output": {"redo_extraction": True}}
assert safe_eval("output.get('redo_extraction') is not None", ctx) is True
def test_guard_then_length_check(self):
"""Guard pattern: check key exists, then check length."""
ctx = {"output": {"results": [1, 2, 3]}}
assert (
safe_eval(
"output.get('results') is not None and len(output['results']) > 0",
ctx,
)
is True
)
def test_guard_short_circuits_on_none(self):
"""Guard pattern: short-circuit prevents crash on None."""
ctx = {"output": {}}
assert (
safe_eval(
"output.get('results') is not None and len(output['results']) > 0",
ctx,
)
is False
)
def test_success_flag_check(self):
ctx = {"output": {"success": True}, "memory": {"attempts": 2}}
assert safe_eval("output.get('success') == True", ctx) is True
def test_memory_threshold(self):
ctx = {"memory": {"score": 0.85}}
assert safe_eval("memory.get('score', 0) >= 0.8", ctx) is True
def test_string_contains_check(self):
ctx = {"output": {"status": "completed_with_warnings"}}
assert safe_eval("'completed' in output.get('status', '')", ctx) is True
def test_fallback_chain(self):
"""or-chain for fallback values."""
ctx = {"output": {}}
result = safe_eval(
"output.get('primary') or output.get('secondary') or 'default'",
ctx,
)
assert result == "default"
def test_no_context_needed(self):
"""Some edges use constant expressions."""
assert safe_eval("True") is True
assert safe_eval("1 == 1") is True
+142
View File
@@ -0,0 +1,142 @@
"""Tests for AS-9: Skill directory allowlisting in file-read tool interception."""
from unittest.mock import MagicMock
import pytest
from framework.llm.provider import ToolResult
def _make_tool_call_event(tool_name: str, path: str):
"""Build a minimal ToolCallEvent-like object."""
tc = MagicMock()
tc.tool_use_id = "tc-1"
tc.tool_name = tool_name
tc.tool_input = {"path": path}
return tc
def _make_node(skill_dirs: list[str]):
"""Build a minimal EventLoopNode with skill_dirs set."""
from framework.graph.event_loop_node import EventLoopNode
mock_result = ToolResult(tool_use_id="tc-1", content="from-executor")
node = EventLoopNode(tool_executor=MagicMock(return_value=mock_result))
node._skill_dirs = skill_dirs
return node
class TestSkillFileReadInterception:
@pytest.mark.asyncio
async def test_reads_file_in_skill_dir(self, tmp_path):
"""File under a skill dir is read directly, bypassing the executor."""
skill_dir = tmp_path / "my-skill"
skill_dir.mkdir()
script = skill_dir / "scripts" / "run.py"
script.parent.mkdir()
script.write_text("print('hello')")
node = _make_node([str(skill_dir)])
tc = _make_tool_call_event("view_file", str(script))
result = await node._execute_tool(tc)
assert result.content == "print('hello')"
assert not result.is_error
node._tool_executor.assert_not_called()
@pytest.mark.asyncio
async def test_skill_md_read_marked_as_skill_content(self, tmp_path):
"""Reading SKILL.md sets is_skill_content=True for AS-10 protection."""
skill_dir = tmp_path / "my-skill"
skill_dir.mkdir()
skill_md = skill_dir / "SKILL.md"
skill_md.write_text("---\nname: my-skill\ndescription: Test\n---\nInstructions.")
node = _make_node([str(skill_dir)])
tc = _make_tool_call_event("view_file", str(skill_md))
result = await node._execute_tool(tc)
assert result.is_skill_content is True
assert not result.is_error
@pytest.mark.asyncio
async def test_non_skill_md_resource_not_marked(self, tmp_path):
"""Bundled resource (not SKILL.md) is NOT marked as skill_content."""
skill_dir = tmp_path / "my-skill"
skill_dir.mkdir()
ref = skill_dir / "references" / "api.md"
ref.parent.mkdir()
ref.write_text("# API Reference")
node = _make_node([str(skill_dir)])
tc = _make_tool_call_event("load_data", str(ref))
result = await node._execute_tool(tc)
assert result.is_skill_content is False
assert not result.is_error
@pytest.mark.asyncio
async def test_path_outside_skill_dir_goes_to_executor(self, tmp_path):
"""Path outside skill dirs is passed through to the executor unchanged."""
skill_dir = tmp_path / "my-skill"
skill_dir.mkdir()
other_file = tmp_path / "other" / "file.txt"
other_file.parent.mkdir()
other_file.write_text("other content")
node = _make_node([str(skill_dir)])
tc = _make_tool_call_event("view_file", str(other_file))
result = await node._execute_tool(tc)
assert result.content == "from-executor"
node._tool_executor.assert_called_once()
@pytest.mark.asyncio
async def test_no_skill_dirs_goes_to_executor(self, tmp_path):
"""When skill_dirs is empty, all tool calls go to executor."""
skill_dir = tmp_path / "my-skill"
skill_dir.mkdir()
script = skill_dir / "scripts" / "run.py"
script.parent.mkdir()
script.write_text("print('hello')")
node = _make_node([])
tc = _make_tool_call_event("view_file", str(script))
result = await node._execute_tool(tc)
assert result.content == "from-executor"
node._tool_executor.assert_called_once()
@pytest.mark.asyncio
async def test_missing_file_returns_error(self, tmp_path):
"""Non-existent file under skill dir returns is_error=True."""
skill_dir = tmp_path / "my-skill"
skill_dir.mkdir()
missing = skill_dir / "scripts" / "missing.py"
node = _make_node([str(skill_dir)])
tc = _make_tool_call_event("view_file", str(missing))
result = await node._execute_tool(tc)
assert result.is_error is True
assert "Could not read skill resource" in result.content
@pytest.mark.asyncio
async def test_non_file_read_tool_goes_to_executor(self, tmp_path):
"""Non file-read tools (e.g. web_search) bypass the interceptor."""
skill_dir = tmp_path / "my-skill"
skill_dir.mkdir()
node = _make_node([str(skill_dir)])
tc = _make_tool_call_event("web_search", str(skill_dir / "SKILL.md"))
result = await node._execute_tool(tc)
assert result.content == "from-executor"
node._tool_executor.assert_called_once()
+8 -1
View File
@@ -69,7 +69,13 @@ class TestSkillCatalog:
def test_to_prompt_xml_generation(self):
skills = [
_make_skill("alpha", "Alpha skill", "project", location="/p/alpha/SKILL.md"),
_make_skill(
"alpha",
"Alpha skill",
"project",
location="/p/alpha/SKILL.md",
base_dir="/p/alpha",
),
_make_skill("beta", "Beta skill", "user", location="/u/beta/SKILL.md"),
]
catalog = SkillCatalog(skills)
@@ -81,6 +87,7 @@ class TestSkillCatalog:
assert "<name>beta</name>" in prompt
assert "<description>Alpha skill</description>" in prompt
assert "<location>/p/alpha/SKILL.md</location>" in prompt
assert "<base_dir>/p/alpha</base_dir>" in prompt
def test_to_prompt_sorted_by_name(self):
skills = [
@@ -0,0 +1,90 @@
"""Tests for AS-10: Activated skill content protected from context pruning."""
import pytest
from framework.graph.conversation import Message, NodeConversation
def _make_conversation() -> NodeConversation:
conv = NodeConversation.__new__(NodeConversation)
conv._messages = []
conv._next_seq = 0
conv._current_phase = None
conv._store = None
return conv
async def _add_tool_msg(conv: NodeConversation, content: str, **kwargs) -> Message:
return await conv.add_tool_result(
tool_use_id=f"tc-{conv._next_seq}",
content=content,
**kwargs,
)
class TestSkillContentProtection:
@pytest.mark.asyncio
async def test_is_skill_content_flag_persists(self):
"""Message created with is_skill_content=True retains the flag."""
conv = _make_conversation()
msg = await _add_tool_msg(conv, "skill instructions", is_skill_content=True)
assert msg.is_skill_content is True
@pytest.mark.asyncio
async def test_regular_message_not_marked(self):
"""Normal tool result messages are not marked as skill content."""
conv = _make_conversation()
msg = await _add_tool_msg(conv, "some tool output")
assert msg.is_skill_content is False
@pytest.mark.asyncio
async def test_skill_content_survives_prune(self):
"""Skill content messages are skipped by prune_old_tool_results."""
conv = _make_conversation()
# Add many regular tool results to push over prune threshold
for _ in range(30):
await _add_tool_msg(conv, "x" * 500) # ~125 tokens each
# Add a skill content message
skill_msg = await _add_tool_msg(
conv,
"## Deep Research\n" + "instructions " * 200,
is_skill_content=True,
)
pruned = await conv.prune_old_tool_results(protect_tokens=500, min_prune_tokens=100)
assert pruned > 0, "Expected some messages to be pruned"
# Find the skill message — it must not be pruned
matching = [m for m in conv._messages if m.seq == skill_msg.seq]
assert matching, "Skill content message was removed"
assert not matching[0].content.startswith("[Pruned tool result")
@pytest.mark.asyncio
async def test_regular_content_can_be_pruned(self):
"""Regular tool results are still pruned when over threshold."""
conv = _make_conversation()
for _ in range(20):
await _add_tool_msg(conv, "regular tool output " * 50)
pruned = await conv.prune_old_tool_results(protect_tokens=500, min_prune_tokens=100)
assert pruned > 0, "Expected regular messages to be pruned"
@pytest.mark.asyncio
async def test_error_messages_also_protected(self):
"""Existing is_error protection still works alongside is_skill_content."""
conv = _make_conversation()
for _ in range(20):
await _add_tool_msg(conv, "output " * 100)
err_msg = await _add_tool_msg(conv, "tool failed", is_error=True)
await conv.prune_old_tool_results(protect_tokens=200, min_prune_tokens=50)
matching = [m for m in conv._messages if m.seq == err_msg.seq]
assert matching
assert not matching[0].content.startswith("[Pruned tool result")
+151
View File
@@ -0,0 +1,151 @@
"""Tests for skill system structured error codes and diagnostics."""
from __future__ import annotations
import logging
from framework.skills.skill_errors import (
SkillError,
SkillErrorCode,
log_skill_error,
)
class TestSkillErrorCode:
def test_all_codes_defined(self):
codes = {e.value for e in SkillErrorCode}
assert "SKILL_NOT_FOUND" in codes
assert "SKILL_PARSE_ERROR" in codes
assert "SKILL_ACTIVATION_FAILED" in codes
assert "SKILL_MISSING_DESCRIPTION" in codes
assert "SKILL_YAML_FIXUP" in codes
assert "SKILL_NAME_MISMATCH" in codes
assert "SKILL_COLLISION" in codes
class TestSkillError:
def test_code_stored(self):
err = SkillError(
code=SkillErrorCode.SKILL_NOT_FOUND,
what="Skill 'my-skill' not found",
why="Not in catalog",
fix="Check discovery paths",
)
assert err.code == SkillErrorCode.SKILL_NOT_FOUND
def test_message_format(self):
err = SkillError(
code=SkillErrorCode.SKILL_MISSING_DESCRIPTION,
what="Missing description in '/path/SKILL.md'",
why="The description field is absent",
fix="Add a description field to the frontmatter",
)
expected = (
"[SKILL_MISSING_DESCRIPTION]\n"
"What failed: Missing description in '/path/SKILL.md'\n"
"Why: The description field is absent\n"
"Fix: Add a description field to the frontmatter"
)
assert str(err) == expected
def test_is_exception(self):
err = SkillError(
code=SkillErrorCode.SKILL_PARSE_ERROR,
what="Parse failed",
why="Invalid YAML",
fix="Fix the YAML",
)
assert isinstance(err, Exception)
def test_what_why_fix_attributes(self):
err = SkillError(
code=SkillErrorCode.SKILL_COLLISION,
what="Name collision",
why="Two skills share the same name",
fix="Rename one skill directory",
)
assert err.what == "Name collision"
assert err.why == "Two skills share the same name"
assert err.fix == "Rename one skill directory"
class TestLogSkillError:
def test_emits_log(self, caplog):
test_logger = logging.getLogger("test_skill")
with caplog.at_level(logging.ERROR, logger="test_skill"):
log_skill_error(
test_logger,
"error",
SkillErrorCode.SKILL_PARSE_ERROR,
what="Invalid SKILL.md at '/path'",
why="Empty file",
fix="Add content",
)
assert "SKILL_PARSE_ERROR" in caplog.text
def test_warning_level(self, caplog):
test_logger = logging.getLogger("test_skill_warn")
with caplog.at_level(logging.WARNING, logger="test_skill_warn"):
log_skill_error(
test_logger,
"warning",
SkillErrorCode.SKILL_YAML_FIXUP,
what="Auto-fixed YAML",
why="Unquoted colons",
fix="Quote values",
)
assert "SKILL_YAML_FIXUP" in caplog.text
def test_message_contains_all_parts(self, caplog):
test_logger = logging.getLogger("test_skill_parts")
with caplog.at_level(logging.ERROR, logger="test_skill_parts"):
log_skill_error(
test_logger,
"error",
SkillErrorCode.SKILL_NOT_FOUND,
what="Skill not found",
why="Not discovered",
fix="Check paths",
)
assert "Skill not found" in caplog.text
assert "Not discovered" in caplog.text
assert "Check paths" in caplog.text
class TestSkillErrorInParser:
def test_missing_description_returns_none(self, tmp_path):
from framework.skills.parser import parse_skill_md
skill_dir = tmp_path / "no-desc"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text("---\nname: no-desc\n---\nBody.\n", encoding="utf-8")
result = parse_skill_md(skill_dir / "SKILL.md")
assert result is None
def test_empty_file_returns_none(self, tmp_path):
from framework.skills.parser import parse_skill_md
skill_dir = tmp_path / "empty"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text("", encoding="utf-8")
result = parse_skill_md(skill_dir / "SKILL.md")
assert result is None
def test_nonexistent_returns_none(self, tmp_path):
from framework.skills.parser import parse_skill_md
result = parse_skill_md(tmp_path / "ghost" / "SKILL.md")
assert result is None
def test_yaml_fixup_still_parses(self, tmp_path):
from framework.skills.parser import parse_skill_md
skill_dir = tmp_path / "colon-test"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text(
"---\nname: colon-test\ndescription: Use for: research\n---\nBody.\n",
encoding="utf-8",
)
result = parse_skill_md(skill_dir / "SKILL.md")
assert result is not None
assert "research" in result.description
+92
View File
@@ -0,0 +1,92 @@
"""Tests for AS-6 skill resource loading support.
Covers:
- <base_dir> element in catalog XML
- allowlisted_dirs property reflects trusted skill base directories
- skill_dirs propagation to NodeContext
"""
from framework.skills.catalog import SkillCatalog
from framework.skills.parser import ParsedSkill
def _make_skill(
name: str,
base_dir: str,
source_scope: str = "project",
) -> ParsedSkill:
return ParsedSkill(
name=name,
description=f"Skill {name}",
location=f"{base_dir}/SKILL.md",
base_dir=base_dir,
source_scope=source_scope,
body="Instructions.",
)
class TestSkillResourceBaseDir:
def test_base_dir_in_xml(self):
"""Each community skill entry should expose its base_dir in the catalog XML."""
skill = _make_skill("deploy", "/project/.hive/skills/deploy")
catalog = SkillCatalog([skill])
prompt = catalog.to_prompt()
assert "<base_dir>/project/.hive/skills/deploy</base_dir>" in prompt
def test_base_dir_xml_escaped(self):
"""base_dir with XML-special chars should be escaped."""
skill = _make_skill("s", "/path/with <&> chars")
catalog = SkillCatalog([skill])
prompt = catalog.to_prompt()
assert "<base_dir>/path/with &lt;&amp;&gt; chars</base_dir>" in prompt
def test_base_dir_absent_for_framework_skills(self):
"""Framework-scope skills are filtered from the catalog, so no base_dir either."""
skill = _make_skill("fw", "/hive/_default_skills/fw", source_scope="framework")
catalog = SkillCatalog([skill])
assert catalog.to_prompt() == ""
def test_allowlisted_dirs_matches_skills(self):
"""allowlisted_dirs returns all skill base_dirs including framework ones."""
skills = [
_make_skill("a", "/skills/a", "project"),
_make_skill("b", "/skills/b", "user"),
_make_skill("c", "/skills/c", "framework"),
]
catalog = SkillCatalog(skills)
dirs = catalog.allowlisted_dirs
assert "/skills/a" in dirs
assert "/skills/b" in dirs
assert "/skills/c" in dirs
def test_allowlisted_dirs_empty_catalog(self):
assert SkillCatalog().allowlisted_dirs == []
class TestSkillDirsPropagation:
def _make_ctx(self, **kwargs):
from unittest.mock import MagicMock
from framework.graph.node import NodeContext
return NodeContext(
runtime=MagicMock(),
node_id="n",
node_spec=MagicMock(),
memory={},
**kwargs,
)
def test_node_context_skill_dirs_default(self):
"""NodeContext.skill_dirs defaults to empty list."""
ctx = self._make_ctx()
assert ctx.skill_dirs == []
def test_node_context_skill_dirs_set(self):
"""NodeContext.skill_dirs can be populated."""
dirs = ["/skills/a", "/skills/b"]
ctx = self._make_ctx(skill_dirs=dirs)
assert ctx.skill_dirs == dirs
+471
View File
@@ -0,0 +1,471 @@
"""Tests for skill trust gating (AS-13)."""
from __future__ import annotations
import json
from unittest.mock import MagicMock, patch
from framework.skills.parser import ParsedSkill
from framework.skills.trust import (
ProjectTrustClassification,
ProjectTrustDetector,
TrustedRepoStore,
TrustGate,
_is_localhost_remote,
_normalize_remote_url,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def make_skill(name: str = "test-skill", scope: str = "project") -> ParsedSkill:
return ParsedSkill(
name=name,
description="Test skill",
location=f"/fake/{name}/SKILL.md",
base_dir=f"/fake/{name}",
source_scope=scope,
body="Test skill instructions.",
)
# ---------------------------------------------------------------------------
# _normalize_remote_url
# ---------------------------------------------------------------------------
class TestNormalizeRemoteUrl:
def test_ssh_scp_format(self):
assert _normalize_remote_url("git@github.com:org/repo.git") == "github.com/org/repo"
def test_https_format(self):
assert _normalize_remote_url("https://github.com/org/repo.git") == "github.com/org/repo"
def test_https_no_dot_git(self):
assert _normalize_remote_url("https://github.com/org/repo") == "github.com/org/repo"
def test_ssh_url_format(self):
assert _normalize_remote_url("ssh://git@github.com/org/repo.git") == "github.com/org/repo"
def test_lowercased(self):
assert _normalize_remote_url("git@GitHub.COM:Org/Repo.git") == "github.com/org/repo"
def test_trailing_slash_stripped(self):
assert _normalize_remote_url("https://github.com/org/repo/") == "github.com/org/repo"
def test_gitlab(self):
assert _normalize_remote_url("git@gitlab.com:team/project.git") == "gitlab.com/team/project"
# ---------------------------------------------------------------------------
# _is_localhost_remote
# ---------------------------------------------------------------------------
class TestIsLocalhostRemote:
def test_localhost_https(self):
assert _is_localhost_remote("http://localhost/org/repo")
def test_127_0_0_1(self):
assert _is_localhost_remote("https://127.0.0.1/repo")
def test_github_not_local(self):
assert not _is_localhost_remote("https://github.com/org/repo")
def test_scp_localhost(self):
assert _is_localhost_remote("git@localhost:org/repo")
# ---------------------------------------------------------------------------
# TrustedRepoStore
# ---------------------------------------------------------------------------
class TestTrustedRepoStore:
def test_empty_store_is_not_trusted(self, tmp_path):
store = TrustedRepoStore(tmp_path / "trusted.json")
assert not store.is_trusted("github.com/org/repo")
def test_trust_and_lookup(self, tmp_path):
store = TrustedRepoStore(tmp_path / "trusted.json")
store.trust("github.com/org/repo", project_path="/some/path")
assert store.is_trusted("github.com/org/repo")
def test_revoke(self, tmp_path):
store = TrustedRepoStore(tmp_path / "trusted.json")
store.trust("github.com/org/repo")
assert store.revoke("github.com/org/repo")
assert not store.is_trusted("github.com/org/repo")
def test_revoke_nonexistent_returns_false(self, tmp_path):
store = TrustedRepoStore(tmp_path / "trusted.json")
assert not store.revoke("github.com/nobody/nowhere")
def test_persists_across_instances(self, tmp_path):
path = tmp_path / "trusted.json"
store1 = TrustedRepoStore(path)
store1.trust("github.com/org/repo")
store2 = TrustedRepoStore(path)
assert store2.is_trusted("github.com/org/repo")
def test_atomic_write(self, tmp_path):
"""Save must not leave a .tmp file behind."""
path = tmp_path / "trusted.json"
store = TrustedRepoStore(path)
store.trust("github.com/org/repo")
assert not (tmp_path / "trusted.tmp").exists()
assert path.exists()
def test_corrupted_json_recovers_gracefully(self, tmp_path):
path = tmp_path / "trusted.json"
path.write_text("{not valid json{{", encoding="utf-8")
store = TrustedRepoStore(path)
assert not store.is_trusted("github.com/any/repo") # no crash
def test_json_schema(self, tmp_path):
path = tmp_path / "trusted.json"
store = TrustedRepoStore(path)
store.trust("github.com/org/repo", project_path="/work/repo")
data = json.loads(path.read_text())
assert data["version"] == 1
assert data["entries"][0]["repo_key"] == "github.com/org/repo"
assert "added_at" in data["entries"][0]
def test_list_entries(self, tmp_path):
store = TrustedRepoStore(tmp_path / "t.json")
store.trust("github.com/a/b")
store.trust("github.com/c/d")
entries = store.list_entries()
assert len(entries) == 2
# ---------------------------------------------------------------------------
# ProjectTrustDetector
# ---------------------------------------------------------------------------
class TestProjectTrustDetector:
def test_none_project_dir_always_trusted(self, tmp_path):
store = TrustedRepoStore(tmp_path / "t.json")
det = ProjectTrustDetector(store)
cls, _ = det.classify(None)
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
def test_nonexistent_dir_always_trusted(self, tmp_path):
store = TrustedRepoStore(tmp_path / "t.json")
det = ProjectTrustDetector(store)
cls, _ = det.classify(tmp_path / "nonexistent")
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
def test_no_git_dir_always_trusted(self, tmp_path):
store = TrustedRepoStore(tmp_path / "t.json")
det = ProjectTrustDetector(store)
cls, _ = det.classify(tmp_path)
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
def test_no_remote_always_trusted(self, tmp_path):
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
det = ProjectTrustDetector(store)
# git command returns non-zero (no remote)
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(returncode=1, stdout="")
cls, _ = det.classify(tmp_path)
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
def test_localhost_remote_always_trusted(self, tmp_path):
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
det = ProjectTrustDetector(store)
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(
returncode=0, stdout="http://localhost/org/repo.git\n"
)
cls, _ = det.classify(tmp_path)
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
def test_trusted_by_store(self, tmp_path):
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
store.trust("github.com/trusted/repo")
det = ProjectTrustDetector(store)
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(
returncode=0, stdout="git@github.com:trusted/repo.git\n"
)
cls, key = det.classify(tmp_path)
assert cls == ProjectTrustClassification.TRUSTED_BY_USER
assert key == "github.com/trusted/repo"
def test_unknown_remote_untrusted(self, tmp_path):
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
det = ProjectTrustDetector(store)
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(
returncode=0, stdout="https://github.com/stranger/repo.git\n"
)
cls, key = det.classify(tmp_path)
assert cls == ProjectTrustClassification.UNTRUSTED
assert key == "github.com/stranger/repo"
def test_own_remotes_env_var(self, tmp_path, monkeypatch):
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
monkeypatch.setenv("HIVE_OWN_REMOTES", "github.com/myorg/*")
det = ProjectTrustDetector(store)
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(
returncode=0, stdout="git@github.com:myorg/myrepo.git\n"
)
cls, _ = det.classify(tmp_path)
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
def test_git_timeout_treated_as_trusted(self, tmp_path):
import subprocess
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
det = ProjectTrustDetector(store)
with patch("subprocess.run", side_effect=subprocess.TimeoutExpired("git", 3)):
cls, _ = det.classify(tmp_path)
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
def test_git_not_found_treated_as_trusted(self, tmp_path):
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
det = ProjectTrustDetector(store)
with patch("subprocess.run", side_effect=FileNotFoundError("git not found")):
cls, _ = det.classify(tmp_path)
assert cls == ProjectTrustClassification.ALWAYS_TRUSTED
# ---------------------------------------------------------------------------
# TrustGate
# ---------------------------------------------------------------------------
class TestTrustGate:
def test_framework_scope_always_passes(self, tmp_path):
skill = make_skill("fw-skill", "framework")
gate = TrustGate(store=TrustedRepoStore(tmp_path / "t.json"), interactive=False)
result = gate.filter_and_gate([skill], project_dir=None)
assert any(s.name == "fw-skill" for s in result)
def test_user_scope_always_passes(self, tmp_path):
skill = make_skill("user-skill", "user")
gate = TrustGate(store=TrustedRepoStore(tmp_path / "t.json"), interactive=False)
result = gate.filter_and_gate([skill], project_dir=None)
assert any(s.name == "user-skill" for s in result)
def test_no_project_skills_returns_early(self, tmp_path):
"""When there are no project-scope skills, trust detection is skipped."""
fw = make_skill("fw", "framework")
gate = TrustGate(store=TrustedRepoStore(tmp_path / "t.json"), interactive=False)
result = gate.filter_and_gate([fw], project_dir=tmp_path)
assert result == [fw]
def test_trusted_project_skills_pass(self, tmp_path):
"""Project skills from a trusted repo pass through."""
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
store.trust("github.com/trusted/repo")
skill = make_skill("proj-skill", "project")
gate = TrustGate(store=store, interactive=False)
with patch("subprocess.run") as m:
m.return_value = MagicMock(returncode=0, stdout="git@github.com:trusted/repo.git\n")
result = gate.filter_and_gate([skill], project_dir=tmp_path)
assert any(s.name == "proj-skill" for s in result)
def test_untrusted_headless_skips_and_logs(self, tmp_path, caplog):
"""In non-interactive mode, untrusted project skills are skipped."""
import logging
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
skill = make_skill("evil-skill", "project")
gate = TrustGate(store=store, interactive=False)
with patch("subprocess.run") as m:
m.return_value = MagicMock(
returncode=0, stdout="https://github.com/stranger/evil.git\n"
)
with caplog.at_level(logging.WARNING):
result = gate.filter_and_gate([skill], project_dir=tmp_path)
assert not any(s.name == "evil-skill" for s in result)
assert "untrusted" in caplog.text.lower() or "skipping" in caplog.text.lower()
def test_interactive_consent_session_only(self, tmp_path):
"""Option 1 (session only) includes skills without writing to store."""
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
skill = make_skill("session-skill", "project")
outputs = []
gate = TrustGate(
store=store,
interactive=True,
print_fn=outputs.append,
input_fn=lambda _: "1", # trust this session
)
with (
patch("sys.stdin.isatty", return_value=True),
patch("sys.stdout.isatty", return_value=True),
patch("subprocess.run") as m,
):
m.return_value = MagicMock(
returncode=0, stdout="https://github.com/stranger/repo.git\n"
)
result = gate.filter_and_gate([skill], project_dir=tmp_path)
assert any(s.name == "session-skill" for s in result)
# Must NOT persist to trusted store
assert not store.is_trusted("github.com/stranger/repo")
def test_interactive_consent_permanent(self, tmp_path):
"""Option 2 (permanent) includes skills and persists to trusted store."""
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
skill = make_skill("perm-skill", "project")
gate = TrustGate(
store=store,
interactive=True,
print_fn=lambda _: None,
input_fn=lambda _: "2", # trust permanently
)
with (
patch("sys.stdin.isatty", return_value=True),
patch("sys.stdout.isatty", return_value=True),
patch("subprocess.run") as m,
):
m.return_value = MagicMock(
returncode=0, stdout="https://github.com/stranger/repo.git\n"
)
result = gate.filter_and_gate([skill], project_dir=tmp_path)
assert any(s.name == "perm-skill" for s in result)
assert store.is_trusted("github.com/stranger/repo")
def test_interactive_consent_deny(self, tmp_path):
"""Option 3 (deny) excludes project skills."""
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
skill = make_skill("bad-skill", "project")
gate = TrustGate(
store=store,
interactive=True,
print_fn=lambda _: None,
input_fn=lambda _: "3", # deny
)
with (
patch("sys.stdin.isatty", return_value=True),
patch("sys.stdout.isatty", return_value=True),
patch("subprocess.run") as m,
):
m.return_value = MagicMock(
returncode=0, stdout="https://github.com/stranger/repo.git\n"
)
result = gate.filter_and_gate([skill], project_dir=tmp_path)
assert not any(s.name == "bad-skill" for s in result)
def test_env_var_override_trusts_all(self, tmp_path, monkeypatch):
"""HIVE_TRUST_PROJECT_SKILLS=1 bypasses gating entirely."""
monkeypatch.setenv("HIVE_TRUST_PROJECT_SKILLS", "1")
store = TrustedRepoStore(tmp_path / "t.json")
skill = make_skill("env-skill", "project")
gate = TrustGate(store=store, interactive=False)
result = gate.filter_and_gate([skill], project_dir=tmp_path)
assert any(s.name == "env-skill" for s in result)
def test_keyboard_interrupt_treated_as_deny(self, tmp_path):
"""Ctrl-C during consent prompt should deny cleanly."""
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
skill = make_skill("interrupted-skill", "project")
gate = TrustGate(
store=store,
interactive=True,
print_fn=lambda _: None,
input_fn=lambda _: (_ for _ in ()).throw(KeyboardInterrupt()),
)
with (
patch("sys.stdin.isatty", return_value=True),
patch("sys.stdout.isatty", return_value=True),
patch("subprocess.run") as m,
):
m.return_value = MagicMock(
returncode=0, stdout="https://github.com/stranger/repo.git\n"
)
result = gate.filter_and_gate([skill], project_dir=tmp_path)
assert not any(s.name == "interrupted-skill" for s in result)
def test_security_notice_shown_once(self, tmp_path, monkeypatch):
"""Security notice (NFR-5) should be shown the first time only."""
# Use a temp sentinel path
sentinel = tmp_path / ".skill_trust_notice_shown"
monkeypatch.setattr("framework.skills.trust._NOTICE_SENTINEL_PATH", sentinel)
assert not sentinel.exists()
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
skill = make_skill("notice-skill", "project")
output_lines: list[str] = []
gate = TrustGate(
store=store,
interactive=True,
print_fn=output_lines.append,
input_fn=lambda _: "3",
)
with (
patch("sys.stdin.isatty", return_value=True),
patch("sys.stdout.isatty", return_value=True),
patch("subprocess.run") as m,
):
m.return_value = MagicMock(
returncode=0, stdout="https://github.com/stranger/repo.git\n"
)
gate.filter_and_gate([skill], project_dir=tmp_path)
assert sentinel.exists()
assert any("Security notice" in line for line in output_lines)
# Second run should NOT show the notice again
output_lines.clear()
skill2 = make_skill("notice-skill-2", "project")
with (
patch("sys.stdin.isatty", return_value=True),
patch("sys.stdout.isatty", return_value=True),
patch("subprocess.run") as m,
):
m.return_value = MagicMock(
returncode=0, stdout="https://github.com/stranger/repo.git\n"
)
gate.filter_and_gate([skill2], project_dir=tmp_path)
assert not any("Security notice" in line for line in output_lines)
def test_mixed_scopes_only_project_gated(self, tmp_path, monkeypatch):
"""Framework and user skills should pass through even if project skills are denied."""
(tmp_path / ".git").mkdir()
store = TrustedRepoStore(tmp_path / "t.json")
fw_skill = make_skill("fw", "framework")
user_skill = make_skill("usr", "user")
proj_skill = make_skill("proj", "project")
gate = TrustGate(
store=store,
interactive=True,
print_fn=lambda _: None,
input_fn=lambda _: "3", # deny project skills
)
with (
patch("sys.stdin.isatty", return_value=True),
patch("sys.stdout.isatty", return_value=True),
patch("subprocess.run") as m,
):
m.return_value = MagicMock(
returncode=0, stdout="https://github.com/stranger/repo.git\n"
)
result = gate.filter_and_gate([fw_skill, user_skill, proj_skill], project_dir=tmp_path)
names = {s.name for s in result}
assert "fw" in names
assert "usr" in names
assert "proj" not in names
+115
View File
@@ -8,6 +8,7 @@ could cause a json.JSONDecodeError and crash execution.
import textwrap
from pathlib import Path
from types import SimpleNamespace
from framework.runner.tool_registry import ToolRegistry
@@ -91,3 +92,117 @@ def test_discover_from_module_handles_empty_content(tmp_path):
result = registered.executor({})
assert isinstance(result, dict)
assert result == {}
class _RegistryFakeClient:
def __init__(self, config):
self.config = config
self.connect_calls = 0
self.disconnect_calls = 0
def connect(self) -> None:
self.connect_calls += 1
def disconnect(self) -> None:
self.disconnect_calls += 1
def list_tools(self):
return [
SimpleNamespace(
name="pooled_tool",
description="Tool from MCP",
input_schema={"type": "object", "properties": {}, "required": []},
)
]
def call_tool(self, tool_name, arguments):
return [{"text": f"{tool_name}:{arguments}"}]
def test_register_mcp_server_uses_connection_manager_when_enabled(monkeypatch):
registry = ToolRegistry()
client = _RegistryFakeClient(SimpleNamespace(name="shared"))
manager_calls: list[tuple[str, str]] = []
class FakeManager:
def acquire(self, config):
manager_calls.append(("acquire", config.name))
client.config = config
return client
def release(self, server_name: str) -> None:
manager_calls.append(("release", server_name))
monkeypatch.setattr(
"framework.runner.mcp_connection_manager.MCPConnectionManager.get_instance",
lambda: FakeManager(),
)
count = registry.register_mcp_server(
{"name": "shared", "transport": "stdio", "command": "echo"},
use_connection_manager=True,
)
assert count == 1
assert manager_calls == [("acquire", "shared")]
registry.cleanup()
assert manager_calls == [("acquire", "shared"), ("release", "shared")]
assert client.disconnect_calls == 0
def test_register_mcp_server_defaults_to_connection_manager(monkeypatch):
"""Default behavior uses the connection manager (reuse enabled by default)."""
registry = ToolRegistry()
created_clients: list[_RegistryFakeClient] = []
def fake_client_factory(config):
client = _RegistryFakeClient(config)
created_clients.append(client)
return client
class FakeManager:
def acquire(self, config):
return fake_client_factory(config)
def release(self, server_name):
pass
monkeypatch.setattr(
"framework.runner.mcp_connection_manager.MCPConnectionManager.get_instance",
lambda: FakeManager(),
)
count = registry.register_mcp_server(
{"name": "direct", "transport": "stdio", "command": "echo"},
)
assert count == 1
assert len(created_clients) == 1
def test_register_mcp_server_direct_client_when_manager_disabled(monkeypatch):
"""When use_connection_manager=False, a direct MCPClient is created."""
registry = ToolRegistry()
created_clients: list[_RegistryFakeClient] = []
def fake_client_factory(config):
client = _RegistryFakeClient(config)
created_clients.append(client)
return client
monkeypatch.setattr("framework.runner.mcp_client.MCPClient", fake_client_factory)
count = registry.register_mcp_server(
{"name": "direct", "transport": "stdio", "command": "echo"},
use_connection_manager=False,
)
assert count == 1
assert len(created_clients) == 1
assert created_clients[0].connect_calls == 1
registry.cleanup()
assert created_clients[0].disconnect_calls == 1
+4 -4
View File
@@ -157,7 +157,7 @@ All bounty types open in parallel. Contributors self-select. Daily progress upda
PR merged with bounty:* label
→ GitHub Action runs bounty-tracker.ts
→ Calculates points from label
→ Resolves GitHub → Discord ID via contributors.yml
→ Resolves GitHub → Discord ID via MongoDB (hive.contributors)
→ Pushes XP to Lurkr API
→ Posts notification to #integrations-announcements
```
@@ -166,7 +166,7 @@ See the [Setup Guide](setup-guide.md) for full configuration (Lurkr, webhooks, s
### Identity Linking
Contributors link GitHub ↔ Discord by opening a [Link Discord Account](https://github.com/aden-hive/hive/issues/new?template=link-discord.yml) issue. A GitHub Action auto-adds them to `contributors.yml` and closes the issue.
Contributors link GitHub ↔ Discord by running `/link-github` in Discord. The bot verifies ownership via a public gist, then stores the mapping in MongoDB.
Without this link, bounties are still tracked but Lurkr can't push XP to your Discord account.
@@ -181,7 +181,7 @@ Without this link, bounties are still tracked but Lurkr can't push XP to your Di
| Agent Builder role | Lurkr bot | Auto-assigned at level 5 |
| OSS Contributor role | Lurkr bot | Auto-assigned at level 15 |
| Core Contributor role | Maintainer | Manual (involves money) |
| Identity linking | contributors.yml | PR-based, reviewed by maintainers |
| Identity linking | Discord bot → MongoDB | `/link-github` command with gist verification |
## Guides
@@ -203,4 +203,4 @@ Without this link, bounties are still tracked but Lurkr can't push XP to your Di
- `.github/workflows/weekly-leaderboard.yml` — Monday leaderboard post
- `scripts/bounty-tracker.ts` — Point calculation, Lurkr API, Discord formatting
- `scripts/setup-bounty-labels.sh` — One-time label setup
- `contributors.yml` — GitHub ↔ Discord identity mapping
- MongoDB `hive.contributors` — GitHub ↔ Discord identity mapping (managed by Discord bot)
+2 -4
View File
@@ -6,9 +6,7 @@ Earn XP, Discord roles, and eventually real money by contributing to the Aden ag
### 1. Link your GitHub and Discord
Open a [Link Discord Account](https://github.com/aden-hive/hive/issues/new?template=link-discord.yml) issue — just paste your Discord ID and submit. A GitHub Action will automatically add you to `contributors.yml` and close the issue.
To find your Discord ID: Discord Settings > Advanced > Enable **Developer Mode**, then right-click your name > **Copy User ID**.
Run `/link-github your-github-username` in Discord. The bot will give you a verification code — create a public gist with that code, then run `/verify`. Done.
Without this link, Lurkr can't push XP to your Discord account.
@@ -154,7 +152,7 @@ A: Yes. Most services have free tiers. The bounty issue links to where you get t
A: Contribute consistently across different bounty types for 4+ weeks. Maintainers will nominate you.
**Q: What if I haven't linked my Discord yet?**
A: You'll still get credit in GitHub, but no Lurkr XP or Discord roles. Add yourself to `contributors.yml`.
A: You'll still get credit in GitHub, but no Lurkr XP or Discord roles. Run `/link-github` in Discord.
## Quick Reference
+4 -2
View File
@@ -104,6 +104,8 @@ Repo Settings > Secrets and variables > Actions:
| `DISCORD_BOUNTY_WEBHOOK_URL` | Webhook URL from Step 5 |
| `LURKR_API_KEY` | Lurkr API key from Step 4f |
| `LURKR_GUILD_ID` | Your Discord server ID\* |
| `BOT_API_URL` | Discord bot API URL |
| `BOT_API_KEY` | Discord bot API key |
\*Enable Developer Mode in Discord, right-click server name > Copy Server ID.
@@ -146,12 +148,12 @@ powerbi, redis
- [ ] All 3 GitHub secrets added
- [ ] Both workflows enabled (`bounty-completed.yml`, `weekly-leaderboard.yml`)
- [ ] Test PR + merge triggers Discord notification
- [ ] `contributors.yml` exists at repo root
- [ ] MongoDB `hive.contributors` collection accessible
## Troubleshooting
**No Discord message:** Check `DISCORD_BOUNTY_WEBHOOK_URL` secret and Action logs.
**Lurkr XP not awarded:** Confirm API key is Read/Write, contributor is in `contributors.yml`, check Action logs for `Lurkr XP push failed`.
**Lurkr XP not awarded:** Confirm API key is Read/Write, contributor has run `/link-github` in Discord, check Action logs for `Lurkr XP push failed`.
**Role not assigned:** Verify role rewards in the Lurkr dashboard or via `/config set`. Lurkr's role must be above the roles it assigns in server hierarchy.
+290
View File
@@ -0,0 +1,290 @@
# Agent Skills User Guide
This guide covers how to use, create, and manage Agent Skills in the Hive framework. Agent Skills follow the open [Agent Skills standard](https://agentskills.io) — skills written for Claude Code, Cursor, or other compatible agents work in Hive unchanged.
## What are skills?
Skills are folders containing a `SKILL.md` file that teaches an agent how to perform a specific task. They can also bundle scripts, templates, and reference materials. Skills are loaded on demand — the agent sees a lightweight catalog at startup and pulls in full instructions only when relevant.
## Quick start
### Install a skill
Drop a skill folder into one of the discovery directories:
```bash
# Project-level (shared with the repo)
mkdir -p .hive/skills/my-skill
cat > .hive/skills/my-skill/SKILL.md << 'EOF'
---
name: my-skill
description: Does X when the user asks about Y.
---
# My Skill
Step-by-step instructions for the agent...
EOF
```
The agent will discover it automatically on the next session.
### List discovered skills
```bash
hive skill list
```
Output groups skills by scope:
```
PROJECT SKILLS
────────────────────────────────────
• my-skill
Does X when the user asks about Y.
/home/user/project/.hive/skills/my-skill/SKILL.md
USER SKILLS
────────────────────────────────────
• deep-research
Multi-step web research with source verification.
/home/user/.hive/skills/deep-research/SKILL.md
```
## Where to put skills
Hive scans five directories at startup, in this precedence order:
| Scope | Path | Use case |
|-------|------|----------|
| Project (Hive) | `<project>/.hive/skills/` | Skills specific to this repo |
| Project (cross-client) | `<project>/.agents/skills/` | Skills shared across Claude Code, Cursor, etc. |
| User (Hive) | `~/.hive/skills/` | Personal skills available in all projects |
| User (cross-client) | `~/.agents/skills/` | Personal cross-client skills |
| Framework | *(built-in)* | Default operational skills shipped with Hive |
**Precedence**: If two skills share the same name, the higher-precedence location wins. A project-level `code-review` skill overrides a user-level one with the same name.
**Cross-client paths**: The `.agents/skills/` directories are a convention shared across compatible agents. A skill installed at `~/.agents/skills/pdf-processing/` is visible to Hive, Claude Code, Cursor, and other compatible tools simultaneously.
## Creating a skill
### Directory structure
```
my-skill/
├── SKILL.md # Required — metadata + instructions
├── scripts/ # Optional — executable code
│ └── run.py
├── references/ # Optional — supplementary docs
│ └── api-reference.md
└── assets/ # Optional — templates, data files
└── template.json
```
### SKILL.md format
Every skill needs a `SKILL.md` with YAML frontmatter and a markdown body:
```markdown
---
name: my-skill
description: Extract and summarize PDF documents. Use when the user mentions PDFs or document extraction.
---
# PDF Processing
## When to use
Use this skill when the user needs to extract text from PDFs or merge documents.
## Steps
1. Check if pdfplumber is available...
2. Extract text using...
## Edge cases
- Scanned PDFs need OCR first...
```
### Frontmatter fields
| Field | Required | Description |
|-------|----------|-------------|
| `name` | Yes | Lowercase letters, numbers, hyphens. Must match the parent directory name. Max 64 chars. |
| `description` | Yes | What the skill does and when to use it. Max 1024 chars. Include keywords that help the agent match tasks. |
| `license` | No | License name or reference to a bundled LICENSE file. |
| `compatibility` | No | Environment requirements (e.g., "Requires git, docker"). |
| `metadata` | No | Arbitrary key-value pairs (author, version, etc.). |
| `allowed-tools` | No | Space-delimited list of pre-approved tools. |
### Writing good descriptions
The description is critical — it's what the agent uses to decide whether to activate a skill. Be specific:
```yaml
# Good — tells the agent what and when
description: Extract text and tables from PDF files, fill PDF forms, and merge multiple PDFs. Use when working with PDF documents or when the user mentions PDFs, forms, or document extraction.
# Bad — too vague for the agent to match
description: Helps with PDFs.
```
### Writing good instructions
The markdown body is loaded into the agent's context when the skill is activated. Tips:
- **Be procedural**: Step-by-step instructions work better than abstract descriptions.
- **Keep it focused**: Stay under 500 lines / 5000 tokens. Move detailed reference material to `references/`.
- **Use relative paths**: Reference bundled files with relative paths (`scripts/run.py`, `references/guide.md`).
- **Include examples**: Show sample inputs and expected outputs.
- **Cover edge cases**: Tell the agent what to do when things go wrong.
## How skills are activated
Skills use **progressive disclosure** — three tiers that keep context usage efficient:
### Tier 1: Catalog (always loaded)
At session start, the agent sees a compact catalog of all available skills (name + description only, ~50-100 tokens each). This is how it knows what skills exist.
### Tier 2: Instructions (on demand)
When the agent determines a skill is relevant to the current task, it reads the full `SKILL.md` body into context. This happens automatically — the agent matches the task against skill descriptions and activates the best fit.
### Tier 3: Resources (on demand)
When skill instructions reference supporting files (`scripts/extract.py`, `references/api-docs.md`), the agent reads those individually as needed.
### Pre-activated skills
Some agents are configured to load specific skills at session start (skipping the catalog phase). This is set in the agent's configuration:
```python
# In agent definition
skills = ["code-review", "deep-research"]
```
Pre-activated skills have their full instructions loaded from the start, without waiting for the agent to decide they're relevant.
## Trust and security
### Why trust gating exists
Project-level skills come from the repository being worked on. If you clone an untrusted repo that contains a `.hive/skills/` directory, those skills could inject instructions into the agent's system prompt. Trust gating prevents this.
**User-level and framework skills are always trusted.** Only project-scope skills go through trust gating.
### What happens with untrusted project skills
When Hive encounters project-level skills from a repo you haven't trusted before, it shows a consent prompt:
```
============================================================
SKILL TRUST REQUIRED
============================================================
The project at /home/user/new-project wants to load 2 skill(s)
that will inject instructions into the agent's system prompt.
Source: github.com/org/new-project
Skills requesting access:
• deploy-pipeline
"Automated deployment workflow for this project."
/home/user/new-project/.hive/skills/deploy-pipeline/SKILL.md
• code-standards
"Project-specific coding standards and review checklist."
/home/user/new-project/.hive/skills/code-standards/SKILL.md
Options:
1) Trust this session only
2) Trust permanently — remember for future runs
3) Deny — skip all project-scope skills from this repo
────────────────────────────────────────────────────────────
Select option (1-3):
```
### Trust a repo via CLI
To trust a repo permanently without the interactive prompt:
```bash
hive skill trust /path/to/project
```
This stores the trust decision in `~/.hive/trusted_repos.json`, keyed by the normalized git remote URL (e.g., `github.com/org/repo`).
### Automatic trust
Some repos are trusted automatically:
- **No git repo**: Directories without `.git/` are always trusted.
- **No remote**: Local-only git repos (no `origin` remote) are always trusted.
- **Localhost remotes**: Repos with `localhost`/`127.0.0.1` remotes are always trusted.
- **Own-remote patterns**: Repos matching patterns in `~/.hive/own_remotes` or the `HIVE_OWN_REMOTES` env var are always trusted.
### Configure own-remote patterns
If you trust all repos from your organization:
```bash
# Via file (one pattern per line)
echo "github.com/my-org/*" >> ~/.hive/own_remotes
echo "gitlab.com/my-team/*" >> ~/.hive/own_remotes
# Via environment variable (comma-separated)
export HIVE_OWN_REMOTES="github.com/my-org/*,github.com/my-corp/*"
```
### CI / headless environments
In non-interactive environments, untrusted project skills are silently skipped. To trust them explicitly:
```bash
export HIVE_TRUST_PROJECT_SKILLS=1
hive run my-agent
```
## Default skills
Hive ships with six built-in operational skills that provide runtime resilience. These are always loaded (unless disabled) and appear as "Operational Protocols" in the agent's system prompt.
| Skill | Purpose |
|-------|---------|
| `hive.note-taking` | Structured working notes in shared memory |
| `hive.batch-ledger` | Track per-item status in batch operations |
| `hive.context-preservation` | Save context before context window pruning |
| `hive.quality-monitor` | Self-assess output quality periodically |
| `hive.error-recovery` | Structured error classification and recovery |
| `hive.task-decomposition` | Break complex tasks into subtasks |
### Disable default skills
In your agent configuration:
```python
# Disable a specific default skill
default_skills = {
"hive.quality-monitor": {"enabled": False},
}
# Disable all default skills
default_skills = {
"_all": {"enabled": False},
}
```
## Environment variables
| Variable | Description |
|----------|-------------|
| `HIVE_TRUST_PROJECT_SKILLS=1` | Bypass trust gating for all project-level skills (CI override) |
| `HIVE_OWN_REMOTES` | Comma-separated glob patterns for auto-trusted remotes (e.g., `github.com/myorg/*`) |
## Compatibility with other agents
Skills written for any Agent Skills-compatible agent work in Hive:
- Place them in `.agents/skills/` (cross-client) or `.hive/skills/` (Hive-specific).
- The `SKILL.md` format is identical across Claude Code, Cursor, Gemini CLI, and others.
- Skills installed at `~/.agents/skills/` are visible to all compatible agents on your machine.
See the [Agent Skills specification](https://agentskills.io/specification) for the full format reference.
+136
View File
@@ -0,0 +1,136 @@
# SDR Agent
An AI-powered sales development outreach automation template for [Hive](https://github.com/aden-hive/hive).
Score contacts by priority, filter suspicious profiles, generate personalized messages, and create Gmail drafts — all with human review before anything is sent.
## Overview
The SDR Agent automates the full outreach pipeline:
```
Intake → Score Contacts → Filter Contacts → Personalize → Send Outreach → Report
```
1. **Intake** — Accept a contact list and outreach goal; confirm strategy with user
2. **Score Contacts** — Rank contacts 0100 using priority factors (alumni, degree, domain, etc.)
3. **Filter Contacts** — Detect and skip suspicious/fake profiles (risk score ≥ 7)
4. **Personalize** — Generate an 80120 word personalized message per contact
5. **Send Outreach** — Create Gmail drafts for human review (never sends automatically)
6. **Report** — Summarize campaign: contacts scored, filtered, drafted
## Quickstart
```bash
cd examples/templates/sdr_agent
# Run interactively via TUI
python -m sdr_agent tui
# Run via CLI with a contacts JSON string
python -m sdr_agent run \
--contacts '[{"name":"Jane Doe","company":"Acme","title":"Engineer","connection_degree":"2nd","is_alumni":true}]' \
--goal "coffee chat" \
--background "Learning Technologist at UWO" \
--max-contacts 20
# Validate agent structure
python -m sdr_agent validate
```
## Contact Schema
Each contact in your list supports the following fields:
| Field | Type | Required | Description |
|-------|------|----------|-------------|
| `name` | string | ✅ | Contact's full name |
| `email` | string | ❌ | Email address (draft placeholder if missing) |
| `company` | string | ✅ | Current company |
| `title` | string | ✅ | Job title |
| `linkedin_url` | string | ❌ | LinkedIn profile URL |
| `connection_degree` | string | ❌ | `"1st"`, `"2nd"`, or `"3rd"` |
| `is_alumni` | boolean | ❌ | Shares school with user |
| `school_name` | string | ❌ | School name for alumni messaging |
| `connections_count` | integer | ❌ | Number of LinkedIn connections |
| `mutual_connections` | integer | ❌ | Count of mutual connections |
| `has_photo` | boolean | ❌ | Has a profile photo |
## Scoring Model
The `score-contacts` node ranks each contact 0100:
| Factor | Points |
|--------|--------|
| Alumni | +30 |
| 1st degree | +25 |
| 2nd degree | +20 |
| 3rd degree | +10 |
| Domain verified | +10 |
| Mutual connections (×1, max 10) | +10 |
| Active job posting | +10 |
| Has profile photo | +5 |
| 500+ connections | +5 |
## Scam Detection
The `filter-contacts` node calculates a risk score and excludes contacts with risk ≥ 7:
| Red Flag | Risk |
|----------|------|
| Fewer than 50 connections | +3 |
| No profile photo | +2 |
| Fewer than 2 work positions | +2 |
| Generic title + few connections | +2 |
| Unverifiable company | +2 |
| AI-generated-looking profile | +2 |
| 5000+ connections, 0 mutual | +1 |
## Pipeline Output Files
Each run writes to `~/.hive/agents/sdr_agent/data/`:
| File | Contents |
|------|----------|
| `contacts.jsonl` | Raw contact list |
| `scored_contacts.jsonl` | Contacts with `priority_score` |
| `safe_contacts.jsonl` | Contacts passing scam filter |
| `personalized_contacts.jsonl` | Contacts with `outreach_message` |
| `drafts.jsonl` | Draft creation records |
## Safety Constraints
- **Never sends emails** — only `gmail_create_draft` is called; human must review and send
- **Batch limit** — processes at most `max_contacts` per run (default: 20)
- **Skip suspicious** — contacts with `risk_score ≥ 7` are always excluded
## Tools Required
- `gmail_create_draft` — create Gmail draft for each contact
- `load_data` — read JSONL data files
- `append_data` — write to JSONL data files
## Architecture
```
┌──────────────────────────────────────────────────────────────┐
│ SDR Agent │
│ │
│ ┌────────┐ ┌───────────────┐ ┌────────────────┐ │
│ │ Intake │──▶│ Score Contacts│──▶│ Filter Contacts│ │
│ └────────┘ └───────────────┘ └────────────────┘ │
│ ▲ │ │
│ │ ▼ │
│ ┌────────┐ ┌───────────────┐ ┌─────────────┐ │
│ │ Report │◀──│ Send Outreach │◀──│ Personalize │ │
│ └────────┘ └───────────────┘ └─────────────┘ │
│ │
│ ● client_facing nodes: intake, report │
│ ● automated nodes: score-contacts, filter-contacts, │
│ personalize, send-outreach │
└──────────────────────────────────────────────────────────────┘
```
## Inspiration
This template is inspired by real-world SDR automation patterns, including contact ranking, scam detection, and two-step personalization (hook extraction → message generation) — demonstrating how job-search and sales outreach workflows can be modeled as AI agent pipelines in Hive.
+45
View File
@@ -0,0 +1,45 @@
"""
SDR Agent Automated sales development outreach pipeline.
Score contacts by priority, filter suspicious profiles, generate personalized
outreach messages, and create Gmail drafts for human review before sending.
"""
from .agent import (
SDRAgent,
default_agent,
goal,
nodes,
edges,
loop_config,
async_entry_points,
entry_node,
entry_points,
pause_nodes,
terminal_nodes,
conversation_mode,
identity_prompt,
)
from .config import RuntimeConfig, AgentMetadata, default_config, metadata
__version__ = "1.0.0"
__all__ = [
"SDRAgent",
"default_agent",
"goal",
"nodes",
"edges",
"loop_config",
"async_entry_points",
"entry_node",
"entry_points",
"pause_nodes",
"terminal_nodes",
"conversation_mode",
"identity_prompt",
"RuntimeConfig",
"AgentMetadata",
"default_config",
"metadata",
]
+234
View File
@@ -0,0 +1,234 @@
"""
CLI entry point for SDR Agent.
Automates sales development outreach: score contacts, filter suspicious
profiles, generate personalized messages, and create Gmail drafts.
"""
import asyncio
import json
import logging
import sys
import click
from .agent import default_agent, SDRAgent
def setup_logging(verbose=False, debug=False):
"""Configure logging for execution visibility."""
if debug:
level, fmt = logging.DEBUG, "%(asctime)s %(name)s: %(message)s"
elif verbose:
level, fmt = logging.INFO, "%(message)s"
else:
level, fmt = logging.WARNING, "%(levelname)s: %(message)s"
logging.basicConfig(level=level, format=fmt, stream=sys.stderr)
logging.getLogger("framework").setLevel(level)
@click.group()
@click.version_option(version="1.0.0")
def cli():
"""SDR Agent - Automated outreach with contact scoring and personalization."""
pass
@cli.command()
@click.option(
"--contacts",
"-c",
type=str,
required=True,
help="JSON string or file path of contacts list",
)
@click.option(
"--goal",
"-g",
type=str,
default="coffee chat",
help="Outreach goal (e.g. 'coffee chat', 'sales pitch')",
)
@click.option(
"--background",
"-b",
type=str,
default="",
help="Your background/role for personalization",
)
@click.option(
"--max-contacts",
"-m",
type=int,
default=20,
help="Max contacts to process per batch (default: 20)",
)
@click.option(
"--mock", is_flag=True, help="Run in mock mode without LLM or Gmail calls"
)
@click.option("--quiet", "-q", is_flag=True, help="Only output result JSON")
@click.option("--verbose", "-v", is_flag=True, help="Show execution details")
@click.option("--debug", is_flag=True, help="Show debug logging")
def run(contacts, goal, background, max_contacts, mock, quiet, verbose, debug):
"""Execute an SDR outreach campaign for the given contacts."""
if not quiet:
setup_logging(verbose=verbose, debug=debug)
context = {
"contacts": contacts,
"outreach_goal": goal,
"user_background": background,
"max_contacts": str(max_contacts),
}
result = asyncio.run(default_agent.run(context, mock_mode=mock))
output_data = {
"success": result.success,
"steps_executed": result.steps_executed,
"output": result.output,
}
if result.error:
output_data["error"] = result.error
click.echo(json.dumps(output_data, indent=2, default=str))
sys.exit(0 if result.success else 1)
@cli.command()
@click.option("--mock", is_flag=True, help="Run in mock mode")
@click.option("--verbose", "-v", is_flag=True, help="Show execution details")
@click.option("--debug", is_flag=True, help="Show debug logging")
def tui(mock, verbose, debug):
"""Launch the TUI dashboard for interactive SDR outreach."""
setup_logging(verbose=verbose, debug=debug)
try:
from framework.tui.app import AdenTUI
except ImportError:
click.echo(
"TUI requires the 'textual' package. Install with: pip install textual"
)
sys.exit(1)
async def run_with_tui():
agent = SDRAgent()
await agent.start(mock_mode=mock)
try:
app = AdenTUI(agent._agent_runtime)
await app.run_async()
finally:
await agent.stop()
asyncio.run(run_with_tui())
@cli.command()
@click.option("--json", "output_json", is_flag=True)
def info(output_json):
"""Show agent information."""
info_data = default_agent.info()
if output_json:
click.echo(json.dumps(info_data, indent=2))
else:
click.echo(f"Agent: {info_data['name']}")
click.echo(f"Version: {info_data['version']}")
click.echo(f"Description: {info_data['description']}")
click.echo(f"\nNodes: {', '.join(info_data['nodes'])}")
click.echo(f"Client-facing: {', '.join(info_data['client_facing_nodes'])}")
click.echo(f"Entry: {info_data['entry_node']}")
click.echo(f"Terminal: {', '.join(info_data['terminal_nodes'])}")
@cli.command()
def validate():
"""Validate agent structure."""
validation = default_agent.validate()
if validation["valid"]:
click.echo("Agent is valid")
if validation["warnings"]:
for warning in validation["warnings"]:
click.echo(f" WARNING: {warning}")
else:
click.echo("Agent has errors:")
for error in validation["errors"]:
click.echo(f" ERROR: {error}")
sys.exit(0 if validation["valid"] else 1)
@cli.command()
@click.option("--verbose", "-v", is_flag=True)
def shell(verbose):
"""Interactive SDR outreach session (CLI, no TUI)."""
asyncio.run(_interactive_shell(verbose))
async def _interactive_shell(verbose=False):
"""Async interactive shell."""
setup_logging(verbose=verbose)
click.echo("=== SDR Agent ===")
click.echo("Automated contact scoring, filtering, and outreach personalization\n")
agent = SDRAgent()
await agent.start()
try:
while True:
try:
goal = await asyncio.get_event_loop().run_in_executor(
None, input, "Outreach goal (e.g. 'coffee chat')> "
)
if goal.lower() in ["quit", "exit", "q"]:
click.echo("Goodbye!")
break
contacts = await asyncio.get_event_loop().run_in_executor(
None, input, "Contacts (JSON)> "
)
background = await asyncio.get_event_loop().run_in_executor(
None, input, "Your background/role> "
)
if not contacts.strip():
continue
click.echo("\nRunning SDR campaign...\n")
result = await agent.trigger_and_wait(
"start",
{
"contacts": contacts,
"outreach_goal": goal,
"user_background": background,
"max_contacts": "20",
},
)
if result is None:
click.echo("\n[Execution timed out]\n")
continue
if result.success:
output = result.output
if "summary_report" in output:
click.echo("\n--- Campaign Report ---\n")
click.echo(output["summary_report"])
click.echo("\n")
else:
click.echo(f"\nCampaign failed: {result.error}\n")
except KeyboardInterrupt:
click.echo("\nGoodbye!")
break
except Exception as e:
click.echo(f"Error: {e}", err=True)
import traceback
traceback.print_exc()
finally:
await agent.stop()
if __name__ == "__main__":
cli()
+378
View File
@@ -0,0 +1,378 @@
{
"agent": {
"id": "sdr_agent",
"name": "SDR Agent",
"version": "1.0.0",
"description": "Automate sales development outreach using AI-powered contact scoring, scam detection, and personalized message generation. Score contacts by priority, filter suspicious profiles, generate personalized outreach messages, and create Gmail drafts for review — all without sending emails automatically."
},
"graph": {
"id": "sdr-agent-graph",
"goal_id": "sdr-agent",
"version": "1.0.0",
"entry_node": "intake",
"entry_points": {
"start": "intake"
},
"pause_nodes": [],
"terminal_nodes": ["complete"],
"conversation_mode": "continuous",
"identity_prompt": "You are an SDR (Sales Development Representative) assistant. You help users automate their outreach by scoring contacts, filtering suspicious profiles, generating personalized messages, and creating Gmail drafts — all with human review before anything is sent.",
"nodes": [
{
"id": "intake",
"name": "Intake",
"description": "Receive the contact list and outreach goal from the user. Confirm the strategy and batch size before proceeding.",
"node_type": "event_loop",
"input_keys": [
"contacts",
"outreach_goal",
"max_contacts",
"user_background"
],
"output_keys": [
"contacts",
"outreach_goal",
"max_contacts",
"user_background"
],
"nullable_output_keys": [],
"input_schema": {},
"output_schema": {},
"system_prompt": "You are an SDR (Sales Development Representative) assistant helping automate outreach.\n\n**STEP 1 — Respond to the user (text only, NO tool calls):**\n\nRead the user's input from context. Confirm your understanding of:\n- The contact list they provided (or ask them to provide one)\n- Their outreach goal (e.g. \"coffee chat\", \"sales pitch\", \"networking\")\n- Their background/role (used to personalize messages)\n- The batch size (max_contacts). Default to 20 if not specified.\n\nPresent a summary like:\n\"Here's what I'll do:\n1. Score and rank your contacts by priority (alumni status, connection degree, etc.)\n2. Filter out suspicious or low-quality profiles (risk score ≥ 7)\n3. Generate a personalized outreach message for each contact\n4. Create Gmail draft emails for your review — I never send automatically\n\nReady to proceed with [N] contacts for [goal]?\"\n\n**STEP 2 — After the user confirms, call set_output:**\n\n- set_output(\"contacts\", <the contact list as a JSON string>)\n- set_output(\"outreach_goal\", <the confirmed goal, e.g. \"coffee chat\">)\n- set_output(\"max_contacts\", <the confirmed batch size as a string, e.g. \"20\">)\n- set_output(\"user_background\", <user's background/role, e.g. \"Learning Technologist at UWO\">)",
"tools": [],
"model": null,
"function": null,
"routes": {},
"max_retries": 3,
"retry_on": [],
"max_node_visits": 0,
"output_model": null,
"max_validation_retries": 2,
"client_facing": true,
"success_criteria": null
},
{
"id": "score-contacts",
"name": "Score Contacts",
"description": "Score and rank each contact from 0 to 100 based on priority factors: alumni status, connection degree, domain verification, mutual connections, and active job postings.",
"node_type": "event_loop",
"input_keys": [
"contacts",
"outreach_goal"
],
"output_keys": [
"scored_contacts"
],
"nullable_output_keys": [],
"input_schema": {},
"output_schema": {},
"system_prompt": "You are a contact prioritization engine. Score each contact from 0 to 100.\n\n**SCORING RULES (additive):**\n- Alumni of the user's school: +30 points\n- 1st degree connection: +25 points\n- 2nd degree connection: +20 points\n- 3rd degree connection: +10 points\n- Domain verified (company email matches LinkedIn company): +10 points\n- Has mutual connections (1 point each, max 10): up to +10 points\n- Active job posting at their company: +10 points\n- Has a profile photo: +5 points\n- Over 500 connections: +5 points\n\nCap the final score at 100.\n\n**STEP 1 — Load the contacts:**\nCall load_data(filename=\"contacts.jsonl\") to read the contact list.\nIf \"contacts\" in context is a JSON string (not a filename), write it first:\n- For each contact in the list, call append_data(filename=\"contacts.jsonl\", data=<JSON contact object>)\nThen read it back.\n\n**STEP 2 — Score each contact:**\nFor each contact, calculate the priority score using the rules above.\nAdd a \"priority_score\" field to each contact object.\n\n**STEP 3 — Write scored contacts and set output:**\n- Call append_data(filename=\"scored_contacts.jsonl\", data=<JSON contact with priority_score>) for each contact.\n- Sort contacts by priority_score (highest first) in your final output.\n- Call set_output(\"scored_contacts\", \"scored_contacts.jsonl\")",
"tools": [
"load_data",
"append_data"
],
"model": null,
"function": null,
"routes": {},
"max_retries": 3,
"retry_on": [],
"max_node_visits": 0,
"output_model": null,
"max_validation_retries": 2,
"client_facing": false,
"success_criteria": null
},
{
"id": "filter-contacts",
"name": "Filter Contacts",
"description": "Analyze each contact for authenticity and filter out suspicious profiles. Any contact with a risk score of 7 or higher is skipped.",
"node_type": "event_loop",
"input_keys": [
"scored_contacts"
],
"output_keys": [
"safe_contacts",
"filtered_count"
],
"nullable_output_keys": [],
"input_schema": {},
"output_schema": {},
"system_prompt": "You are a profile authenticity analyzer. Your job is to detect suspicious or fake LinkedIn profiles.\n\n**RISK SCORING RULES (additive):**\n- Fewer than 50 connections: +3 points\n- No profile photo: +2 points\n- Fewer than 2 positions in work history: +2 points\n- Generic title (e.g. \"entrepreneur\", \"CEO\", \"consultant\") AND fewer than 100 connections: +2 points\n- Company name appears generic or unverifiable: +2 points\n- Profile text seems auto-generated or overly promotional: +2 points\n- Connection count over 5000 with no mutual connections: +1 point\n\n**DECISION RULE:**\n- risk_score < 4: SAFE — include in outreach\n- risk_score 46: CAUTION — include but flag\n- risk_score ≥ 7: SKIP — exclude from outreach\n\n**STEP 1 — Load scored contacts:**\nCall load_data(filename=<the \"scored_contacts\" value from context>).\nProcess contacts chunk by chunk if has_more=true.\n\n**STEP 2 — Analyze each contact:**\nFor each contact, calculate a risk_score using the rules above.\nDetermine: is_safe (risk_score < 7), recommendation (safe/caution/skip), flags (list of triggered rules).\n\n**STEP 3 — Write safe contacts and set output:**\n- For each contact where risk_score < 7: call append_data(filename=\"safe_contacts.jsonl\", data=<contact JSON with risk_score and flags added>)\n- Track how many contacts were filtered (risk_score ≥ 7)\n- Call set_output(\"safe_contacts\", \"safe_contacts.jsonl\")\n- Call set_output(\"filtered_count\", <number of skipped contacts as string>)",
"tools": [
"load_data",
"append_data"
],
"model": null,
"function": null,
"routes": {},
"max_retries": 3,
"retry_on": [],
"max_node_visits": 0,
"output_model": null,
"max_validation_retries": 2,
"client_facing": false,
"success_criteria": null
},
{
"id": "personalize",
"name": "Personalize",
"description": "Generate a personalized outreach message for each contact based on their profile, shared background, and the user's outreach goal.",
"node_type": "event_loop",
"input_keys": [
"safe_contacts",
"outreach_goal",
"user_background"
],
"output_keys": [
"personalized_contacts"
],
"nullable_output_keys": [],
"input_schema": {},
"output_schema": {},
"system_prompt": "You are a professional outreach message writer. Generate personalized messages for each contact.\n\n**TWO-STEP PERSONALIZATION:**\n\nFor each contact, follow this two-step approach:\n\nSTEP A — Extract hooks (analyze the profile):\nLook for 2-3 specific talking points from the contact's profile:\n- Shared alumni connection\n- Specific role, company, or career transition worth mentioning\n- Any mutual interests aligned with the user's background\n\nSTEP B — Generate the message:\nWrite a warm, professional outreach message using the hooks.\n\n**MESSAGE REQUIREMENTS:**\n- 80-120 words (LinkedIn message length)\n- Start with a specific observation (\"I noticed you...\" or \"Fellow [school] alum here...\")\n- Mention the shared connection or interest naturally\n- State the outreach goal clearly but softly (e.g. \"Open to a brief 15-min chat?\")\n- Professional but warm tone — NOT templated or AI-sounding\n- Do NOT mention job postings directly unless the goal is job-related\n- Do NOT use generic openers like \"I hope this finds you well\"\n- End with a low-pressure ask\n\n**STEP 1 — Load safe contacts:**\nCall load_data(filename=<the \"safe_contacts\" value from context>).\n\n**STEP 2 — Generate message for each contact:**\nFor each contact: generate the personalized message using the two-step approach above.\nAdd \"outreach_message\" field to each contact object.\n\n**STEP 3 — Write output and set:**\n- Call append_data(filename=\"personalized_contacts.jsonl\", data=<contact JSON with outreach_message>) for each.\n- Call set_output(\"personalized_contacts\", \"personalized_contacts.jsonl\")",
"tools": [
"load_data",
"append_data"
],
"model": null,
"function": null,
"routes": {},
"max_retries": 3,
"retry_on": [],
"max_node_visits": 0,
"output_model": null,
"max_validation_retries": 2,
"client_facing": false,
"success_criteria": null
},
{
"id": "send-outreach",
"name": "Send Outreach",
"description": "Create Gmail draft emails for each contact using their personalized message. Drafts are created for human review — emails are never sent automatically.",
"node_type": "event_loop",
"input_keys": [
"personalized_contacts",
"outreach_goal"
],
"output_keys": [
"drafts_created"
],
"nullable_output_keys": [],
"input_schema": {},
"output_schema": {},
"system_prompt": "You are an outreach execution assistant. Create Gmail draft emails for each contact.\n\n**CRITICAL RULE: NEVER send emails automatically. Only create drafts.**\n\n**STEP 1 — Load personalized contacts:**\nCall load_data(filename=<the \"personalized_contacts\" value from context>).\nProcess chunk by chunk if has_more=true.\n\n**STEP 2 — Create Gmail draft for each contact:**\nFor each contact with an \"outreach_message\":\n- subject: \"Coffee Chat Request\" (or appropriate subject based on outreach_goal)\n- to: contact's email address (use LinkedIn profile URL if email not available — note this in body)\n- body: the \"outreach_message\" from the contact object\n\nCall gmail_create_draft(\n to=<contact email or linkedin_url as placeholder>,\n subject=<appropriate subject line>,\n body=<outreach_message>\n)\n\nRecord each draft: call append_data(\n filename=\"drafts.jsonl\",\n data=<JSON: {contact_name, contact_email, subject, status: \"draft_created\"}>\n)\n\n**STEP 3 — Set output:**\n- Call set_output(\"drafts_created\", \"drafts.jsonl\")\n\n**IMPORTANT:** If a contact has no email address, create the draft with their LinkedIn URL as a placeholder and add a note in the body: \"Note: Please find the recipient's email before sending.\"",
"tools": [
"gmail_create_draft",
"load_data",
"append_data"
],
"model": null,
"function": null,
"routes": {},
"max_retries": 3,
"retry_on": [],
"max_node_visits": 0,
"output_model": null,
"max_validation_retries": 2,
"client_facing": false,
"success_criteria": null
},
{
"id": "report",
"name": "Report",
"description": "Generate a summary report of the outreach campaign: contacts scored, filtered, messaged, and drafts created. Present to user for review.",
"node_type": "event_loop",
"input_keys": [
"drafts_created",
"filtered_count",
"outreach_goal"
],
"output_keys": [
"summary_report"
],
"nullable_output_keys": [],
"input_schema": {},
"output_schema": {},
"system_prompt": "You are an SDR assistant. Generate a clear campaign summary report and present it to the user.\n\n**STEP 1 — Load draft records:**\nCall load_data(filename=<the \"drafts_created\" value from context>) to read the draft records.\nIf has_more=true, load additional chunks until all records are loaded.\n\n**STEP 2 — Present the report (text only, NO tool calls):**\n\nPresent a clean summary:\n\n📊 **SDR Campaign Summary — [outreach_goal]**\n\n**Overview:**\n- Total contacts processed: [N]\n- Contacts filtered (suspicious profiles): [filtered_count]\n- Safe contacts messaged: [N - filtered_count]\n- Gmail drafts created: [N]\n\n**Drafts Created:**\nList each draft: Contact Name | Company | Subject\n\n**Next Steps:**\n\"Your Gmail drafts are ready for review. Please:\n1. Open Gmail and review each draft\n2. Personalize further if needed\n3. Send when ready\n\nCampaign complete!\"\n\n**STEP 3 — After the user responds, call set_output:**\n- set_output(\"summary_report\", <the formatted report text>)",
"tools": [
"load_data"
],
"model": null,
"function": null,
"routes": {},
"max_retries": 3,
"retry_on": [],
"max_node_visits": 0,
"output_model": null,
"max_validation_retries": 2,
"client_facing": true,
"success_criteria": null
},
{
"id": "complete",
"name": "Complete",
"description": "Terminal node - campaign complete.",
"node_type": "event_loop",
"input_keys": [
"summary_report"
],
"output_keys": [
"final_report"
],
"nullable_output_keys": [],
"input_schema": {},
"output_schema": {},
"system_prompt": "Campaign is complete. Set the final output.\n\nCall set_output(\"final_report\", <summary_report value from context>)",
"tools": [],
"model": null,
"function": null,
"routes": {},
"max_retries": 3,
"retry_on": [],
"max_node_visits": 1,
"output_model": null,
"max_validation_retries": 2,
"client_facing": false,
"success_criteria": null
}
],
"edges": [
{
"id": "intake-to-score",
"source": "intake",
"target": "score-contacts",
"condition": "on_success",
"condition_expr": null,
"priority": 1,
"input_mapping": {}
},
{
"id": "score-to-filter",
"source": "score-contacts",
"target": "filter-contacts",
"condition": "on_success",
"condition_expr": null,
"priority": 1,
"input_mapping": {}
},
{
"id": "filter-to-personalize",
"source": "filter-contacts",
"target": "personalize",
"condition": "on_success",
"condition_expr": null,
"priority": 1,
"input_mapping": {}
},
{
"id": "personalize-to-send",
"source": "personalize",
"target": "send-outreach",
"condition": "on_success",
"condition_expr": null,
"priority": 1,
"input_mapping": {}
},
{
"id": "send-to-report",
"source": "send-outreach",
"target": "report",
"condition": "on_success",
"condition_expr": null,
"priority": 1,
"input_mapping": {}
},
{
"id": "report-to-complete",
"source": "report",
"target": "complete",
"condition": "on_success",
"condition_expr": null,
"priority": 1,
"input_mapping": {}
}
],
"max_steps": 100,
"max_retries_per_node": 3,
"description": "Automated SDR outreach pipeline: score contacts by priority, filter suspicious profiles, generate personalized messages, and create Gmail drafts for human review."
},
"goal": {
"id": "sdr-agent",
"name": "SDR Agent",
"description": "Automate sales development outreach: score contacts by priority, filter suspicious profiles, generate personalized messages, and create Gmail drafts for human review.",
"status": "draft",
"success_criteria": [
{
"id": "contact-scoring-accuracy",
"description": "Contacts are correctly scored and ranked by priority factors (alumni status, connection degree, domain verification)",
"metric": "scoring_accuracy",
"target": ">=90%",
"weight": 0.30,
"met": false
},
{
"id": "scam-filter-effectiveness",
"description": "Suspicious profiles (risk_score >= 7) are correctly identified and excluded from outreach",
"metric": "filter_precision",
"target": ">=95%",
"weight": 0.25,
"met": false
},
{
"id": "message-personalization",
"description": "Generated messages reference specific profile details (alumni connection, role, company) and match the outreach goal",
"metric": "personalization_score",
"target": ">=80%",
"weight": 0.30,
"met": false
},
{
"id": "draft-creation",
"description": "Gmail drafts are created for all safe contacts without errors",
"metric": "draft_success_rate",
"target": "100%",
"weight": 0.15,
"met": false
}
],
"constraints": [
{
"id": "draft-not-send",
"description": "Agent creates Gmail drafts but NEVER sends emails automatically",
"constraint_type": "hard",
"category": "safety",
"check": ""
},
{
"id": "respect-batch-limit",
"description": "Must not process more contacts than the configured max_contacts parameter",
"constraint_type": "hard",
"category": "operational",
"check": ""
},
{
"id": "skip-suspicious",
"description": "Contacts with risk_score >= 7 must be excluded from outreach",
"constraint_type": "hard",
"category": "safety",
"check": ""
}
],
"context": {},
"required_capabilities": [],
"input_schema": {},
"output_schema": {},
"version": "1.0.0",
"parent_version": null,
"evolution_reason": null
},
"required_tools": [
"gmail_create_draft",
"load_data",
"append_data"
],
"metadata": {
"node_count": 7,
"edge_count": 6
}
}
+375
View File
@@ -0,0 +1,375 @@
"""Agent graph construction for SDR Agent."""
from pathlib import Path
from framework.graph import EdgeSpec, EdgeCondition, Goal, SuccessCriterion, Constraint
from framework.graph.checkpoint_config import CheckpointConfig
from framework.graph.edge import AsyncEntryPointSpec, GraphSpec
from framework.graph.executor import ExecutionResult
from framework.llm import LiteLLMProvider
from framework.runner.tool_registry import ToolRegistry
from framework.runtime.agent_runtime import AgentRuntime, create_agent_runtime
from framework.runtime.execution_stream import EntryPointSpec
from .config import default_config, metadata
from .nodes import (
intake_node,
score_contacts_node,
filter_contacts_node,
personalize_node,
send_outreach_node,
report_node,
)
# Goal definition
goal = Goal(
id="sdr-agent",
name="SDR Agent",
description=(
"Automate sales development outreach: score contacts by priority, "
"filter suspicious profiles, generate personalized messages, "
"and create Gmail drafts for human review."
),
success_criteria=[
SuccessCriterion(
id="contact-scoring-accuracy",
description=(
"Contacts are correctly scored and ranked by priority factors "
"(alumni status, connection degree, domain verification)"
),
metric="scoring_accuracy",
target=">=90%",
weight=0.30,
),
SuccessCriterion(
id="scam-filter-effectiveness",
description=(
"Suspicious profiles (risk_score >= 7) are correctly identified "
"and excluded from outreach"
),
metric="filter_precision",
target=">=95%",
weight=0.25,
),
SuccessCriterion(
id="message-personalization",
description=(
"Generated messages reference specific profile details "
"(alumni connection, role, company) and match the outreach goal"
),
metric="personalization_score",
target=">=80%",
weight=0.30,
),
SuccessCriterion(
id="draft-creation",
description="Gmail drafts are created for all safe contacts without errors",
metric="draft_success_rate",
target="100%",
weight=0.15,
),
],
constraints=[
Constraint(
id="draft-not-send",
description="Agent creates Gmail drafts but NEVER sends emails automatically",
constraint_type="hard",
category="safety",
),
Constraint(
id="respect-batch-limit",
description="Must not process more contacts than the configured max_contacts parameter",
constraint_type="hard",
category="operational",
),
Constraint(
id="skip-suspicious",
description="Contacts with risk_score >= 7 must be excluded from outreach",
constraint_type="hard",
category="safety",
),
],
)
# Node list
nodes = [
intake_node,
score_contacts_node,
filter_contacts_node,
personalize_node,
send_outreach_node,
report_node,
]
# Edge definitions
edges = [
EdgeSpec(
id="intake-to-score",
source="intake",
target="score-contacts",
condition=EdgeCondition.ON_SUCCESS,
priority=1,
),
EdgeSpec(
id="score-to-filter",
source="score-contacts",
target="filter-contacts",
condition=EdgeCondition.ON_SUCCESS,
priority=1,
),
EdgeSpec(
id="filter-to-personalize",
source="filter-contacts",
target="personalize",
condition=EdgeCondition.ON_SUCCESS,
priority=1,
),
EdgeSpec(
id="personalize-to-send",
source="personalize",
target="send-outreach",
condition=EdgeCondition.ON_SUCCESS,
priority=1,
),
EdgeSpec(
id="send-to-report",
source="send-outreach",
target="report",
condition=EdgeCondition.ON_SUCCESS,
priority=1,
),
EdgeSpec(
id="report-to-intake",
source="report",
target="intake",
condition=EdgeCondition.ON_SUCCESS,
priority=1,
),
]
# Graph configuration
entry_node = "intake"
entry_points = {"start": "intake"}
async_entry_points: list[AsyncEntryPointSpec] = [] # SDR Agent is manually triggered
pause_nodes = []
terminal_nodes = []
loop_config = {
"max_iterations": 100,
"max_tool_calls_per_turn": 30,
"max_tool_result_chars": 8000,
"max_history_tokens": 32000,
}
conversation_mode = "continuous"
identity_prompt = (
"You are an SDR (Sales Development Representative) assistant. "
"You help users automate their outreach by scoring contacts, filtering "
"suspicious profiles, generating personalized messages, and creating "
"Gmail drafts — all with human review before anything is sent."
)
class SDRAgent:
"""
SDR Agent 6-node pipeline for automated outreach.
Flow: intake -> score-contacts -> filter-contacts -> personalize
-> send-outreach -> report -> intake (loop)
Pipeline:
1. intake: Receive contact list and outreach goal
2. score-contacts: Rank contacts 0-100 by priority factors
3. filter-contacts: Remove suspicious profiles (risk >= 7)
4. personalize: Generate personalized messages for each contact
5. send-outreach: Create Gmail drafts (never sends automatically)
6. report: Summarize campaign results and present to user
"""
def __init__(self, config=None):
self.config = config or default_config
self.goal = goal
self.nodes = nodes
self.edges = edges
self.entry_node = entry_node
self.entry_points = entry_points
self.pause_nodes = pause_nodes
self.terminal_nodes = terminal_nodes
self._agent_runtime: AgentRuntime | None = None
self._graph: GraphSpec | None = None
self._tool_registry: ToolRegistry | None = None
def _build_graph(self) -> GraphSpec:
"""Build the GraphSpec."""
return GraphSpec(
id="sdr-agent-graph",
goal_id=self.goal.id,
version="1.0.0",
entry_node=self.entry_node,
entry_points=self.entry_points,
terminal_nodes=self.terminal_nodes,
pause_nodes=self.pause_nodes,
nodes=self.nodes,
edges=self.edges,
default_model=self.config.model,
max_tokens=self.config.max_tokens,
loop_config=loop_config,
conversation_mode=conversation_mode,
identity_prompt=identity_prompt,
)
def _setup(self, mock_mode=False) -> None:
"""Set up the agent runtime with sessions, checkpoints, and logging."""
self._storage_path = Path.home() / ".hive" / "agents" / "sdr_agent"
self._storage_path.mkdir(parents=True, exist_ok=True)
self._tool_registry = ToolRegistry()
mcp_config_path = Path(__file__).parent / "mcp_servers.json"
if mcp_config_path.exists():
self._tool_registry.load_mcp_config(mcp_config_path)
tools_path = Path(__file__).parent / "tools.py"
if tools_path.exists():
self._tool_registry.discover_from_module(tools_path)
if mock_mode:
from framework.llm.mock import MockLLMProvider
llm = MockLLMProvider()
else:
llm = LiteLLMProvider(
model=self.config.model,
api_key=self.config.api_key,
api_base=self.config.api_base,
)
tool_executor = self._tool_registry.get_executor()
tools = list(self._tool_registry.get_tools().values())
self._graph = self._build_graph()
checkpoint_config = CheckpointConfig(
enabled=True,
checkpoint_on_node_start=False,
checkpoint_on_node_complete=True,
checkpoint_max_age_days=7,
async_checkpoint=True,
)
entry_point_specs = [
EntryPointSpec(
id="default",
name="Default",
entry_node=self.entry_node,
trigger_type="manual",
isolation_level="shared",
),
]
self._agent_runtime = create_agent_runtime(
graph=self._graph,
goal=self.goal,
storage_path=self._storage_path,
entry_points=entry_point_specs,
llm=llm,
tools=tools,
tool_executor=tool_executor,
checkpoint_config=checkpoint_config,
)
async def start(self, mock_mode=False) -> None:
"""Set up and start the agent runtime."""
if self._agent_runtime is None:
self._setup(mock_mode=mock_mode)
if not self._agent_runtime.is_running:
await self._agent_runtime.start()
async def stop(self) -> None:
"""Stop the agent runtime and clean up."""
if self._agent_runtime and self._agent_runtime.is_running:
await self._agent_runtime.stop()
self._agent_runtime = None
async def trigger_and_wait(
self,
entry_point: str,
input_data: dict,
timeout: float | None = None,
session_state: dict | None = None,
) -> ExecutionResult | None:
"""Execute the graph and wait for completion."""
if self._agent_runtime is None:
raise RuntimeError("Agent not started. Call start() first.")
return await self._agent_runtime.trigger_and_wait(
entry_point_id=entry_point,
input_data=input_data,
timeout=timeout,
session_state=session_state,
)
async def run(
self, context: dict, mock_mode=False, session_state=None
) -> ExecutionResult:
"""Run the agent (convenience method for single execution)."""
await self.start(mock_mode=mock_mode)
try:
result = await self.trigger_and_wait(
"default", context, session_state=session_state
)
return result or ExecutionResult(success=False, error="Execution timeout")
finally:
await self.stop()
def info(self):
"""Get agent information."""
return {
"name": metadata.name,
"version": metadata.version,
"description": metadata.description,
"goal": {
"name": self.goal.name,
"description": self.goal.description,
},
"nodes": [n.id for n in self.nodes],
"edges": [e.id for e in self.edges],
"entry_node": self.entry_node,
"entry_points": self.entry_points,
"pause_nodes": self.pause_nodes,
"terminal_nodes": self.terminal_nodes,
"client_facing_nodes": [n.id for n in self.nodes if n.client_facing],
}
def validate(self):
"""Validate agent structure."""
errors = []
warnings = []
node_ids = {node.id for node in self.nodes}
for edge in self.edges:
if edge.source not in node_ids:
errors.append(f"Edge {edge.id}: source '{edge.source}' not found")
if edge.target not in node_ids:
errors.append(f"Edge {edge.id}: target '{edge.target}' not found")
if self.entry_node not in node_ids:
errors.append(f"Entry node '{self.entry_node}' not found")
for terminal in self.terminal_nodes:
if terminal not in node_ids:
errors.append(f"Terminal node '{terminal}' not found")
for ep_id, node_id in self.entry_points.items():
if node_id not in node_ids:
errors.append(
f"Entry point '{ep_id}' references unknown node '{node_id}'"
)
return {
"valid": len(errors) == 0,
"errors": errors,
"warnings": warnings,
}
# Create default instance
default_agent = SDRAgent()
+30
View File
@@ -0,0 +1,30 @@
"""Runtime configuration for SDR Agent."""
from dataclasses import dataclass
from framework.config import RuntimeConfig
default_config = RuntimeConfig()
@dataclass
class AgentMetadata:
name: str = "SDR Agent"
version: str = "1.0.0"
description: str = (
"Automate sales development outreach using AI-powered contact scoring, "
"scam detection, and personalized message generation. "
"Score contacts by priority, filter suspicious profiles, generate "
"personalized outreach messages, and create Gmail drafts for review."
)
intro_message: str = (
"Hi! I'm your SDR (Sales Development Representative) assistant. "
"Provide a list of contacts and your outreach goal, and I'll "
"score them by priority, filter out suspicious profiles, generate "
"personalized messages for each contact, and create Gmail drafts "
"for your review. I never send emails automatically — you stay in control. "
"To get started, share your contact list and tell me about your outreach goal!"
)
metadata = AgentMetadata()
@@ -0,0 +1,97 @@
[
{
"name": "Sarah Chen",
"email": "sarah.chen@techcorp.io",
"company": "TechCorp",
"title": "Learning & Development Manager",
"linkedin_url": "https://linkedin.com/in/sarah-chen-ld",
"connection_degree": "2nd",
"is_alumni": true,
"school_name": "University of Western Ontario",
"connections_count": 843,
"mutual_connections": 7,
"has_photo": true,
"company_domain_verified": true
},
{
"name": "James Okafor",
"email": "james.okafor@edventure.co",
"company": "EdVenture",
"title": "Instructional Designer",
"linkedin_url": "https://linkedin.com/in/james-okafor-id",
"connection_degree": "1st",
"is_alumni": false,
"connections_count": 621,
"mutual_connections": 12,
"has_photo": true,
"company_domain_verified": true
},
{
"name": "Emily Zhao",
"email": "emily.zhao@univedu.ca",
"company": "UniEdu",
"title": "Director of Digital Learning",
"linkedin_url": "https://linkedin.com/in/emily-zhao-dl",
"connection_degree": "2nd",
"is_alumni": true,
"school_name": "University of Western Ontario",
"connections_count": 1204,
"mutual_connections": 3,
"has_photo": true,
"company_domain_verified": true,
"active_job_posting": true
},
{
"name": "Marcus Williams",
"email": "marcus@growthsales.io",
"company": "GrowthSales",
"title": "CEO",
"linkedin_url": "https://linkedin.com/in/marcus-williams-ceo",
"connection_degree": "3rd",
"is_alumni": false,
"connections_count": 6300,
"mutual_connections": 0,
"has_photo": true,
"company_domain_verified": false
},
{
"name": "Priya Patel",
"email": "",
"company": "FutureLearn Inc.",
"title": "EdTech Product Manager",
"linkedin_url": "https://linkedin.com/in/priya-patel-edtech",
"connection_degree": "2nd",
"is_alumni": false,
"connections_count": 512,
"mutual_connections": 5,
"has_photo": true,
"company_domain_verified": true
},
{
"name": "Alex Johnson",
"email": "alex@bizopp.biz",
"company": "Biz Opportunity Global",
"title": "Entrepreneur",
"linkedin_url": "https://linkedin.com/in/alex-johnson-biz",
"connection_degree": "3rd",
"is_alumni": false,
"connections_count": 38,
"mutual_connections": 0,
"has_photo": false,
"company_domain_verified": false
},
{
"name": "Natalie Brown",
"email": "natalie.brown@learningpro.com",
"company": "LearningPro",
"title": "HR Learning Specialist",
"linkedin_url": "https://linkedin.com/in/natalie-brown-hr",
"connection_degree": "1st",
"is_alumni": true,
"school_name": "University of Western Ontario",
"connections_count": 389,
"mutual_connections": 9,
"has_photo": true,
"company_domain_verified": true
}
]
+270
View File
@@ -0,0 +1,270 @@
{
"original_draft": {
"agent_name": "sdr_agent",
"goal": "Automate sales development outreach: score contacts by priority, filter suspicious profiles, generate personalized messages, and create Gmail drafts for human review.",
"description": "",
"success_criteria": [
"Contacts are correctly scored and ranked by priority factors (alumni status, connection degree, domain verification)",
"Suspicious profiles (risk_score >= 7) are correctly identified and excluded from outreach",
"Generated messages reference specific profile details (alumni connection, role, company) and match the outreach goal",
"Gmail drafts are created for all safe contacts without errors"
],
"constraints": [
"Agent creates Gmail drafts but NEVER sends emails automatically",
"Must not process more contacts than the configured max_contacts parameter",
"Contacts with risk_score >= 7 must be excluded from outreach"
],
"nodes": [
{
"id": "intake",
"name": "Intake",
"description": "Receive the contact list and outreach goal from the user. Confirm the strategy and batch size before proceeding.",
"node_type": "event_loop",
"tools": [
"load_contacts_from_file"
],
"input_keys": [
"contacts",
"outreach_goal",
"max_contacts",
"user_background"
],
"output_keys": [
"contacts",
"outreach_goal",
"max_contacts",
"user_background"
],
"success_criteria": "The user has confirmed the contact list, outreach goal, batch size, and their background. All four keys have been written via set_output.",
"sub_agents": [],
"flowchart_type": "start",
"flowchart_shape": "stadium",
"flowchart_color": "#8aad3f"
},
{
"id": "score-contacts",
"name": "Score Contacts",
"description": "Score and rank each contact from 0 to 100 based on priority factors: alumni status, connection degree, domain verification, mutual connections, and active job postings.",
"node_type": "event_loop",
"tools": [
"load_data",
"append_data"
],
"input_keys": [
"contacts",
"outreach_goal"
],
"output_keys": [
"scored_contacts"
],
"success_criteria": "Every contact has a priority_score field (0-100) and scored_contacts.jsonl has been written and referenced via set_output.",
"sub_agents": [],
"flowchart_type": "database",
"flowchart_shape": "cylinder",
"flowchart_color": "#508878"
},
{
"id": "filter-contacts",
"name": "Filter Contacts",
"description": "Analyze each contact for authenticity and filter out suspicious profiles. Any contact with a risk score of 7 or higher is skipped.",
"node_type": "event_loop",
"tools": [
"load_data",
"append_data"
],
"input_keys": [
"scored_contacts"
],
"output_keys": [
"safe_contacts",
"filtered_count"
],
"success_criteria": "Each contact has a risk_score and recommendation field. Contacts with risk_score >= 7 are excluded. safe_contacts.jsonl and filtered_count are set via set_output.",
"sub_agents": [],
"flowchart_type": "database",
"flowchart_shape": "cylinder",
"flowchart_color": "#508878"
},
{
"id": "personalize",
"name": "Personalize",
"description": "Generate a personalized outreach message for each contact based on their profile, shared background, and the user's outreach goal.",
"node_type": "event_loop",
"tools": [
"load_data",
"append_data"
],
"input_keys": [
"safe_contacts",
"outreach_goal",
"user_background"
],
"output_keys": [
"personalized_contacts"
],
"success_criteria": "Every safe contact has an outreach_message field of 80-120 words that references a specific hook from their profile. personalized_contacts.jsonl is set via set_output.",
"sub_agents": [],
"flowchart_type": "database",
"flowchart_shape": "cylinder",
"flowchart_color": "#508878"
},
{
"id": "send-outreach",
"name": "Send Outreach",
"description": "Create Gmail draft emails for each contact using their personalized message. Drafts are created for human review \u2014 emails are never sent automatically.",
"node_type": "event_loop",
"tools": [
"gmail_create_draft",
"load_data",
"append_data"
],
"input_keys": [
"personalized_contacts",
"outreach_goal"
],
"output_keys": [
"drafts_created"
],
"success_criteria": "A Gmail draft has been created for every safe contact. drafts.jsonl records each draft and drafts_created is set via set_output.",
"sub_agents": [],
"flowchart_type": "database",
"flowchart_shape": "cylinder",
"flowchart_color": "#508878"
},
{
"id": "report",
"name": "Report",
"description": "Generate a summary report of the outreach campaign: contacts scored, filtered, messaged, and drafts created. Present to user for review.",
"node_type": "event_loop",
"tools": [
"load_data"
],
"input_keys": [
"drafts_created",
"filtered_count",
"outreach_goal"
],
"output_keys": [
"summary_report"
],
"success_criteria": "A campaign summary has been presented to the user listing totals for contacts scored, filtered, messaged, and drafts created. summary_report is set via set_output.",
"sub_agents": [],
"flowchart_type": "terminal",
"flowchart_shape": "stadium",
"flowchart_color": "#b5453a"
}
],
"edges": [
{
"id": "edge-0",
"source": "intake",
"target": "score-contacts",
"condition": "on_success",
"description": "",
"label": ""
},
{
"id": "edge-1",
"source": "score-contacts",
"target": "filter-contacts",
"condition": "on_success",
"description": "",
"label": ""
},
{
"id": "edge-2",
"source": "filter-contacts",
"target": "personalize",
"condition": "on_success",
"description": "",
"label": ""
},
{
"id": "edge-3",
"source": "personalize",
"target": "send-outreach",
"condition": "on_success",
"description": "",
"label": ""
},
{
"id": "edge-4",
"source": "send-outreach",
"target": "report",
"condition": "on_success",
"description": "",
"label": ""
},
{
"id": "edge-5",
"source": "report",
"target": "intake",
"condition": "on_success",
"description": "",
"label": ""
}
],
"entry_node": "intake",
"terminal_nodes": [
"report"
],
"flowchart_legend": {
"start": {
"shape": "stadium",
"color": "#8aad3f"
},
"terminal": {
"shape": "stadium",
"color": "#b5453a"
},
"process": {
"shape": "rectangle",
"color": "#b5a575"
},
"decision": {
"shape": "diamond",
"color": "#d89d26"
},
"io": {
"shape": "parallelogram",
"color": "#d06818"
},
"document": {
"shape": "document",
"color": "#c4b830"
},
"database": {
"shape": "cylinder",
"color": "#508878"
},
"subprocess": {
"shape": "subroutine",
"color": "#887a48"
},
"browser": {
"shape": "hexagon",
"color": "#cc8850"
}
}
},
"flowchart_map": {
"intake": [
"intake"
],
"score-contacts": [
"score-contacts"
],
"filter-contacts": [
"filter-contacts"
],
"personalize": [
"personalize"
],
"send-outreach": [
"send-outreach"
],
"report": [
"report"
]
}
}
@@ -0,0 +1,14 @@
{
"hive-tools": {
"transport": "stdio",
"command": "uv",
"args": [
"run",
"python",
"mcp_server.py",
"--stdio"
],
"cwd": "../../../tools",
"description": "Hive tools MCP server"
}
}

Some files were not shown because too many files have changed in this diff Show More