Compare commits

...

78 Commits

Author SHA1 Message Date
Timothy f748187391 fix: posix only permission check for skill 2026-03-31 19:06:23 -07:00
Timothy eafbeb78b4 fix: python test 2026-03-31 18:55:24 -07:00
Timothy 5cb5083f8d fix(micro-fix): queen skill allowlist 2026-03-31 18:52:45 -07:00
Bryan @ Aden bf86daee92 Merge pull request #6319 from KartikPawade/fix/sap-tool-credential-store
fix: use CredentialStoreAdapter in sap_tool instead of raw os.getenv()
2026-03-31 18:30:21 -07:00
Timothy 43bbd0f31f feat(micro-fix): skill cli parser 2026-03-31 18:13:01 -07:00
Timothy @aden 2cf962b538 Merge pull request #6782 from levxn/skills/cli-commands
feat(skills): implement hive skill CLI subcommands (CLI-1 through CLI-13)
2026-03-31 17:59:25 -07:00
Timothy 4298196700 Merge branch 'main' into feature/agent-skills 2026-03-31 17:53:57 -07:00
Timothy @aden bc1f712e42 Merge pull request #6610 from levxn/skills/ds-ovrride-heuristics
feat(skills): DS-12 and DS-13 — config override application, batch auto-detection, and context preservation warning
2026-03-31 17:51:19 -07:00
Timothy @aden cccbcc8ec3 Merge pull request #6529 from vakrahul/fix/mcp-structured-errors
feat: structured MCP error codes and failure diagnostics (closes #6352)
2026-03-31 17:40:50 -07:00
Timothy @aden 0722f83f16 Merge pull request #6792 from fermano/feat/agent-selection-tool-resolution-n-framework-integration
Feat/agent selection tool resolution n framework integration
2026-03-31 17:38:39 -07:00
Hundao 72091d2783 fix(security): add SSRF protection to web_scrape tool (#6879)
Validate URLs against internal network ranges before making requests.
Block private IPs, loopback, link-local, and cloud metadata endpoints
(169.254.169.254). Intercept Playwright navigation to catch redirect-based
SSRF bypasses.

Fixes #1157

Co-authored-by: Harshit <Harshitk-cp@users.noreply.github.com>
2026-03-31 14:04:47 +08:00
Kartik 3bb69a5784 fix: add env fallback and type hints for SAP tool credentials
Made-with: Cursor
2026-03-31 11:15:41 +05:30
Kartik 63fb089062 chore: format sap_tool.py
Made-with: Cursor
2026-03-31 10:21:56 +05:30
Hundao d5ba985e29 docs: fix agent.json examples to match current schema (#6878)
Replace outdated node_id/edge_id with id, wrap nodes/edges under
graph key, add goal section with success_criteria. Matches what
load_agent_export() and NodeSpec actually expect.

Fixes #897

Co-authored-by: Jose37456 <Jose37456@users.noreply.github.com>
2026-03-31 12:41:59 +08:00
Bryan @ Aden 6ee510d2f6 Merge pull request #6855 from Ttian18/feat/tina/docs-mcp-unix-sse-transport
docs: add Unix socket and SSE transport to MCP Integration Guide (#6739)
2026-03-30 18:46:01 -07:00
Bryan @ Aden 45b350e7c8 Merge pull request #6857 from Ttian18/feat/tina/job-hunter-pdf-resume
feat(job-hunter): support PDF resume input via file path (#6740)
2026-03-30 18:45:36 -07:00
Bryan @ Aden 7e690de12f Merge pull request #6844 from sundaram2021/fix/quickstart-credentials-in-windows
micro-fix:  shell config handling and add antigravity option
2026-03-30 17:20:36 -07:00
Hundao ae85d2bf59 fix(security): prevent path traversal in session_store (#6876)
Validate that resolved session path stays within the sessions directory
using Path.is_relative_to(). Prevents session_id values like
"../../something" from escaping the sandbox.

Also guard the caller in _write_run_event where get_session_path is
called outside the existing OSError try/except block.

Fixes #1000

Co-authored-by: Sidhartha kumar <Alearner12@users.noreply.github.com>
2026-03-30 23:53:23 +08:00
Juttiga Bheemeswar e9fd0158b9 fix(csv_sql): prevent SQL injection via DuckDB parameter binding (#1408)
* fix(csv_sql): prevent SQL injection via DuckDB parameter binding

* test(csv_sql): add regression test for apostrophe path

* Refactor CSV query function for security and clarity

Removed detailed docstring arguments and return information for the CSV query function. Improved security checks for SQL queries.

* fix/1256-csv-sql-safe-path

Added security regression tests to reject non-SELECT queries and multi-statement queries.

* docs: restore csv_sql docstring (Args, Returns, Examples)

* fix: use word-boundary regex for SQL keyword detection

Substring matching caused false positives on column names like
created_at, updated_at, deleted_at. Switch to \b word-boundary regex.
Also add tests for comment rejection, CTE queries, and keyword-in-column-name.

---------

Co-authored-by: Juttiga Bheem <BBemail@gmail.com>
Co-authored-by: hundao <alchemy_wimp@hotmail.com>
2026-03-30 23:26:07 +08:00
Zhang 9a68a5d7ee fix(job-hunter): align intake node client_facing and input_keys with agent.json 2026-03-29 11:33:57 -07:00
Zhang 33edf4a207 feat(job-hunter): support PDF resume input via file path (#6740) 2026-03-29 11:26:35 -07:00
Zhang f9fdaf5adc docs: clarify required vs optional fields for Unix and SSE transports 2026-03-29 10:29:47 -07:00
Zhang eabb17934c docs: add Unix socket and SSE transport types to MCP Integration Guide 2026-03-29 10:24:57 -07:00
kernel_crush eba7524955 refactor: remove deprecated storage/backend.py (267 lines) (#6849)
* refactor: remove deprecated storage/backend.py (267 lines)

Delete the fully deprecated FileStorage class and inline its 5 still-active
methods (_validate_key, _load_run_sync, _load_summary_sync, _delete_run_sync,
_list_all_runs_sync) directly into ConcurrentStorage.

Changes:
- Delete core/framework/storage/backend.py (267 lines of no-op/deprecated code)
- Inline active read methods into ConcurrentStorage (no new FileStorage dep)
- Remove deprecated index operations (get_runs_by_goal, get_runs_by_status,
  get_runs_by_node, list_all_goals) and their associated locking
- Update __init__.py to export ConcurrentStorage instead of FileStorage
- Update runtime/core.py to use ConcurrentStorage directly
- Fix Runtime.end_run() to call save_run_sync() (sync wrapper) instead of
  the async save_run(), which was silently dropping the coroutine
- Update test_path_traversal_fix.py to test ConcurrentStorage._validate_key()
- Clean up test_storage.py — remove all FileStorage test classes, un-skip
  ConcurrentStorage tests now that it's self-contained
- Remove stale FileStorage references from testing/test_storage.py docstring,
  testing/debug_tool.py docstring, and test_runtime.py skip reasons

All 44 tests pass, ruff check and ruff format clean.

Fixes #6797

* fix(core): address CodeRabbitAI PR review feedback

 - Fix critical no-op in ConcurrentStorage._save_run_sync by implementing atomic persistence to 
uns/{run_id}.json.
 - Update 	est_path_traversal_fix.py to test ConcurrentStorage directly and use real file paths for end-to-end validation.
 - Unskip 	est_run_saved_on_end and assert actual run file persistence.
 - Fix debug_tool.py to use load_run_sync() instead of the async load_run().

* fix(core): address round 2 of CodeRabbitAI reviews

 - Add _validate_key to _save_run_sync and _load_summary_sync to enforce path traversal protections on the lowest level APIs.
 - Invalidate summary cache and refresh run cache in save_run_sync() to match the async save_run() cache coherence behavior.
 - Add tests for load_summary and save_run_sync path traversal rejection.
2026-03-29 22:48:12 +08:00
Sundaram Kumar Jha c56440340a Merge origin/main into fix/quickstart-credentials-in-windows 2026-03-29 08:44:26 +05:30
Bhuvaneswari N c889ffd85d feat(scripts): add support for more LLM providers in check_llm_key.py (#6833)
* feat(scripts): add support for more LLM providers in check_llm_key.py

* fix(scripts): correct perplexity endpoint to /v1/models and simplify lambda kwargs to **_
2026-03-29 09:11:25 +08:00
Md. Afzal Hassan Ehsani 905a4f3516 feat(quickstart): add Local (Ollama) LLM provider option (#6028)
* feat(quickstart): add Local (Ollama) LLM provider option
- Detect Ollama via 'ollama list' in quickstart.sh and quickstart.ps1
- Add 'Local (Ollama)' menu option with interactive model picker
- Save provider=ollama, model=<selected> to ~/.hive/configuration.json
- Omit api_key_env_var for Ollama (no API key required)
Refs #5154, #5231

* feat: add local Ollama support and resolve native tool calling

This integrates Ollama as a first-class local provider choice during quickstart, and patches several configuration barriers preventing local models from safely executing the framework's agent graphs.

* **Quickstart Integration**: Added `Local (Ollama)` to the provider menu in both quickstart.sh and quickstart.ps1. When selected, it automatically queries `ollama list` and allows the user to pick an installed model without prompting for an API key.
* **Routing & Configuration**: Automatically sets `"api_base": "http://localhost:11434"` so LiteLLM routes correctly to the local daemon, and increases the default max_tokens config.py allocation to `32768`.
* **Native Tool Calling**: Normalized Ollama models to strictly use the ollama_chat provider prefix inside litellm.py and registered them as `supports_function_calling: True`. This forces native structured function calling and fixes the infinite loop caused by JSON-mode text fallbacks.
* **Context Truncation Fix**: Updated config.py to explicitly pass `"num_ctx": 16384` to Ollama. This prevents the local daemon from silently truncating the Queen agent's ~9,500 token system prompt (Ollama defaults to 2048 `num_ctx`).
* **UX Warnings**: Added terminal notices warning users to select high-parameter models (e.g., `qwen2.5:72b+`) to ensure sufficient contextual reasoning abilities.

Resolves #6027
Resolves #6028

* test: add unit tests for Ollama helper functions

Cover _is_ollama_model(), _ensure_ollama_chat_prefix(), and num_ctx
injection in get_llm_extra_kwargs() as requested in PR review.
Fix existing test_init_ollama_no_key_needed assertion to expect the
normalised ollama_chat/ prefix.

Made-with: Cursor

* chores: fixed merge conflict

* fix(ollama): address PR review comments and normalize provider config

* fix(ollama): align quickstart defaults and add tool_choice comment

* fix(ollama): enforce OLLAMA_DETECTED logic and resolve quickstart script syntax errors

* fix(ollama): align quickstart logic and cleanup test imports
2026-03-29 08:51:47 +08:00
Sundaram Kumar Jha 941605720f fix: add missing antigravity subscription option 2026-03-28 23:46:01 +05:30
Sundaram Kumar Jha 72e5c5c1c6 test: cover shell config fallbacks 2026-03-28 23:38:02 +05:30
Sundaram Kumar Jha 0f42c8c8c1 fix: align Git Bash shell config handling 2026-03-28 23:37:53 +05:30
RichardTang-Aden c3c3075610 Merge pull request #6811 from Hundao/fix/lazy-import-resend
fix: lazy import resend in email_tool
2026-03-27 14:37:41 -07:00
Bryan @ Aden 86ef6fd8c5 Merge pull request #6822 from sundaram2021/fix/date-formatting-issue-on-windows
micro-fix: fix date formatting issue on windows and mattermost formatting issue
2026-03-27 07:24:26 -07:00
Sundaram Kumar Jha 95bdf4fe32 fix: mattermost formatting issue 2026-03-27 09:55:28 +05:30
Sundaram Kumar Jha 890d303d26 test: cover queen memory date formatting on Windows 2026-03-27 09:46:42 +05:30
Sundaram Kumar Jha 7fe60991e1 fix: use cross-platform queen memory date formatting 2026-03-27 09:46:27 +05:30
RichardTang-Aden a72938a163 Merge pull request #6747 from wakqasahmed/feat/mattermost-integration
feat(tools): add Mattermost messaging platform integration
2026-03-26 15:27:00 -07:00
Richard Tang 326a3dd1b7 docs: add honeycomb in readme 2026-03-26 14:55:00 -07:00
Richard Tang 183c6e2620 docs: readme with harness 2026-03-26 14:50:55 -07:00
Timothy @aden 1b40bff7da Merge pull request #6803 from aden-hive/fix/queen-cannot-read-skills
fix: allow curl in run_command and fix queen custom skill discovery
2026-03-26 12:56:03 -07:00
Timothy @aden 38b79edaee Merge pull request #6633 from sundaram2021/refactor/event-loop-node-modularization
refactor: modularize event loop node class methods and helpers
2026-03-26 12:47:53 -07:00
Sundaram Kumar Jha eb4f180192 chore: pull latest change 2026-03-27 00:48:01 +05:30
Sundaram Kumar Jha bf0b9a1edb refactor: cleanup compact llm function 2026-03-27 00:45:42 +05:30
Sundaram Kumar Jha 9667dd25cb chore: pull latest changes 2026-03-26 21:54:56 +05:30
hundao 33e4e8d440 fix: lazy import resend in email_tool to prevent tool registration crash
Fixes #4816
2026-03-26 18:43:04 +08:00
Shiva Santosh Reddy Aenugu c5ac29c81d fix(frontend): add 404 fallback route for unknown paths (#6373) 2026-03-26 18:24:01 +08:00
vakrahul 13c072d731 fix: match expected error message text in mcp_client and mcp_registry 2026-03-26 15:39:17 +05:30
Aaryann Chandola 5e31975cc3 feat(mcp-cli): add CLI management commands (#6350) (#6787)
* feat(mcp-cli): add hive mcp CLI management commands (#6350)

Implement the hive mcp subcommand group with shared helpers and all
P0/P1 management commands: install, add, remove, enable, disable,
list, info, config, search, health, update.

Includes update bridge (remove+reinstall with rollback on failure),
first-use security notice, credential prompting, secret masking,
and agent usage detection via load_agent_selection().

* test(mcp-cli): add CLI integration and handler tests (#6350)

58 tests covering all commands end-to-end:
- Real framework.cli.main() entrypoint dispatch (list, install, update)
- Real registry-on-disk integration (install, list, config, info, remove)
- All 11 command handlers (install, add, remove, enable, disable, list,
  info, config, search, health, update)
- Security notice shown only once
- Credential prompting stores overrides, skips when env set, handles cancel
- Secret masking in human output, JSON output, and config display
- Index refresh semantics (stale cache fallback vs no-cache hard fail)
- Update rollback on reinstall failure preserves original entry
- Update rejects local servers and pinned servers with correct remediation
- Bulk update skips local and pinned servers
- Argparse registration validates all 11 subcommands present
- _find_agents_using_server resolves via real load_agent_selection
- _parse_key_value_pairs validates KEY=VAL format

* fix(mcp-cli): mask list --json secrets, preserve enabled state on update, defer security sentinel (#6350)

- list --json now masks override values as <set> before emitting
- update preserves enabled=False state across reinstall
- security notice sentinel only written after successful install

* refactor(mcp-cli): fix docstring, share registry instance in update, extract _mask_overrides helper (#6350)

- Fix module docstring to reflect update's full behavior
- Pass registry instance to _cmd_mcp_update_server to avoid redundant disk I/O
- Extract _mask_overrides() used by list --json, info --json, info human, and config display
- Add comment about _find_agents_using_server path arithmetic limitation
2026-03-26 18:01:28 +08:00
vakrahul 82af76e72a feat: wire structured MCP errors into mcp_registry.py (closes #6352) 2026-03-26 15:30:10 +05:30
Amogh Raj a483f8d06a docs: add Windows quickstart.ps1 command in Quick Start section (#6781)
* docs: add Windows quickstart.ps1 command in Quick Start section

* fix: restore closing code fence and comment out Windows command

---------

Co-authored-by: hundao <alchemy_wimp@hotmail.com>
2026-03-26 17:27:31 +08:00
Sundaram Kumar Jha e188c26e9f chore: revert changes 2026-03-26 08:11:39 +05:30
Fernando Mano 22d9fba1fd Feature: #6351 - Agent selection, tool resolution & framework integration -- MCP Registry integration deleted local test code -- fix failing tests 2026-03-24 22:56:56 -03:00
Fernando Mano c7d0afc775 Feature: #6351 - Agent selection, tool resolution & framework integration
Made-with: Cursor
2026-03-24 22:34:52 -03:00
Fernando Mano 45aafbc52b Merge branch 'main' into feat/agent-selection-tool-resolution-n-framework-integration 2026-03-24 17:02:08 -03:00
Levin 567340c05d Merge branch 'aden-hive:main' into skills/cli-commands 2026-03-24 22:58:03 +05:30
levxn 8d8656193d bug fix 2026-03-24 22:09:54 +05:30
levxn ef317371ce hive skill test implemented, --json flag for machine parsable outputs, fixed lints 2026-03-24 21:51:14 +05:30
Levin d5596ccb0a Merge branch 'aden-hive:main' into skills/cli-commands 2026-03-24 21:47:35 +05:30
levxn 5f1530ec5b minor bug fix, and lint issue fixes 2026-03-24 15:52:10 +05:30
Levin 8af32b421c Merge branch 'aden-hive:main' into skills/cli-commands 2026-03-24 13:53:40 +05:30
levxn 95cc8a4513 cli commands, v1 2026-03-24 02:23:20 +05:30
Sundaram Kumar Jha d648f3d315 refactor(event-loop): slim event loop node orchestration 2026-03-24 01:00:08 +05:30
Sundaram Kumar Jha b43044cf4d refactor(event-loop): untangle modular event loop imports 2026-03-24 00:59:55 +05:30
Sundaram Kumar Jha 4724320946 refactor(event-loop): add shared event loop types 2026-03-24 00:59:35 +05:30
Waqas Ahmed 89ab2e0a74 feat(tools): add Mattermost messaging platform integration
Add Mattermost as a new messaging tool following the existing Discord/Telegram
pattern. Supports self-hosted and cloud instances via personal access tokens.

Tools: list_teams, list_channels, get_channel, send_message, get_posts,
create_reaction, delete_post. Includes rate limit retry logic, credential
store + env var fallback, and comprehensive tests (41 unit + 50 conformance).

Closes #6746

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-23 14:07:30 +02:00
Sundaram Kumar Jha ee4682c565 chore: pull latest changes ; fix: merge conflict 2026-03-22 08:52:17 +05:30
Sundaram Kumar Jha 9b59255770 chore: pull latest change , refactor: modularize latest change 2026-03-20 11:05:12 +05:30
Sundaram Kumar Jha 49fd443da8 chore: resolve merge conflict 2026-03-20 10:01:07 +05:30
Fernando Mano b599a760e8 Feature: #6351 - Agent selection, tool resolution & framework integration -- first version with mocked MCPRegistry 2026-03-19 10:48:37 -03:00
Levin b4a37cdb03 Merge branch 'aden-hive:main' into skills/ds-ovrride-heuristics 2026-03-19 18:51:29 +05:30
Sundaram Kumar Jha 4885db318e fix: merge conflict 2026-03-19 09:22:44 +05:30
Sundaram Kumar Jha fa7ce53fb3 style(repo): fix ruff format violations
Apply Ruff formatting to the extracted event loop modules, the EventLoopNode wrappers, and the OpenRouter key check script so the lint CI format check passes cleanly.
2026-03-19 09:20:18 +05:30
Sundaram Kumar Jha 75a2ef2c4a Merge branch 'main' into refactor/event-loop-node-modularization 2026-03-19 09:14:10 +05:30
Sundaram Kumar Jha a0b9d6afaf chore: refresh locks 2026-03-19 09:08:10 +05:30
Sundaram Kumar Jha 74c0a85e3f refactor(graph): modularize event loop helpers
Extract EventLoopNode helper logic into focused event_loop modules while keeping the node responsible for orchestration.

Preserve the existing behavior and compatibility for compaction, event publishing, cursor persistence, synthetic tools, judge evaluation, stall detection, tool result handling, and subagent escalation wiring.
2026-03-19 09:07:19 +05:30
levxn 96609386a3 lints fixed 2026-03-18 22:18:16 +05:30
levxn 0cef0e6990 DS-12, DS-13 skill config overrides and runtime heuristics 2026-03-18 22:13:09 +05:30
vakrahul 2f15a16159 feat: structured MCP error codes and failure diagnostics (closes #6352) 2026-03-16 22:50:14 +05:30
Kartik d433cda209 fix: use CredentialStoreAdapter in sap_tool instead of raw os.getenv()
Made-with: Cursor
2026-03-13 22:30:50 +05:30
90 changed files with 13013 additions and 3590 deletions
+4 -2
View File
@@ -13,6 +13,10 @@ out/
.env
.env.local
.env.*.local
.venv
/venv
tools/src/uv.lock
# User configuration (copied from .example)
config.yaml
@@ -69,8 +73,6 @@ exports/*
.claude/settings.local.json
.venv
docs/github-issues/*
core/tests/*dumps/*
+23 -16
View File
@@ -23,6 +23,7 @@
</p>
<p align="center">
<img src="https://img.shields.io/badge/Agent_Harness-Runtime_Layer-ff6600?style=flat-square" alt="Agent Harness" />
<img src="https://img.shields.io/badge/AI_Agents-Self--Improving-brightgreen?style=flat-square" alt="AI Agents" />
<img src="https://img.shields.io/badge/Multi--Agent-Systems-blue?style=flat-square" alt="Multi-Agent" />
<img src="https://img.shields.io/badge/Headless-Development-purple?style=flat-square" alt="Headless" />
@@ -35,39 +36,42 @@
<img src="https://img.shields.io/badge/Google_Gemini-supported-4285F4?style=flat-square&logo=google" alt="Gemini" />
</p>
<p align="center"><em>The agent harness for production workloads — state management, failure recovery, observability, and human oversight so your agents actually run.</em></p>
## Overview
Generate a swarm of worker agents with a coding agent(queen) that control them. Define your goal through conversation with hive queen, and the framework generates a node graph with dynamically created connection code. When things break, the framework captures failure data, evolves the agent through the coding agent, and redeploys. Built-in human-in-the-loop nodes, browser use, credential management, and real-time monitoring give you control without sacrificing adaptability.
Hive is a runtime harness for AI agents in production. You describe your goal in natural language; a coding agent (the queen) generates the agent graph and connection code to achieve it. During execution, the harness manages state isolation, checkpoint-based crash recovery, cost enforcement, and real-time observability. When agents fail, the framework captures failure data, evolves the graph through the coding agent, and redeploys automatically. Built-in human-in-the-loop nodes, browser control, credential management, and parallel execution give you production reliability without sacrificing adaptability.
Visit [adenhq.com](https://adenhq.com) for complete documentation, examples, and guides.
Visit [HoneyComb](http://honeycomb.open-hive.com/) to see what jobs are being automated by AI. Its a stock market for jobs, driven by our communitys AI agent progress. You can long and short jobs (with no real money but compute token)based on how much you think a job is going to be replaced by AI.
https://github.com/user-attachments/assets/bf10edc3-06ba-48b6-98ba-d069b15fb69d
## Who Is Hive For?
Hive is designed for developers and teams who want to build many **autonomous AI agents** fast without manually wiring complex workflows.
Hive is the harness layer for teams moving AI agents from prototype to production. Models are getting better on their own — the bottleneck is the infrastructure around them: state management, failure recovery, cost control, and observability.
Hive is a good fit if you:
- Want AI agents that **execute real business processes**, not demos
- Need **fast or high volume agent execution** over open workflow
- Need a **runtime that handles state, recovery, and parallel execution** at scale
- Need **self-healing and adaptive agents** that improve over time
- Require **human-in-the-loop control**, observability, and cost limits
- Plan to run agents in **production environments**
- Plan to run agents in **production** where uptime, cost, and auditability matter
Hive may not be the best fit if youre only experimenting with simple agent chains or one-off scripts.
## When Should You Use Hive?
Use Hive when you need:
Use Hive when the bottleneck is no longer the model but the harness around it:
- Long-running, autonomous agents
- Strong guardrails, process, and controls
- Continuous improvement based on failures
- Multi-agent coordination
- A framework that evolves with your goals
- Long-running agents that need **state persistence and crash recovery**
- Production workloads requiring **cost enforcement, observability, and audit trails**
- Agents that **self-heal** through failure capture and graph evolution
- Multi-agent coordination with **session isolation and shared memory**
- A framework that **scales with model improvements** rather than fighting them
## Quick Links
@@ -100,9 +104,11 @@ Use Hive when you need:
git clone https://github.com/aden-hive/hive.git
cd hive
# Run quickstart setup
# Run quickstart setup (macOS/Linux)
./quickstart.sh
# Windows (PowerShell)
.\quickstart.ps1
```
This sets up:
@@ -152,9 +158,9 @@ Hive is built to be model-agnostic and system-agnostic.
- **LLM flexibility** - Hive Framework supports Anthropic, OpenAI, OpenRouter, Hive LLM, and other hosted or local models through LiteLLM-compatible providers.
- **Business system connectivity** - Hive Framework is designed to connect to all kinds of business systems as tools, such as CRM, support, messaging, data, file, and internal APIs via MCP.
## Why Aden
## Why Hive
Hive focuses on generating agents that run real business processes rather than generic agents. Instead of requiring you to manually design workflows, define agent interactions, and handle failures reactively, Hive flips the paradigm: **you describe outcomes, and the system builds itself**—delivering an outcome-driven, adaptive experience with an easy-to-use set of tools and integrations.
As models improve, the upper bound of what agents can do rises — but their reliability and production value are determined by the harness. Hive focuses on generating agents that run real business processes rather than generic agents. Instead of requiring you to manually design workflows, define agent interactions, and handle failures reactively, Hive flips the paradigm: **you describe outcomes, and the system builds itself**—delivering an outcome-driven, adaptive experience with an easy-to-use set of tools and integrations.
```mermaid
flowchart LR
@@ -190,8 +196,9 @@ flowchart LR
### The Hive Advantage
| Traditional Frameworks | Hive |
| Typical Agent Frameworks | Hive |
| -------------------------- | -------------------------------------- |
| Focus on model orchestration | **Production harness**: state, recovery, observability |
| Hardcode agent workflows | Describe goals in natural language |
| Manual graph definition | Auto-generated agent graphs |
| Reactive error handling | Outcome-evaluation and adaptiveness |
@@ -385,7 +392,7 @@ Yes! Hive supports local models through LiteLLM. Simply use the model name forma
**Q: What makes Hive different from other agent frameworks?**
Hive generates your entire agent system from natural language goals using a coding agent—you don't hardcode workflows or manually define graphs. When agents fail, the framework automatically captures failure data, [evolves the agent graph](docs/key_concepts/evolution.md), and redeploys. This self-improving loop is unique to Aden.
Hive is an agent harness, not just an orchestration framework. It provides the production runtime layer — session isolation, checkpoint-based crash recovery, cost enforcement, real-time observability, and human-in-the-loop controls — that makes agents reliable enough to run real workloads. On top of that, Hive generates your entire agent system from natural language goals and automatically [evolves the graph](docs/key_concepts/evolution.md) when agents fail. The combination of a robust harness with self-improving generation is what sets Hive apart.
**Q: Is Hive open-source?**
+88 -3
View File
@@ -6,7 +6,7 @@ This guide explains how to integrate Model Context Protocol (MCP) servers with t
The framework provides built-in support for MCP servers, allowing you to:
- **Register MCP servers** via STDIO or HTTP transport
- **Register MCP servers** via STDIO, HTTP, Unix socket, or SSE transport
- **Auto-discover tools** from registered servers
- **Use MCP tools** seamlessly in your agents
- **Manage multiple MCP servers** simultaneously
@@ -104,6 +104,48 @@ runner.register_mcp_server(
- `url`: Base URL of the MCP server
- `headers`: HTTP headers to include (optional)
### Unix Socket Transport
Best for same-host inter-process communication with lower overhead than TCP:
```python
runner.register_mcp_server(
name="local-ipc-tools",
transport="unix",
url="http://localhost",
socket_path="/tmp/mcp_server.sock",
headers={
"Authorization": "Bearer token"
}
)
```
**Configuration:**
- `url`: Base URL for HTTP requests over the socket (required, e.g., `"http://localhost"`)
- `socket_path`: Absolute path to the Unix socket file (required, e.g., `"/tmp/mcp_server.sock"`)
- `headers`: HTTP headers to include (optional)
### SSE Transport
Best for real-time, event-driven connections using the MCP SDK's SSE client:
```python
runner.register_mcp_server(
name="streaming-tools",
transport="sse",
url="http://localhost:8000/sse",
headers={
"Authorization": "Bearer token"
}
)
```
**Configuration:**
- `url`: SSE endpoint URL (required, e.g., `"http://localhost:8000/sse"`)
- `headers`: HTTP headers for the SSE connection (optional)
## Using MCP Tools in Agents
Once registered, MCP tools are available just like any other tool:
@@ -258,7 +300,32 @@ runner.register_mcp_server(
)
```
### 3. Handle Cleanup
### 3. Use Unix Socket for Same-Host IPC
When both the agent and MCP server run on the same machine, Unix sockets avoid TCP overhead:
```python
runner.register_mcp_server(
name="fast-local-tools",
transport="unix",
url="http://localhost",
socket_path="/tmp/mcp_server.sock"
)
```
### 4. Use SSE for Streaming and Real-Time Tools
SSE transport maintains a persistent connection, ideal for event-driven servers:
```python
runner.register_mcp_server(
name="realtime-tools",
transport="sse",
url="http://realtime-server:8000/sse"
)
```
### 5. Handle Cleanup
Always clean up MCP connections when done:
@@ -280,7 +347,7 @@ async with AgentRunner.load("exports/my-agent") as runner:
# Automatic cleanup
```
### 4. Tool Name Conflicts
### 6. Tool Name Conflicts
If multiple MCP servers provide tools with the same name, the last registered server wins. To avoid conflicts:
@@ -315,6 +382,24 @@ If HTTP transport fails:
2. Check firewall settings
3. Verify the URL and port are correct
### Unix Socket Not Connecting
If Unix socket transport fails:
1. Verify the socket file exists: `ls -la /tmp/mcp_server.sock`
2. Check file permissions on the socket
3. Ensure no other process has locked the socket
4. Verify the `url` field is set (e.g., `"http://localhost"`)
### SSE Connection Issues
If SSE transport fails:
1. Verify the server supports SSE at the given URL
2. Check that the `mcp` Python package is installed (`pip install mcp`)
3. Ensure the SSE endpoint is accessible: `curl http://localhost:8000/sse`
4. Check for firewall or proxy issues blocking long-lived connections
## Example: Full Agent with MCP Tools
Here's a complete example of an agent that uses MCP tools:
@@ -584,11 +584,19 @@ class CredentialTesterAgent:
self._tool_registry.load_mcp_config(mcp_config_path)
try:
agent_dir = Path(__file__).parent
registry = MCPRegistry()
registry.initialize()
registry_configs = registry.load_agent_selection(Path(__file__).parent)
if (agent_dir / "mcp_registry.json").is_file():
self._tool_registry.set_mcp_registry_agent_path(agent_dir)
registry_configs, selection_max_tools = registry.load_agent_selection(agent_dir)
if registry_configs:
self._tool_registry.load_registry_servers(registry_configs)
self._tool_registry.load_registry_servers(
registry_configs,
preserve_existing_tools=True,
log_collisions=True,
max_tools=selection_max_tools,
)
except Exception:
logger.warning("MCP registry config failed to load", exc_info=True)
+9 -4
View File
@@ -31,6 +31,11 @@ def _queen_dir() -> Path:
return Path.home() / ".hive" / "queen"
def format_memory_date(d: date) -> str:
"""Return a cross-platform long date label without a zero-padded day."""
return f"{d.strftime('%B')} {d.day}, {d.year}"
def semantic_memory_path() -> Path:
return _queen_dir() / "MEMORY.md"
@@ -91,9 +96,9 @@ def format_for_injection() -> str:
content = content[:_EPISODIC_CHAR_BUDGET] + "\n\n…(truncated)"
today = date.today()
if d == today:
label = f"## Today — {d.strftime('%B %-d, %Y')}"
label = f"## Today — {format_memory_date(d)}"
else:
label = f"## {d.strftime('%B %-d, %Y')}"
label = f"## {format_memory_date(d)}"
parts.append(f"{label}\n\n{content}")
if not parts:
@@ -127,7 +132,7 @@ def append_episodic_entry(content: str) -> None:
ep_path = episodic_memory_path()
ep_path.parent.mkdir(parents=True, exist_ok=True)
today = date.today()
today_str = f"{today.strftime('%B')} {today.day}, {today.year}"
today_str = format_memory_date(today)
timestamp = datetime.now().strftime("%H:%M")
if not ep_path.exists():
header = f"# {today_str}\n\n"
@@ -331,7 +336,7 @@ async def consolidate_queen_memory(
existing_semantic = read_semantic_memory()
today_journal = read_episodic_memory()
today = date.today()
today_str = f"{today.strftime('%B')} {today.day}, {today.year}"
today_str = format_memory_date(today)
adapt_path = session_dir / "data" / "adapt.md"
user_msg = (
+5
View File
@@ -99,6 +99,11 @@ def main():
register_debugger_commands(subparsers)
# Register MCP registry commands (mcp install, mcp add, ...)
from framework.runner.mcp_registry_cli import register_mcp_commands
register_mcp_commands(subparsers)
args = parser.parse_args()
if hasattr(args, "func"):
+7
View File
@@ -186,6 +186,8 @@ def get_worker_llm_extra_kwargs() -> dict[str, Any]:
"store": False,
"allowed_openai_params": ["store"],
}
if worker_llm.get("provider") == "ollama":
return {"num_ctx": worker_llm.get("num_ctx", 16384)}
return {}
@@ -432,6 +434,11 @@ def get_llm_extra_kwargs() -> dict[str, Any]:
"store": False,
"allowed_openai_params": ["store"],
}
if llm.get("provider") == "ollama":
# Pass num_ctx to Ollama so it doesn't silently truncate the ~9.5k Queen prompt.
# Ollama's default num_ctx is only 2048. We set it to 16384 here so LiteLLM
# passes it through as a provider-specific option.
return {"num_ctx": llm.get("num_ctx", 16384)}
return {}
@@ -0,0 +1,6 @@
"""EventLoopNode subpackage — modular components of the event loop orchestrator.
All public symbols are re-exported by the parent ``event_loop_node.py`` for
backward compatibility. Internal consumers may import directly from these
submodules for clarity.
"""
@@ -0,0 +1,652 @@
"""Conversation compaction pipeline.
Implements the multi-level compaction strategy:
1. Prune old tool results
2. Structure-preserving compaction (spillover)
3. LLM summary compaction (with recursive splitting)
4. Emergency deterministic summary (no LLM)
"""
from __future__ import annotations
import json
import logging
import os
import re
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from framework.graph.conversation import NodeConversation
from framework.graph.event_loop.event_publishing import publish_context_usage
from framework.graph.event_loop.types import LoopConfig, OutputAccumulator
from framework.graph.node import NodeContext
from framework.runtime.event_bus import EventBus
logger = logging.getLogger(__name__)
# Limits for LLM compaction
LLM_COMPACT_CHAR_LIMIT: int = 240_000
LLM_COMPACT_MAX_DEPTH: int = 10
async def compact(
ctx: NodeContext,
conversation: NodeConversation,
accumulator: OutputAccumulator | None,
*,
config: LoopConfig,
event_bus: EventBus | None,
char_limit: int = LLM_COMPACT_CHAR_LIMIT,
max_depth: int = LLM_COMPACT_MAX_DEPTH,
) -> None:
"""Run the full compaction pipeline if conversation needs compaction.
Pipeline stages (in order, short-circuits when budget is restored):
1. Prune old tool results
2. Structure-preserving compaction (free, no LLM)
3. LLM summary compaction (recursive split if too large)
4. Emergency deterministic summary (fallback)
"""
ratio_before = conversation.usage_ratio()
phase_grad = getattr(ctx, "continuous_mode", False)
pre_inventory: list[dict[str, Any]] | None = None
if ratio_before >= 1.0:
pre_inventory = build_message_inventory(conversation)
# --- Step 1: Prune old tool results (free, fast) ---
protect = max(2000, config.max_context_tokens // 12)
pruned = await conversation.prune_old_tool_results(
protect_tokens=protect,
min_prune_tokens=max(1000, protect // 3),
)
if pruned > 0:
logger.info(
"Pruned %d old tool results: %.0f%% -> %.0f%%",
pruned,
ratio_before * 100,
conversation.usage_ratio() * 100,
)
if not conversation.needs_compaction():
await log_compaction(
ctx,
conversation,
ratio_before,
event_bus,
pre_inventory=pre_inventory,
)
return
# --- Step 2: Standard structure-preserving compaction (free, no LLM) ---
spill_dir = config.spillover_dir
if spill_dir:
await conversation.compact_preserving_structure(
spillover_dir=spill_dir,
keep_recent=4,
phase_graduated=phase_grad,
)
if not conversation.needs_compaction():
await log_compaction(
ctx,
conversation,
ratio_before,
event_bus,
pre_inventory=pre_inventory,
)
return
# --- Step 3: LLM summary compaction ---
if ctx.llm is not None:
logger.info(
"LLM summary compaction triggered (%.0f%% usage)",
conversation.usage_ratio() * 100,
)
try:
summary = await llm_compact(
ctx,
list(conversation.messages),
accumulator,
char_limit=char_limit,
max_depth=max_depth,
max_context_tokens=config.max_context_tokens,
)
await conversation.compact(
summary,
keep_recent=2,
phase_graduated=phase_grad,
)
except Exception as e:
logger.warning("LLM compaction failed: %s", e)
if not conversation.needs_compaction():
await log_compaction(
ctx,
conversation,
ratio_before,
event_bus,
pre_inventory=pre_inventory,
)
return
# --- Step 4: Emergency deterministic summary (LLM failed/unavailable) ---
logger.warning(
"Emergency compaction (%.0f%% usage)",
conversation.usage_ratio() * 100,
)
summary = build_emergency_summary(ctx, accumulator, conversation, config)
await conversation.compact(
summary,
keep_recent=1,
phase_graduated=phase_grad,
)
await log_compaction(
ctx,
conversation,
ratio_before,
event_bus,
pre_inventory=pre_inventory,
)
# --- LLM compaction with binary-search splitting ----------------------
async def llm_compact(
ctx: NodeContext,
messages: list,
accumulator: OutputAccumulator | None = None,
_depth: int = 0,
*,
char_limit: int = LLM_COMPACT_CHAR_LIMIT,
max_depth: int = LLM_COMPACT_MAX_DEPTH,
max_context_tokens: int = 128_000,
) -> str:
"""Summarise *messages* with LLM, splitting recursively if too large.
If the formatted text exceeds ``LLM_COMPACT_CHAR_LIMIT`` or the LLM
rejects the call with a context-length error, the messages are split
in half and each half is summarised independently. Tool history is
appended once at the top-level call (``_depth == 0``).
"""
from framework.graph.conversation import extract_tool_call_history
from framework.graph.event_loop.tool_result_handler import is_context_too_large_error
if _depth > max_depth:
raise RuntimeError(f"LLM compaction recursion limit ({max_depth})")
formatted = format_messages_for_summary(messages)
# Proactive split: avoid wasting an API call on oversized input
if len(formatted) > char_limit and len(messages) > 1:
summary = await _llm_compact_split(
ctx,
messages,
accumulator,
_depth,
char_limit=char_limit,
max_depth=max_depth,
max_context_tokens=max_context_tokens,
)
else:
prompt = build_llm_compaction_prompt(
ctx,
accumulator,
formatted,
max_context_tokens=max_context_tokens,
)
summary_budget = max(1024, max_context_tokens // 2)
try:
response = await ctx.llm.acomplete(
messages=[{"role": "user", "content": prompt}],
system=(
"You are a conversation compactor for an AI agent. "
"Write a detailed summary that allows the agent to "
"continue its work. Preserve user-stated rules, "
"constraints, and account/identity preferences verbatim."
),
max_tokens=summary_budget,
)
summary = response.content
except Exception as e:
if is_context_too_large_error(e) and len(messages) > 1:
logger.info(
"LLM context too large (depth=%d, msgs=%d) — splitting",
_depth,
len(messages),
)
summary = await _llm_compact_split(
ctx,
messages,
accumulator,
_depth,
char_limit=char_limit,
max_depth=max_depth,
max_context_tokens=max_context_tokens,
)
else:
raise
# Append tool history at top level only
if _depth == 0:
tool_history = extract_tool_call_history(messages)
if tool_history and "TOOLS ALREADY CALLED" not in summary:
summary += "\n\n" + tool_history
return summary
async def _llm_compact_split(
ctx: NodeContext,
messages: list,
accumulator: OutputAccumulator | None,
_depth: int,
*,
char_limit: int = LLM_COMPACT_CHAR_LIMIT,
max_depth: int = LLM_COMPACT_MAX_DEPTH,
max_context_tokens: int = 128_000,
) -> str:
"""Split messages in half and summarise each half independently."""
mid = max(1, len(messages) // 2)
s1 = await llm_compact(
ctx,
messages[:mid],
None,
_depth + 1,
char_limit=char_limit,
max_depth=max_depth,
max_context_tokens=max_context_tokens,
)
s2 = await llm_compact(
ctx,
messages[mid:],
accumulator,
_depth + 1,
char_limit=char_limit,
max_depth=max_depth,
max_context_tokens=max_context_tokens,
)
return s1 + "\n\n" + s2
# --- Compaction helpers ------------------------------------------------
def format_messages_for_summary(messages: list) -> str:
"""Format messages as text for LLM summarisation."""
lines: list[str] = []
for m in messages:
if m.role == "tool":
content = m.content[:500]
if len(m.content) > 500:
content += "..."
lines.append(f"[tool result]: {content}")
elif m.role == "assistant" and m.tool_calls:
names = [tc.get("function", {}).get("name", "?") for tc in m.tool_calls]
text = m.content[:200] if m.content else ""
lines.append(f"[assistant (calls: {', '.join(names)})]: {text}")
else:
lines.append(f"[{m.role}]: {m.content}")
return "\n\n".join(lines)
def build_llm_compaction_prompt(
ctx: NodeContext,
accumulator: OutputAccumulator | None,
formatted_messages: str,
*,
max_context_tokens: int = 128_000,
) -> str:
"""Build prompt for LLM compaction targeting 50% of token budget."""
spec = ctx.node_spec
ctx_lines = [f"NODE: {spec.name} (id={spec.id})"]
if spec.description:
ctx_lines.append(f"PURPOSE: {spec.description}")
if spec.success_criteria:
ctx_lines.append(f"SUCCESS CRITERIA: {spec.success_criteria}")
if accumulator:
acc = accumulator.to_dict()
done = {k: v for k, v in acc.items() if v is not None}
todo = [k for k, v in acc.items() if v is None]
if done:
ctx_lines.append(
"OUTPUTS ALREADY SET:\n"
+ "\n".join(f" {k}: {str(v)[:150]}" for k, v in done.items())
)
if todo:
ctx_lines.append(f"OUTPUTS STILL NEEDED: {', '.join(todo)}")
elif spec.output_keys:
ctx_lines.append(f"OUTPUTS STILL NEEDED: {', '.join(spec.output_keys)}")
target_tokens = max_context_tokens // 2
target_chars = target_tokens * 4
node_ctx = "\n".join(ctx_lines)
return (
"You are compacting an AI agent's conversation history. "
"The agent is still working and needs to continue.\n\n"
f"AGENT CONTEXT:\n{node_ctx}\n\n"
f"CONVERSATION MESSAGES:\n{formatted_messages}\n\n"
"INSTRUCTIONS:\n"
f"Write a summary of approximately {target_chars} characters "
f"(~{target_tokens} tokens).\n"
"1. Preserve ALL user-stated rules, constraints, and preferences "
"verbatim.\n"
"2. Preserve key decisions made and results obtained.\n"
"3. Preserve in-progress work state so the agent can continue.\n"
"4. Be detailed enough that the agent can resume without "
"re-doing work.\n"
)
def build_message_inventory(conversation: NodeConversation) -> list[dict[str, Any]]:
"""Build a per-message size inventory for debug logging."""
inventory: list[dict[str, Any]] = []
for message in conversation.messages:
content_chars = len(message.content)
tool_call_args_chars = 0
tool_name = None
if message.tool_calls:
for tool_call in message.tool_calls:
args = tool_call.get("function", {}).get("arguments", "")
tool_call_args_chars += (
len(args) if isinstance(args, str) else len(json.dumps(args))
)
names = [
tool_call.get("function", {}).get("name", "?") for tool_call in message.tool_calls
]
tool_name = ", ".join(names)
elif message.role == "tool" and message.tool_use_id:
for previous in conversation.messages:
if previous.tool_calls:
for tool_call in previous.tool_calls:
if tool_call.get("id") == message.tool_use_id:
tool_name = tool_call.get("function", {}).get("name", "?")
break
if tool_name:
break
entry: dict[str, Any] = {
"seq": message.seq,
"role": message.role,
"content_chars": content_chars,
}
if tool_call_args_chars:
entry["tool_call_args_chars"] = tool_call_args_chars
if tool_name:
entry["tool"] = tool_name
if message.is_error:
entry["is_error"] = True
if message.phase_id:
entry["phase"] = message.phase_id
if content_chars > 2000:
entry["preview"] = message.content[:200] + ""
inventory.append(entry)
return inventory
def write_compaction_debug_log(
ctx: NodeContext,
before_pct: int,
after_pct: int,
level: str,
inventory: list[dict[str, Any]] | None,
) -> None:
"""Write detailed compaction analysis to ~/.hive/compaction_log/."""
log_dir = Path.home() / ".hive" / "compaction_log"
log_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now(UTC).strftime("%Y%m%dT%H%M%S_%f")
node_label = ctx.node_id.replace("/", "_")
log_path = log_dir / f"{ts}_{node_label}.md"
lines: list[str] = [
f"# Compaction Debug — {ctx.node_id}",
f"**Time:** {datetime.now(UTC).isoformat()}",
f"**Node:** {ctx.node_spec.name} (`{ctx.node_id}`)",
]
if ctx.stream_id:
lines.append(f"**Stream:** {ctx.stream_id}")
lines.append(f"**Level:** {level}")
lines.append(f"**Usage:** {before_pct}% → {after_pct}%")
lines.append("")
if inventory:
total_chars = sum(
entry.get("content_chars", 0) + entry.get("tool_call_args_chars", 0)
for entry in inventory
)
lines.append(
"## Pre-Compaction Message Inventory "
f"({len(inventory)} messages, {total_chars:,} total chars)"
)
lines.append("")
ranked = sorted(
inventory,
key=lambda entry: entry.get("content_chars", 0) + entry.get("tool_call_args_chars", 0),
reverse=True,
)
lines.append("| # | seq | role | tool | chars | % of total | flags |")
lines.append("|---|-----|------|------|------:|------------|-------|")
for i, entry in enumerate(ranked, 1):
chars = entry.get("content_chars", 0) + entry.get("tool_call_args_chars", 0)
pct = (chars / total_chars * 100) if total_chars else 0
tool = entry.get("tool", "")
flags: list[str] = []
if entry.get("is_error"):
flags.append("error")
if entry.get("phase"):
flags.append(f"phase={entry['phase']}")
lines.append(
f"| {i} | {entry['seq']} | {entry['role']} | {tool} "
f"| {chars:,} | {pct:.1f}% | {', '.join(flags)} |"
)
large = [entry for entry in ranked if entry.get("preview")]
if large:
lines.append("")
lines.append("### Large message previews")
for entry in large:
lines.append(
f"\n**seq={entry['seq']}** ({entry['role']}, {entry.get('tool', '')}):"
)
lines.append(f"```\n{entry['preview']}\n```")
lines.append("")
try:
log_path.write_text("\n".join(lines), encoding="utf-8")
logger.debug("Compaction debug log written to %s", log_path)
except OSError:
logger.debug("Failed to write compaction debug log to %s", log_path)
async def log_compaction(
ctx: NodeContext,
conversation: NodeConversation,
ratio_before: float,
event_bus: EventBus | None,
*,
pre_inventory: list[dict[str, Any]] | None = None,
) -> None:
"""Log compaction result to runtime logger and event bus."""
ratio_after = conversation.usage_ratio()
before_pct = round(ratio_before * 100)
after_pct = round(ratio_after * 100)
# Determine label from what happened
if after_pct >= before_pct - 1:
level = "prune_only"
elif ratio_after <= 0.6:
level = "llm"
else:
level = "structural"
logger.info(
"Compaction complete (%s): %d%% -> %d%%",
level,
before_pct,
after_pct,
)
if ctx.runtime_logger:
ctx.runtime_logger.log_step(
node_id=ctx.node_id,
node_type="event_loop",
step_index=-1,
llm_text=f"Context compacted ({level}): {before_pct}% \u2192 {after_pct}%",
verdict="COMPACTION",
verdict_feedback=f"level={level} before={before_pct}% after={after_pct}%",
)
if event_bus:
from framework.runtime.event_bus import AgentEvent, EventType
event_data: dict[str, Any] = {
"level": level,
"usage_before": before_pct,
"usage_after": after_pct,
}
if pre_inventory is not None:
event_data["message_inventory"] = pre_inventory
await event_bus.publish(
AgentEvent(
type=EventType.CONTEXT_COMPACTED,
stream_id=ctx.stream_id or ctx.node_id,
node_id=ctx.node_id,
data=event_data,
)
)
await publish_context_usage(event_bus, ctx, conversation, "post_compaction")
if os.environ.get("HIVE_COMPACTION_DEBUG"):
write_compaction_debug_log(ctx, before_pct, after_pct, level, pre_inventory)
def build_emergency_summary(
ctx: NodeContext,
accumulator: OutputAccumulator | None = None,
conversation: NodeConversation | None = None,
config: LoopConfig | None = None,
) -> str:
"""Build a structured emergency compaction summary.
Unlike normal/aggressive compaction which uses an LLM summary,
emergency compaction cannot afford an LLM call (context is already
way over budget). Instead, build a deterministic summary from the
node's known state so the LLM can continue working after
compaction without losing track of its task and inputs.
"""
parts = [
"EMERGENCY COMPACTION — previous conversation was too large "
"and has been replaced with this summary.\n"
]
# 1. Node identity
spec = ctx.node_spec
parts.append(f"NODE: {spec.name} (id={spec.id})")
if spec.description:
parts.append(f"PURPOSE: {spec.description}")
# 2. Inputs the node received
input_lines = []
for key in spec.input_keys:
value = ctx.input_data.get(key) or ctx.memory.read(key)
if value is not None:
# Truncate long values but keep them recognisable
v_str = str(value)
if len(v_str) > 200:
v_str = v_str[:200] + ""
input_lines.append(f" {key}: {v_str}")
if input_lines:
parts.append("INPUTS:\n" + "\n".join(input_lines))
# 3. Output accumulator state (what's been set so far)
if accumulator:
acc_state = accumulator.to_dict()
set_keys = {k: v for k, v in acc_state.items() if v is not None}
missing = [k for k, v in acc_state.items() if v is None]
if set_keys:
lines = [f" {k}: {str(v)[:150]}" for k, v in set_keys.items()]
parts.append("OUTPUTS ALREADY SET:\n" + "\n".join(lines))
if missing:
parts.append(f"OUTPUTS STILL NEEDED: {', '.join(missing)}")
elif spec.output_keys:
parts.append(f"OUTPUTS STILL NEEDED: {', '.join(spec.output_keys)}")
# 4. Available tools reminder
if spec.tools:
parts.append(f"AVAILABLE TOOLS: {', '.join(spec.tools)}")
# 5. Spillover files — list actual files so the LLM can load
# them immediately instead of having to call list_data_files first.
# Inline adapt.md (agent memory) directly — it contains user rules
# and identity preferences that must survive emergency compaction.
spillover_dir = config.spillover_dir if config else None
if spillover_dir:
try:
from pathlib import Path
data_dir = Path(spillover_dir)
if data_dir.is_dir():
# Inline adapt.md content directly
adapt_path = data_dir / "adapt.md"
if adapt_path.is_file():
adapt_text = adapt_path.read_text(encoding="utf-8").strip()
if adapt_text:
parts.append(f"AGENT MEMORY (adapt.md):\n{adapt_text}")
all_files = sorted(
f.name for f in data_dir.iterdir() if f.is_file() and f.name != "adapt.md"
)
# Separate conversation history files from regular data files
conv_files = [f for f in all_files if re.match(r"conversation_\d+\.md$", f)]
data_files = [f for f in all_files if f not in conv_files]
if conv_files:
conv_list = "\n".join(
f" - {f} (full path: {data_dir / f})" for f in conv_files
)
parts.append(
"CONVERSATION HISTORY (freeform messages saved during compaction — "
"use load_data('<filename>') to review earlier dialogue):\n" + conv_list
)
if data_files:
file_list = "\n".join(
f" - {f} (full path: {data_dir / f})" for f in data_files[:30]
)
parts.append("DATA FILES (use load_data('<filename>') to read):\n" + file_list)
if not all_files:
parts.append(
"NOTE: Large tool results may have been saved to files. "
"Use list_directory to check the data directory."
)
except Exception:
parts.append(
"NOTE: Large tool results were saved to files. "
"Use read_file(path='<path>') to read them."
)
# 6. Tool call history (prevent re-calling tools)
if conversation is not None:
tool_history = _extract_tool_call_history(conversation)
if tool_history:
parts.append(tool_history)
parts.append(
"\nContinue working towards setting the remaining outputs. "
"Use your tools and the inputs above."
)
return "\n\n".join(parts)
def _extract_tool_call_history(conversation: NodeConversation) -> str:
"""Extract tool call history from conversation messages.
This is the instance-level variant that operates on a NodeConversation
directly (vs. the module-level extract_tool_call_history in conversation.py
which works on raw message lists).
"""
from framework.graph.conversation import extract_tool_call_history
return extract_tool_call_history(list(conversation.messages))
@@ -0,0 +1,239 @@
"""Cursor persistence, queue draining, and pause detection.
Handles the checkpoint/resume cycle: restoring state from a previous
conversation store, writing cursor data, and managing injection/trigger
queues between iterations.
"""
from __future__ import annotations
import asyncio
import json
import logging
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Any
from framework.graph.conversation import ConversationStore, NodeConversation
from framework.graph.event_loop.types import LoopConfig, OutputAccumulator, TriggerEvent
from framework.graph.node import NodeContext
from framework.llm.capabilities import supports_image_tool_results
logger = logging.getLogger(__name__)
@dataclass
class RestoredState:
"""State recovered from a previous checkpoint."""
conversation: NodeConversation
accumulator: OutputAccumulator
start_iteration: int
recent_responses: list[str]
recent_tool_fingerprints: list[list[tuple[str, str]]]
async def restore(
conversation_store: ConversationStore | None,
ctx: NodeContext,
config: LoopConfig,
) -> RestoredState | None:
"""Attempt to restore from a previous checkpoint.
Returns a ``RestoredState`` with conversation, accumulator, iteration
counter, and stall/doom-loop detection state everything needed to
resume exactly where execution stopped.
"""
if conversation_store is None:
return None
# In isolated mode, filter parts by phase_id so the node only sees
# its own messages in the shared flat conversation store. In
# continuous mode (or when _restore is called for timer-resume)
# load all parts — the full conversation threads across nodes.
_is_continuous = getattr(ctx, "continuous_mode", False)
phase_filter = None if _is_continuous else ctx.node_id
conversation = await NodeConversation.restore(
conversation_store,
phase_id=phase_filter,
)
if conversation is None:
return None
accumulator = await OutputAccumulator.restore(conversation_store)
accumulator.spillover_dir = config.spillover_dir
accumulator.max_value_chars = config.max_output_value_chars
cursor = await conversation_store.read_cursor()
start_iteration = cursor.get("iteration", 0) + 1 if cursor else 0
# Restore stall/doom-loop detection state
recent_responses: list[str] = cursor.get("recent_responses", []) if cursor else []
raw_fps = cursor.get("recent_tool_fingerprints", []) if cursor else []
recent_tool_fingerprints: list[list[tuple[str, str]]] = [
[tuple(pair) for pair in fps] # type: ignore[misc]
for fps in raw_fps
]
logger.info(
f"Restored event loop: iteration={start_iteration}, "
f"messages={conversation.message_count}, "
f"outputs={list(accumulator.values.keys())}, "
f"stall_window={len(recent_responses)}, "
f"doom_window={len(recent_tool_fingerprints)}"
)
return RestoredState(
conversation=conversation,
accumulator=accumulator,
start_iteration=start_iteration,
recent_responses=recent_responses,
recent_tool_fingerprints=recent_tool_fingerprints,
)
async def write_cursor(
conversation_store: ConversationStore | None,
ctx: NodeContext,
conversation: NodeConversation,
accumulator: OutputAccumulator,
iteration: int,
*,
recent_responses: list[str] | None = None,
recent_tool_fingerprints: list[list[tuple[str, str]]] | None = None,
) -> None:
"""Write checkpoint cursor for crash recovery.
Persists iteration counter, accumulator outputs, and stall/doom-loop
detection state so that resume picks up exactly where execution stopped.
"""
if conversation_store:
cursor = await conversation_store.read_cursor() or {}
cursor.update(
{
"iteration": iteration,
"node_id": ctx.node_id,
"next_seq": conversation.next_seq,
"outputs": accumulator.to_dict(),
}
)
# Persist stall/doom-loop detection state for reliable resume
if recent_responses is not None:
cursor["recent_responses"] = recent_responses
if recent_tool_fingerprints is not None:
# Convert list[list[tuple]] → list[list[list]] for JSON
cursor["recent_tool_fingerprints"] = [
[list(pair) for pair in fps] for fps in recent_tool_fingerprints
]
await conversation_store.write_cursor(cursor)
async def drain_injection_queue(
queue: asyncio.Queue,
conversation: NodeConversation,
*,
ctx: NodeContext,
describe_images_as_text_fn: (
Callable[[list[dict[str, Any]]], Awaitable[str | None]] | None
) = None,
) -> int:
"""Drain all pending injected events as user messages. Returns count."""
count = 0
while not queue.empty():
try:
content, is_client_input, image_content = queue.get_nowait()
logger.info(
"[drain] injected message (client_input=%s, images=%d): %s",
is_client_input,
len(image_content) if image_content else 0,
content[:200] if content else "(empty)",
)
if image_content and ctx.llm and not supports_image_tool_results(ctx.llm.model):
logger.info(
"Model '%s' does not support images; attempting vision fallback",
ctx.llm.model,
)
if describe_images_as_text_fn is not None:
description = await describe_images_as_text_fn(image_content)
if description:
content = f"{content}\n\n{description}" if content else description
logger.info("[drain] image described as text via vision fallback")
else:
logger.info("[drain] no vision fallback available; images dropped")
image_content = None
# Real user input is stored as-is; external events get a prefix
if is_client_input:
await conversation.add_user_message(
content,
is_client_input=True,
image_content=image_content,
)
else:
await conversation.add_user_message(f"[External event]: {content}")
count += 1
except asyncio.QueueEmpty:
break
return count
async def drain_trigger_queue(
queue: asyncio.Queue,
conversation: NodeConversation,
) -> int:
"""Drain all pending trigger events as a single batched user message.
Multiple triggers are merged so the LLM sees them atomically and can
reason about all pending triggers before acting.
"""
triggers: list[TriggerEvent] = []
while not queue.empty():
try:
triggers.append(queue.get_nowait())
except asyncio.QueueEmpty:
break
if not triggers:
return 0
parts: list[str] = []
for t in triggers:
task = t.payload.get("task", "")
task_line = f"\nTask: {task}" if task else ""
payload_str = json.dumps(t.payload, default=str)
parts.append(f"[TRIGGER: {t.trigger_type}/{t.source_id}]{task_line}\n{payload_str}")
combined = "\n\n".join(parts)
logger.info("[drain] %d trigger(s): %s", len(triggers), combined[:200])
await conversation.add_user_message(combined)
return len(triggers)
async def check_pause(
ctx: NodeContext,
conversation: NodeConversation,
iteration: int,
) -> bool:
"""
Check if pause has been requested. Returns True if paused.
Note: This check happens BEFORE starting iteration N, after completing N-1.
If paused, the node exits having completed {iteration} iterations (0 to iteration-1).
"""
# Check executor-level pause event (for /pause command, Ctrl+Z)
if ctx.pause_event and ctx.pause_event.is_set():
completed = iteration # 0-indexed: iteration=3 means 3 iterations completed (0,1,2)
logger.info(f"⏸ Pausing after {completed} iteration(s) completed (executor-level)")
return True
# Check context-level pause flags (legacy/alternative methods)
pause_requested = ctx.input_data.get("pause_requested", False)
if not pause_requested:
try:
pause_requested = ctx.memory.read("pause_requested") or False
except (PermissionError, KeyError):
pause_requested = False
if pause_requested:
completed = iteration
logger.info(f"⏸ Pausing after {completed} iteration(s) completed (context-level)")
return True
return False
@@ -0,0 +1,360 @@
"""EventBus publishing helpers for the event loop.
Thin wrappers around EventBus.emit_*() calls that check for bus existence
before publishing. Extracted to reduce noise in the main orchestrator.
"""
from __future__ import annotations
import logging
import time
from framework.graph.conversation import NodeConversation
from framework.graph.event_loop.types import HookContext
from framework.graph.node import NodeContext
from framework.runtime.event_bus import EventBus
logger = logging.getLogger(__name__)
async def publish_loop_started(
event_bus: EventBus | None,
stream_id: str,
node_id: str,
max_iterations: int,
execution_id: str = "",
) -> None:
if event_bus:
await event_bus.emit_node_loop_started(
stream_id=stream_id,
node_id=node_id,
max_iterations=max_iterations,
execution_id=execution_id,
)
async def generate_action_plan(
event_bus: EventBus | None,
ctx: NodeContext,
stream_id: str,
node_id: str,
execution_id: str,
) -> None:
"""Generate a brief action plan via LLM and emit it as an SSE event.
Runs as a fire-and-forget task so it never blocks the main loop.
"""
try:
system_prompt = ctx.node_spec.system_prompt or ""
# Trim to keep the prompt small
prompt_summary = system_prompt[:500]
if len(system_prompt) > 500:
prompt_summary += "..."
tool_names = [t.name for t in ctx.available_tools]
output_keys = ctx.node_spec.output_keys or []
prompt = (
f'You are about to work on a task as node "{node_id}".\n\n'
f"System prompt:\n{prompt_summary}\n\n"
f"Tools available: {tool_names}\n"
f"Required outputs: {output_keys}\n\n"
f"Write a brief action plan (2-5 bullet points) describing "
f"what you will do to complete this task. Be specific and concise.\n"
f"Return ONLY the plan text, no preamble."
)
response = await ctx.llm.acomplete(
messages=[{"role": "user", "content": prompt}],
max_tokens=1024,
)
plan = response.content.strip()
if plan and event_bus:
await event_bus.emit_node_action_plan(
stream_id=stream_id,
node_id=node_id,
plan=plan,
execution_id=execution_id,
)
except Exception as e:
logger.warning("Action plan generation failed for node '%s': %s", node_id, e)
async def publish_iteration(
event_bus: EventBus | None,
stream_id: str,
node_id: str,
iteration: int,
execution_id: str = "",
extra_data: dict | None = None,
) -> None:
if event_bus:
await event_bus.emit_node_loop_iteration(
stream_id=stream_id,
node_id=node_id,
iteration=iteration,
execution_id=execution_id,
extra_data=extra_data,
)
async def publish_llm_turn_complete(
event_bus: EventBus | None,
stream_id: str,
node_id: str,
stop_reason: str,
model: str,
input_tokens: int,
output_tokens: int,
cached_tokens: int = 0,
execution_id: str = "",
iteration: int | None = None,
) -> None:
if event_bus:
await event_bus.emit_llm_turn_complete(
stream_id=stream_id,
node_id=node_id,
stop_reason=stop_reason,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cached_tokens=cached_tokens,
execution_id=execution_id,
iteration=iteration,
)
def log_skip_judge(
ctx: NodeContext,
node_id: str,
iteration: int,
feedback: str,
tool_calls: list[dict],
llm_text: str,
turn_tokens: dict[str, int],
iter_start: float,
) -> None:
"""Log a CONTINUE step that skips judge evaluation (e.g., waiting for input)."""
if ctx.runtime_logger:
ctx.runtime_logger.log_step(
node_id=node_id,
node_type="event_loop",
step_index=iteration,
verdict="CONTINUE",
verdict_feedback=feedback,
tool_calls=tool_calls,
llm_text=llm_text,
input_tokens=turn_tokens.get("input", 0),
output_tokens=turn_tokens.get("output", 0),
latency_ms=int((time.time() - iter_start) * 1000),
)
async def publish_loop_completed(
event_bus: EventBus | None,
stream_id: str,
node_id: str,
iterations: int,
execution_id: str = "",
) -> None:
if event_bus:
await event_bus.emit_node_loop_completed(
stream_id=stream_id,
node_id=node_id,
iterations=iterations,
execution_id=execution_id,
)
async def publish_context_usage(
event_bus: EventBus | None,
ctx: NodeContext,
conversation: NodeConversation,
trigger: str,
) -> None:
"""Emit a CONTEXT_USAGE_UPDATED event with current context window state."""
if not event_bus:
return
from framework.runtime.event_bus import AgentEvent, EventType
estimated = conversation.estimate_tokens()
max_tokens = conversation._max_context_tokens
ratio = estimated / max_tokens if max_tokens > 0 else 0.0
await event_bus.publish(
AgentEvent(
type=EventType.CONTEXT_USAGE_UPDATED,
stream_id=ctx.stream_id or ctx.node_id,
node_id=ctx.node_id,
data={
"usage_ratio": round(ratio, 4),
"usage_pct": round(ratio * 100),
"message_count": conversation.message_count,
"estimated_tokens": estimated,
"max_context_tokens": max_tokens,
"trigger": trigger,
},
)
)
async def publish_stalled(
event_bus: EventBus | None,
stream_id: str,
node_id: str,
execution_id: str = "",
) -> None:
if event_bus:
await event_bus.emit_node_stalled(
stream_id=stream_id,
node_id=node_id,
reason="Consecutive similar responses detected",
execution_id=execution_id,
)
async def publish_text_delta(
event_bus: EventBus | None,
stream_id: str,
node_id: str,
content: str,
snapshot: str,
ctx: NodeContext,
execution_id: str = "",
iteration: int | None = None,
inner_turn: int = 0,
) -> None:
if event_bus:
if ctx.node_spec.client_facing:
await event_bus.emit_client_output_delta(
stream_id=stream_id,
node_id=node_id,
content=content,
snapshot=snapshot,
execution_id=execution_id,
iteration=iteration,
inner_turn=inner_turn,
)
else:
await event_bus.emit_llm_text_delta(
stream_id=stream_id,
node_id=node_id,
content=content,
snapshot=snapshot,
execution_id=execution_id,
inner_turn=inner_turn,
)
async def publish_tool_started(
event_bus: EventBus | None,
stream_id: str,
node_id: str,
tool_use_id: str,
tool_name: str,
tool_input: dict,
execution_id: str = "",
) -> None:
if event_bus:
await event_bus.emit_tool_call_started(
stream_id=stream_id,
node_id=node_id,
tool_use_id=tool_use_id,
tool_name=tool_name,
tool_input=tool_input,
execution_id=execution_id,
)
async def publish_tool_completed(
event_bus: EventBus | None,
stream_id: str,
node_id: str,
tool_use_id: str,
tool_name: str,
result: str,
is_error: bool,
execution_id: str = "",
) -> None:
if event_bus:
await event_bus.emit_tool_call_completed(
stream_id=stream_id,
node_id=node_id,
tool_use_id=tool_use_id,
tool_name=tool_name,
result=result,
is_error=is_error,
execution_id=execution_id,
)
async def publish_judge_verdict(
event_bus: EventBus | None,
stream_id: str,
node_id: str,
action: str,
feedback: str = "",
judge_type: str = "implicit",
iteration: int = 0,
execution_id: str = "",
) -> None:
if event_bus:
await event_bus.emit_judge_verdict(
stream_id=stream_id,
node_id=node_id,
action=action,
feedback=feedback,
judge_type=judge_type,
iteration=iteration,
execution_id=execution_id,
)
async def publish_output_key_set(
event_bus: EventBus | None,
stream_id: str,
node_id: str,
key: str,
execution_id: str = "",
) -> None:
if event_bus:
await event_bus.emit_output_key_set(
stream_id=stream_id, node_id=node_id, key=key, execution_id=execution_id
)
async def run_hooks(
hooks_config: dict[str, list],
event: str,
conversation: NodeConversation,
trigger: str | None = None,
) -> None:
"""Run all registered hooks for *event*, applying their results.
Each hook receives a HookContext and may return a HookResult that:
- replaces the system prompt (result.system_prompt)
- injects an extra user message (result.inject)
Hooks run in registration order; each sees the prompt as left by the
previous hook.
"""
hook_list = hooks_config.get(event, [])
if not hook_list:
return
for hook in hook_list:
ctx = HookContext(
event=event,
trigger=trigger,
system_prompt=conversation.system_prompt,
)
try:
result = await hook(ctx)
except Exception:
logger.warning("Hook '%s' raised an exception", event, exc_info=True)
continue
if result is None:
continue
if result.system_prompt:
conversation.update_system_prompt(result.system_prompt)
if result.inject:
await conversation.add_user_message(result.inject)
@@ -0,0 +1,175 @@
"""Judge evaluation pipeline for the event loop."""
from __future__ import annotations
import logging
from collections.abc import Callable
from framework.graph.conversation import NodeConversation
from framework.graph.event_loop.types import JudgeProtocol, JudgeVerdict, OutputAccumulator
from framework.graph.node import NodeContext
logger = logging.getLogger(__name__)
class SubagentJudge:
"""Judge for subagent execution."""
def __init__(self, task: str, max_iterations: int = 10):
self._task = task
self._max_iterations = max_iterations
async def evaluate(self, context: dict[str, object]) -> JudgeVerdict:
missing = context.get("missing_keys", [])
if not isinstance(missing, list) or not missing:
return JudgeVerdict(action="ACCEPT", feedback="")
iteration = context.get("iteration", 0)
if not isinstance(iteration, int):
iteration = 0
remaining = self._max_iterations - iteration - 1
if remaining <= 3:
urgency = (
f"URGENT: Only {remaining} iterations left. "
f"Stop all other work and call set_output NOW for: {missing}"
)
elif remaining <= self._max_iterations // 2:
urgency = (
f"WARNING: {remaining} iterations remaining. "
f"You must call set_output for: {missing}"
)
else:
urgency = f"Missing output keys: {missing}. Use set_output to provide them."
return JudgeVerdict(action="RETRY", feedback=f"Your task: {self._task}\n{urgency}")
async def judge_turn(
*,
mark_complete_flag: bool,
judge: JudgeProtocol | None,
ctx: NodeContext,
conversation: NodeConversation,
accumulator: OutputAccumulator,
assistant_text: str,
tool_results: list[dict[str, object]],
iteration: int,
get_missing_output_keys_fn: Callable[
[OutputAccumulator, list[str] | None, list[str] | None],
list[str],
],
max_context_tokens: int,
) -> JudgeVerdict:
"""Evaluate the current state using judge or implicit logic.
Evaluation levels (in order):
0. Short-circuits: mark_complete, skip_judge, tool-continue.
1. Custom judge (JudgeProtocol) full authority when set.
2. Implicit judge output-key check + optional conversation-aware
quality gate (when ``success_criteria`` is defined).
Returns a JudgeVerdict. ``feedback=None`` means no real evaluation
happened (skip_judge, tool-continue); the caller must not inject a
feedback message. Any non-None feedback (including ``""``) means a
real evaluation occurred and will be logged into the conversation.
"""
# --- Level 0: short-circuits (no evaluation) -----------------------
if mark_complete_flag:
return JudgeVerdict(action="ACCEPT")
if ctx.node_spec.skip_judge:
return JudgeVerdict(action="RETRY") # feedback=None → not logged
# --- Level 1: custom judge -----------------------------------------
if judge is not None:
context = {
"assistant_text": assistant_text,
"tool_calls": tool_results,
"output_accumulator": accumulator.to_dict(),
"accumulator": accumulator,
"iteration": iteration,
"conversation_summary": conversation.export_summary(),
"output_keys": ctx.node_spec.output_keys,
"missing_keys": get_missing_output_keys_fn(
accumulator, ctx.node_spec.output_keys, ctx.node_spec.nullable_output_keys
),
}
verdict = await judge.evaluate(context)
# Ensure evaluated RETRY always carries feedback for logging.
if verdict.action == "RETRY" and not verdict.feedback:
return JudgeVerdict(action="RETRY", feedback="Custom judge returned RETRY.")
return verdict
# --- Level 2: implicit judge ---------------------------------------
# Real tool calls were made — let the agent keep working.
if tool_results:
return JudgeVerdict(action="RETRY") # feedback=None → not logged
missing = get_missing_output_keys_fn(
accumulator, ctx.node_spec.output_keys, ctx.node_spec.nullable_output_keys
)
if missing:
return JudgeVerdict(
action="RETRY",
feedback=(
f"Task incomplete. Required outputs not yet produced: {missing}. "
f"Follow your system prompt instructions to complete the work."
),
)
# All output keys present — run safety checks before accepting.
output_keys = ctx.node_spec.output_keys or []
nullable_keys = set(ctx.node_spec.nullable_output_keys or [])
# All-nullable with nothing set → node produced nothing useful.
all_nullable = output_keys and nullable_keys >= set(output_keys)
none_set = not any(accumulator.get(k) is not None for k in output_keys)
if all_nullable and none_set:
return JudgeVerdict(
action="RETRY",
feedback=(
f"No output keys have been set yet. "
f"Use set_output to set at least one of: {output_keys}"
),
)
# Client-facing with no output keys → continuous interaction node.
# Inject tool-use pressure instead of auto-accepting.
if not output_keys and ctx.node_spec.client_facing:
return JudgeVerdict(
action="RETRY",
feedback=(
"STOP describing what you will do. "
"You have FULL access to all tools — file creation, "
"shell commands, MCP tools — and you CAN call them "
"directly in your response. Respond ONLY with tool "
"calls, no prose. Execute the task now."
),
)
# Level 2b: conversation-aware quality check (if success_criteria set)
if ctx.node_spec.success_criteria and ctx.llm:
from framework.graph.conversation_judge import evaluate_phase_completion
verdict = await evaluate_phase_completion(
llm=ctx.llm,
conversation=conversation,
phase_name=ctx.node_spec.name,
phase_description=ctx.node_spec.description,
success_criteria=ctx.node_spec.success_criteria,
accumulator_state=accumulator.to_dict(),
max_context_tokens=max_context_tokens,
)
if verdict.action != "ACCEPT":
return JudgeVerdict(
action=verdict.action,
feedback=verdict.feedback or "Phase criteria not met.",
)
return JudgeVerdict(action="ACCEPT", feedback="")
@@ -0,0 +1,106 @@
"""Stall and doom-loop detection for the event loop.
Pure functions with no class dependencies safe to call from any context.
"""
from __future__ import annotations
import json
def ngram_similarity(s1: str, s2: str, n: int = 2) -> float:
"""Jaccard similarity of n-gram sets.
Returns 0.0-1.0, where 1.0 is exact match.
Fast: O(len(s) + len(s2)) using set operations.
"""
def _ngrams(s: str) -> set[str]:
return {s[i : i + n] for i in range(len(s) - n + 1) if s.strip()}
if not s1 or not s2:
return 0.0
ngrams1, ngrams2 = _ngrams(s1.lower()), _ngrams(s2.lower())
if not ngrams1 or not ngrams2:
return 0.0
intersection = len(ngrams1 & ngrams2)
union = len(ngrams1 | ngrams2)
return intersection / union if union else 0.0
def is_stalled(
recent_responses: list[str],
threshold: int,
similarity_threshold: float,
) -> bool:
"""Detect stall using n-gram similarity.
Detects when ALL N consecutive responses are mutually similar
(>= threshold). A single dissimilar response resets the signal.
This catches phrases like "I'm still stuck" vs "I'm stuck"
without false-positives on "attempt 1" vs "attempt 2".
"""
if len(recent_responses) < threshold:
return False
if not recent_responses[0]:
return False
# Every consecutive pair must be similar
for i in range(1, len(recent_responses)):
if ngram_similarity(recent_responses[i], recent_responses[i - 1]) < similarity_threshold:
return False
return True
def fingerprint_tool_calls(
tool_results: list[dict],
) -> list[tuple[str, str]]:
"""Create deterministic fingerprints for a turn's tool calls.
Each fingerprint is (tool_name, canonical_args_json). Order-sensitive
so [search("a"), fetch("b")] != [fetch("b"), search("a")].
"""
fingerprints = []
for tr in tool_results:
name = tr.get("tool_name", "")
args = tr.get("tool_input", {})
try:
canonical = json.dumps(args, sort_keys=True, default=str)
except (TypeError, ValueError):
canonical = str(args)
fingerprints.append((name, canonical))
return fingerprints
def is_tool_doom_loop(
recent_tool_fingerprints: list[list[tuple[str, str]]],
threshold: int,
enabled: bool = True,
) -> tuple[bool, str]:
"""Detect doom loop via exact fingerprint match.
Detects when N consecutive turns invoke the same tools with
identical (canonicalized) arguments. Different arguments mean
different work, so only exact matches count.
Returns (is_doom_loop, description).
"""
if not enabled:
return False, ""
if len(recent_tool_fingerprints) < threshold:
return False, ""
first = recent_tool_fingerprints[0]
if not first:
return False, ""
# All turns in the window must match the first exactly
if all(fp == first for fp in recent_tool_fingerprints[1:]):
tool_names = [name for name, _ in first]
desc = (
f"Doom loop detected: {len(recent_tool_fingerprints)} "
f"identical consecutive tool calls ({', '.join(tool_names)})"
)
return True, desc
return False, ""
@@ -0,0 +1,412 @@
"""Subagent execution for the event loop.
Handles the full subagent lifecycle: validation, context setup, tool filtering,
conversation store derivation, execution, and cleanup. Also includes the
_EscalationReceiver helper used for subagent queen escalation routing.
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
from collections.abc import Awaitable, Callable
from pathlib import Path
from typing import TYPE_CHECKING, Any
from framework.graph.conversation import ConversationStore
from framework.graph.event_loop.judge_pipeline import SubagentJudge
from framework.graph.event_loop.types import LoopConfig, OutputAccumulator
from framework.graph.node import NodeContext, SharedMemory
from framework.llm.provider import ToolResult, ToolUse
from framework.runtime.event_bus import EventBus
if TYPE_CHECKING:
from framework.graph.event_loop_node import EventLoopNode
logger = logging.getLogger(__name__)
class EscalationReceiver:
"""Temporary receiver registered in node_registry for subagent escalation routing.
When a subagent calls ``report_to_parent(wait_for_response=True)``, the callback
creates one of these, registers it under a unique escalation ID in the executor's
``node_registry``, and awaits ``wait()``. The TUI / runner calls
``inject_input(escalation_id, content)`` which the ``ExecutionStream`` routes here
via ``inject_event()`` matching the same ``hasattr(node, "inject_event")`` check
used for regular ``EventLoopNode`` instances.
"""
def __init__(self) -> None:
self._event = asyncio.Event()
self._response: str | None = None
self._awaiting_input = True # So inject_worker_message() can prefer us
async def inject_event(
self,
content: str,
*,
is_client_input: bool = False,
image_content: list[dict[str, Any]] | None = None,
) -> None:
"""Called by ExecutionStream.inject_input() when the user responds."""
self._response = content
self._event.set()
async def wait(self) -> str | None:
"""Block until inject_event() delivers the user's response."""
await self._event.wait()
return self._response
async def execute_subagent(
ctx: NodeContext,
agent_id: str,
task: str,
*,
config: LoopConfig,
event_loop_node_cls: type[EventLoopNode],
escalation_receiver_cls: type[EscalationReceiver],
accumulator: OutputAccumulator | None = None,
event_bus: EventBus | None = None,
tool_executor: Callable[[ToolUse], ToolResult | Awaitable[ToolResult]] | None = None,
conversation_store: ConversationStore | None = None,
subagent_instance_counter: dict[str, int] | None = None,
) -> ToolResult:
"""Execute a subagent and return the result as a ToolResult.
The subagent:
- Gets a fresh conversation with just the task
- Has read-only access to the parent's readable memory
- Cannot delegate to its own subagents (prevents recursion)
- Returns its output in structured JSON format
Args:
ctx: Parent node's context (for memory, tools, LLM access).
agent_id: The node ID of the subagent to invoke.
task: The task description to give the subagent.
accumulator: Parent's OutputAccumulator.
event_bus: EventBus for lifecycle events.
config: LoopConfig for iteration/tool limits.
tool_executor: Tool executor callable.
conversation_store: Parent conversation store (for deriving subagent store).
subagent_instance_counter: Mutable counter dict for unique subagent paths.
Returns:
ToolResult with structured JSON output.
"""
# Log subagent invocation start
logger.info(
"\n" + "=" * 60 + "\n"
"🤖 SUBAGENT INVOCATION\n"
"=" * 60 + "\n"
"Parent Node: %s\n"
"Subagent ID: %s\n"
"Task: %s\n" + "=" * 60,
ctx.node_id,
agent_id,
task[:500] + "..." if len(task) > 500 else task,
)
# 1. Validate agent exists in registry
if agent_id not in ctx.node_registry:
return ToolResult(
tool_use_id="",
content=json.dumps(
{
"message": f"Sub-agent '{agent_id}' not found in registry",
"data": None,
"metadata": {"agent_id": agent_id, "success": False, "error": "not_found"},
}
),
is_error=True,
)
subagent_spec = ctx.node_registry[agent_id]
# 2. Create read-only memory snapshot
parent_data = ctx.memory.read_all()
# Merge in-flight outputs from the parent's accumulator.
if accumulator:
for key, value in accumulator.to_dict().items():
if key not in parent_data:
parent_data[key] = value
subagent_memory = SharedMemory()
for key, value in parent_data.items():
subagent_memory.write(key, value, validate=False)
read_keys = set(parent_data.keys()) | set(subagent_spec.input_keys or [])
scoped_memory = subagent_memory.with_permissions(
read_keys=list(read_keys),
write_keys=[], # Read-only!
)
# 2b. Compute instance counter early so the callback and child context
# share the same stable node_id for this subagent invocation.
if subagent_instance_counter is not None:
subagent_instance_counter.setdefault(agent_id, 0)
subagent_instance_counter[agent_id] += 1
subagent_instance = str(subagent_instance_counter[agent_id])
else:
subagent_instance = "1"
if subagent_instance == "1":
sa_node_id = f"{ctx.node_id}:subagent:{agent_id}"
else:
sa_node_id = f"{ctx.node_id}:subagent:{agent_id}:{subagent_instance}"
# 2c. Set up report callback (one-way channel to parent / event bus)
subagent_reports: list[dict] = []
async def _report_callback(
message: str,
data: dict | None = None,
*,
wait_for_response: bool = False,
) -> str | None:
subagent_reports.append({"message": message, "data": data, "timestamp": time.time()})
if event_bus:
await event_bus.emit_subagent_report(
stream_id=ctx.node_id,
node_id=sa_node_id,
subagent_id=agent_id,
message=message,
data=data,
execution_id=ctx.execution_id,
)
if not wait_for_response:
return None
if not event_bus:
logger.warning(
"Subagent '%s' requested user response but no event_bus available",
agent_id,
)
return None
# Create isolated receiver and register for input routing
import uuid
escalation_id = f"{ctx.node_id}:escalation:{uuid.uuid4().hex[:8]}"
receiver = escalation_receiver_cls()
registry = ctx.shared_node_registry
registry[escalation_id] = receiver
try:
await event_bus.emit_escalation_requested(
stream_id=ctx.stream_id or ctx.node_id,
node_id=escalation_id,
reason=f"Subagent report (wait_for_response) from {agent_id}",
context=message,
execution_id=ctx.execution_id,
)
# Block until queen responds
return await receiver.wait()
finally:
registry.pop(escalation_id, None)
# 3. Filter tools for subagent
subagent_tool_names = set(subagent_spec.tools or [])
tool_source = ctx.all_tools if ctx.all_tools else ctx.available_tools
# GCU auto-population
if subagent_spec.node_type == "gcu" and not subagent_tool_names:
subagent_tools = [t for t in tool_source if t.name != "delegate_to_sub_agent"]
else:
subagent_tools = [
t
for t in tool_source
if t.name in subagent_tool_names and t.name != "delegate_to_sub_agent"
]
missing = subagent_tool_names - {t.name for t in subagent_tools}
if missing:
logger.warning(
"Subagent '%s' requested tools not found in catalog: %s",
agent_id,
sorted(missing),
)
logger.info(
"📦 Subagent '%s' configuration:\n"
" - System prompt: %s\n"
" - Tools available (%d): %s\n"
" - Memory keys inherited: %s",
agent_id,
(subagent_spec.system_prompt[:200] + "...")
if subagent_spec.system_prompt and len(subagent_spec.system_prompt) > 200
else subagent_spec.system_prompt,
len(subagent_tools),
[t.name for t in subagent_tools],
list(parent_data.keys()),
)
# 4. Build subagent context
max_iter = min(config.max_iterations, 10)
subagent_ctx = NodeContext(
runtime=ctx.runtime,
node_id=sa_node_id,
node_spec=subagent_spec,
memory=scoped_memory,
input_data={"task": task, **parent_data},
llm=ctx.llm,
available_tools=subagent_tools,
goal_context=(
f"Your specific task: {task}\n\n"
f"COMPLETION REQUIREMENTS:\n"
f"When your task is done, you MUST call set_output() "
f"for each required key: {subagent_spec.output_keys}\n"
f"Alternatively, call report_to_parent(mark_complete=true) "
f"with your findings in message/data.\n"
f"You have a maximum of {max_iter} turns to complete this task."
),
goal=ctx.goal,
max_tokens=ctx.max_tokens,
runtime_logger=ctx.runtime_logger,
is_subagent_mode=True, # Prevents nested delegation
report_callback=_report_callback,
node_registry={}, # Empty - no nested subagents
shared_node_registry=ctx.shared_node_registry, # For escalation routing
)
# 5. Create and execute subagent EventLoopNode
subagent_conv_store = None
if conversation_store is not None:
from framework.storage.conversation_store import FileConversationStore
parent_base = getattr(conversation_store, "_base", None)
if parent_base is not None:
conversations_dir = parent_base.parent
subagent_dir_name = f"{agent_id}-{subagent_instance}"
subagent_store_path = conversations_dir / subagent_dir_name
subagent_conv_store = FileConversationStore(base_path=subagent_store_path)
# Derive a subagent-scoped spillover dir
subagent_spillover = None
if config.spillover_dir:
subagent_spillover = str(Path(config.spillover_dir) / agent_id / subagent_instance)
subagent_node = event_loop_node_cls(
event_bus=event_bus,
judge=SubagentJudge(task=task, max_iterations=max_iter),
config=LoopConfig(
max_iterations=max_iter,
max_tool_calls_per_turn=config.max_tool_calls_per_turn,
tool_call_overflow_margin=config.tool_call_overflow_margin,
max_context_tokens=config.max_context_tokens,
stall_detection_threshold=config.stall_detection_threshold,
max_tool_result_chars=config.max_tool_result_chars,
spillover_dir=subagent_spillover,
),
tool_executor=tool_executor,
conversation_store=subagent_conv_store,
)
# Inject a unique GCU browser profile for this subagent
_profile_token = None
try:
from gcu.browser.session import set_active_profile as _set_gcu_profile
_profile_token = _set_gcu_profile(f"{agent_id}-{subagent_instance}")
except ImportError:
pass # GCU tools not installed; no-op
try:
logger.info("🚀 Starting subagent '%s' execution...", agent_id)
start_time = time.time()
result = await subagent_node.execute(subagent_ctx)
latency_ms = int((time.time() - start_time) * 1000)
separator = "-" * 60
logger.info(
"\n%s\n"
"✅ SUBAGENT '%s' COMPLETED\n"
"%s\n"
"Success: %s\n"
"Latency: %dms\n"
"Tokens used: %s\n"
"Output keys: %s\n"
"%s",
separator,
agent_id,
separator,
result.success,
latency_ms,
result.tokens_used,
list(result.output.keys()) if result.output else [],
separator,
)
result_json = {
"message": (
f"Sub-agent '{agent_id}' completed successfully"
if result.success
else f"Sub-agent '{agent_id}' failed: {result.error}"
),
"data": result.output,
"reports": subagent_reports if subagent_reports else None,
"metadata": {
"agent_id": agent_id,
"success": result.success,
"tokens_used": result.tokens_used,
"latency_ms": latency_ms,
"report_count": len(subagent_reports),
},
}
return ToolResult(
tool_use_id="",
content=json.dumps(result_json, indent=2, default=str),
is_error=not result.success,
)
except Exception as e:
logger.exception(
"\n" + "!" * 60 + "\n❌ SUBAGENT '%s' FAILED\nError: %s\n" + "!" * 60,
agent_id,
str(e),
)
result_json = {
"message": f"Sub-agent '{agent_id}' raised exception: {e}",
"data": None,
"metadata": {
"agent_id": agent_id,
"success": False,
"error": str(e),
},
}
return ToolResult(
tool_use_id="",
content=json.dumps(result_json, indent=2),
is_error=True,
)
finally:
# Restore the GCU profile context
if _profile_token is not None:
from gcu.browser.session import _active_profile as _gcu_profile_var
_gcu_profile_var.reset(_profile_token)
# Stop the browser session for this subagent's profile
if tool_executor is not None:
_subagent_profile = f"{agent_id}-{subagent_instance}"
try:
_stop_use = ToolUse(
id="gcu-cleanup",
name="browser_stop",
input={"profile": _subagent_profile},
)
_stop_result = tool_executor(_stop_use)
if asyncio.iscoroutine(_stop_result) or asyncio.isfuture(_stop_result):
await _stop_result
except Exception as _gcu_exc:
logger.warning(
"GCU browser_stop failed for profile %r: %s",
_subagent_profile,
_gcu_exc,
)
@@ -0,0 +1,369 @@
"""Synthetic tool builders for the event loop.
Factory functions that create ``Tool`` definitions for framework-level
synthetic tools (set_output, ask_user, escalate, delegate, report_to_parent).
Also includes the ``handle_set_output`` validation logic.
All functions are pure they receive explicit parameters and return
``Tool`` or ``ToolResult`` objects with no side effects.
"""
from __future__ import annotations
from typing import Any
from framework.llm.provider import Tool, ToolResult
def build_ask_user_tool() -> Tool:
"""Build the synthetic ask_user tool for explicit user-input requests.
Client-facing nodes call ask_user() when they need to pause and wait
for user input. Text-only turns WITHOUT ask_user flow through without
blocking, allowing progress updates and summaries to stream freely.
"""
return Tool(
name="ask_user",
description=(
"You MUST call this tool whenever you need the user's response. "
"Always call it after greeting the user, asking a question, or "
"requesting approval. Do NOT call it for status updates or "
"summaries that don't require a response. "
"Always include 2-3 predefined options. The UI automatically "
"appends an 'Other' free-text input after your options, so NEVER "
"include catch-all options like 'Custom idea', 'Something else', "
"'Other', or 'None of the above' — the UI handles that. "
"When the question primarily needs a typed answer but you must "
"include options, make one option signal that typing is expected "
"(e.g. 'I\\'ll type my response'). This helps users discover the "
"free-text input. "
"The ONLY exception: omit options when the question demands a "
"free-form answer the user must type out (e.g. 'Describe your "
"agent idea', 'Paste the error message'). "
'{"question": "What would you like to do?", "options": '
'["Build a new agent", "Modify existing agent", "Run tests"]} '
"Free-form example: "
'{"question": "Describe the agent you want to build."}'
),
parameters={
"type": "object",
"properties": {
"question": {
"type": "string",
"description": "The question or prompt shown to the user.",
},
"options": {
"type": "array",
"items": {"type": "string"},
"description": (
"2-3 specific predefined choices. Include in most cases. "
'Example: ["Option A", "Option B", "Option C"]. '
"The UI always appends an 'Other' free-text input, so "
"do NOT include catch-alls like 'Custom idea' or 'Other'. "
"Omit ONLY when the user must type a free-form answer."
),
"minItems": 2,
"maxItems": 3,
},
},
"required": ["question"],
},
)
def build_ask_user_multiple_tool() -> Tool:
"""Build the synthetic ask_user_multiple tool for batched questions.
Queen-only tool that presents multiple questions at once so the user
can answer them all in a single interaction rather than one at a time.
"""
return Tool(
name="ask_user_multiple",
description=(
"Ask the user multiple questions at once. Use this instead of "
"ask_user when you have 2 or more questions to ask in the same "
"turn — it lets the user answer everything in one go rather than "
"going back and forth. Each question can have its own predefined "
"options (2-3 choices) or be free-form. The UI renders all "
"questions together with a single Submit button. "
"ALWAYS prefer this over ask_user when you have multiple things "
"to clarify. "
"IMPORTANT: Do NOT repeat the questions in your text response — "
"the widget renders them. Keep your text to a brief intro only. "
'{"questions": ['
' {"id": "scope", "prompt": "What scope?", "options": ["Full", "Partial"]},'
' {"id": "format", "prompt": "Output format?", "options": ["PDF", "CSV", "JSON"]},'
' {"id": "details", "prompt": "Any special requirements?"}'
"]}"
),
parameters={
"type": "object",
"properties": {
"questions": {
"type": "array",
"items": {
"type": "object",
"properties": {
"id": {
"type": "string",
"description": (
"Short identifier for this question (used in the response)."
),
},
"prompt": {
"type": "string",
"description": "The question text shown to the user.",
},
"options": {
"type": "array",
"items": {"type": "string"},
"description": (
"2-3 predefined choices. The UI appends an "
"'Other' free-text input automatically. "
"Omit only when the user must type a free-form answer."
),
"minItems": 2,
"maxItems": 3,
},
},
"required": ["id", "prompt"],
},
"minItems": 2,
"maxItems": 8,
"description": "List of questions to present to the user.",
},
},
"required": ["questions"],
},
)
def build_set_output_tool(output_keys: list[str] | None) -> Tool | None:
"""Build the synthetic set_output tool for explicit output declaration."""
if not output_keys:
return None
return Tool(
name="set_output",
description=(
"Set an output value for this node. Call once per output key. "
"Use this for brief notes, counts, status, and file references — "
"NOT for large data payloads. When a tool result was saved to a "
"data file, pass the filename as the value "
"(e.g. 'google_sheets_get_values_1.txt') so the next phase can "
"load the full data. Values exceeding ~2000 characters are "
"auto-saved to data files. "
f"Valid keys: {output_keys}"
),
parameters={
"type": "object",
"properties": {
"key": {
"type": "string",
"description": f"Output key. Must be one of: {output_keys}",
"enum": output_keys,
},
"value": {
"type": "string",
"description": (
"The output value — a brief note, count, status, "
"or data filename reference."
),
},
},
"required": ["key", "value"],
},
)
def build_escalate_tool() -> Tool:
"""Build the synthetic escalate tool for worker -> queen handoff."""
return Tool(
name="escalate",
description=(
"Escalate to the queen when requesting user input, "
"blocked by errors, missing "
"credentials, or ambiguous constraints that require supervisor "
"guidance. Include a concise reason and optional context. "
"The node will pause until the queen injects guidance."
),
parameters={
"type": "object",
"properties": {
"reason": {
"type": "string",
"description": (
"Short reason for escalation (e.g. 'Tool repeatedly failing')."
),
},
"context": {
"type": "string",
"description": "Optional diagnostic details for the queen.",
},
},
"required": ["reason"],
},
)
def build_delegate_tool(sub_agents: list[str], node_registry: dict[str, Any]) -> Tool | None:
"""Build the synthetic delegate_to_sub_agent tool for subagent invocation.
Args:
sub_agents: List of node IDs that can be invoked as subagents.
node_registry: Map of node_id -> NodeSpec for looking up subagent descriptions.
Returns:
Tool definition if sub_agents is non-empty, None otherwise.
"""
if not sub_agents:
return None
agent_descriptions = []
for agent_id in sub_agents:
spec = node_registry.get(agent_id)
if spec:
desc = getattr(spec, "description", "(no description)")
agent_descriptions.append(f"- {agent_id}: {desc}")
else:
agent_descriptions.append(f"- {agent_id}: (not found in registry)")
return Tool(
name="delegate_to_sub_agent",
description=(
"Delegate a task to a specialized sub-agent. The sub-agent runs "
"autonomously with read-only access to current memory and returns "
"its result. Use this to parallelize work or leverage specialized capabilities.\n\n"
"Available sub-agents:\n" + "\n".join(agent_descriptions)
),
parameters={
"type": "object",
"properties": {
"agent_id": {
"type": "string",
"description": f"The sub-agent to invoke. Must be one of: {sub_agents}",
"enum": sub_agents,
},
"task": {
"type": "string",
"description": (
"The task description for the sub-agent to execute. "
"Be specific about what you want the sub-agent to do and "
"what information to return."
),
},
},
"required": ["agent_id", "task"],
},
)
def build_report_to_parent_tool() -> Tool:
"""Build the synthetic report_to_parent tool for sub-agent progress reports.
Sub-agents call this to send one-way progress updates, partial findings,
or status reports to the parent node (and external observers via event bus)
without blocking execution.
When ``wait_for_response`` is True, the sub-agent blocks until the parent
relays the user's response — used for escalation (e.g. login pages, CAPTCHAs).
When ``mark_complete`` is True, the sub-agent terminates immediately after
sending the report no need to call set_output for each output key.
"""
return Tool(
name="report_to_parent",
description=(
"Send a report to the parent agent. By default this is fire-and-forget: "
"the parent receives the report but does not respond. "
"Set wait_for_response=true to BLOCK until the user replies — use this "
"when you need human intervention (e.g. login pages, CAPTCHAs, "
"authentication walls). The user's response is returned as the tool result. "
"Set mark_complete=true to finish your task and terminate immediately "
"after sending the report — use this when your findings are in the "
"message/data fields and you don't need to call set_output."
),
parameters={
"type": "object",
"properties": {
"message": {
"type": "string",
"description": "A human-readable status or progress message.",
},
"data": {
"type": "object",
"description": "Optional structured data to include with the report.",
},
"wait_for_response": {
"type": "boolean",
"description": (
"If true, block execution until the user responds. "
"Use for escalation scenarios requiring human intervention."
),
"default": False,
},
"mark_complete": {
"type": "boolean",
"description": (
"If true, terminate the sub-agent immediately after sending "
"this report. The report message and data are delivered to the "
"parent as the final result. No set_output calls are needed."
),
"default": False,
},
},
"required": ["message"],
},
)
def handle_set_output(
tool_input: dict[str, Any],
output_keys: list[str] | None,
) -> ToolResult:
"""Handle set_output tool call. Returns ToolResult (sync)."""
import logging
import re
logger = logging.getLogger(__name__)
key = tool_input.get("key", "")
value = tool_input.get("value", "")
valid_keys = output_keys or []
# Recover from truncated JSON (max_tokens hit mid-argument).
# The _raw key is set by litellm when json.loads fails.
if not key and "_raw" in tool_input:
raw = tool_input["_raw"]
key_match = re.search(r'"key"\s*:\s*"(\w+)"', raw)
if key_match:
key = key_match.group(1)
val_match = re.search(r'"value"\s*:\s*"', raw)
if val_match:
start = val_match.end()
value = raw[start:].rstrip()
for suffix in ('"}\n', '"}', '"'):
if value.endswith(suffix):
value = value[: -len(suffix)]
break
if key:
logger.warning(
"Recovered set_output args from truncated JSON: key=%s, value_len=%d",
key,
len(value),
)
# Re-inject so the caller sees proper key/value
tool_input["key"] = key
tool_input["value"] = value
if key not in valid_keys:
return ToolResult(
tool_use_id="",
content=f"Invalid output key '{key}'. Valid keys: {valid_keys}",
is_error=True,
)
return ToolResult(
tool_use_id="",
content=f"Output '{key}' set successfully.",
is_error=False,
)
@@ -0,0 +1,542 @@
"""Tool result handling: truncation, spillover, JSON preview, and execution.
Manages tool result size limits, file spillover for large results, and
smart JSON previews. Also includes transient error classification and
the context-window-exceeded error detector.
"""
from __future__ import annotations
import asyncio
import json
import logging
import re
from pathlib import Path
from typing import Any
from framework.llm.provider import ToolResult, ToolUse
from framework.llm.stream_events import ToolCallEvent
logger = logging.getLogger(__name__)
# Pattern for detecting context-window-exceeded errors across LLM providers.
_CONTEXT_TOO_LARGE_RE = re.compile(
r"context.{0,20}(length|window|limit|size)|"
r"too.{0,10}(long|large|many.{0,10}tokens)|"
r"(exceed|exceeds|exceeded).{0,30}(limit|window|context|tokens)|"
r"maximum.{0,20}token|prompt.{0,20}too.{0,10}long",
re.IGNORECASE,
)
def is_context_too_large_error(exc: BaseException) -> bool:
"""Detect whether an exception indicates the LLM input was too large."""
cls = type(exc).__name__
if "ContextWindow" in cls:
return True
return bool(_CONTEXT_TOO_LARGE_RE.search(str(exc)))
def is_transient_error(exc: BaseException) -> bool:
"""Classify whether an exception is transient (retryable) vs permanent.
Transient: network errors, rate limits, server errors, timeouts.
Permanent: auth errors, bad requests, context window exceeded.
"""
try:
from litellm.exceptions import (
APIConnectionError,
BadGatewayError,
InternalServerError,
RateLimitError,
ServiceUnavailableError,
)
transient_types: tuple[type[BaseException], ...] = (
RateLimitError,
APIConnectionError,
InternalServerError,
BadGatewayError,
ServiceUnavailableError,
TimeoutError,
ConnectionError,
OSError,
)
except ImportError:
transient_types = (TimeoutError, ConnectionError, OSError)
if isinstance(exc, transient_types):
return True
# RuntimeError from StreamErrorEvent with "Stream error:" prefix
if isinstance(exc, RuntimeError):
error_str = str(exc).lower()
transient_keywords = [
"rate limit",
"429",
"timeout",
"connection",
"internal server",
"502",
"503",
"504",
"service unavailable",
"bad gateway",
"overloaded",
"failed to parse tool call",
]
return any(kw in error_str for kw in transient_keywords)
return False
def extract_json_metadata(parsed: Any, *, _depth: int = 0, _max_depth: int = 3) -> str:
"""Return a concise structural summary of parsed JSON.
Reports key names, value types, and crucially array lengths so
the LLM knows how much data exists beyond the preview.
Returns an empty string for simple scalars.
"""
if _depth >= _max_depth:
if isinstance(parsed, dict):
return f"dict with {len(parsed)} keys"
if isinstance(parsed, list):
return f"list of {len(parsed)} items"
return type(parsed).__name__
if isinstance(parsed, dict):
if not parsed:
return "empty dict"
lines: list[str] = []
indent = " " * (_depth + 1)
for key, value in list(parsed.items())[:20]:
if isinstance(value, list):
line = f'{indent}"{key}": list of {len(value)} items'
if value:
first = value[0]
if isinstance(first, dict):
sample_keys = list(first.keys())[:10]
line += f" (each item: dict with keys {sample_keys})"
elif isinstance(first, list):
line += f" (each item: list of {len(first)} elements)"
lines.append(line)
elif isinstance(value, dict):
child = extract_json_metadata(value, _depth=_depth + 1, _max_depth=_max_depth)
lines.append(f'{indent}"{key}": {child}')
else:
lines.append(f'{indent}"{key}": {type(value).__name__}')
if len(parsed) > 20:
lines.append(f"{indent}... and {len(parsed) - 20} more keys")
return "\n".join(lines)
if isinstance(parsed, list):
if not parsed:
return "empty list"
desc = f"list of {len(parsed)} items"
first = parsed[0]
if isinstance(first, dict):
sample_keys = list(first.keys())[:10]
desc += f" (each item: dict with keys {sample_keys})"
elif isinstance(first, list):
desc += f" (each item: list of {len(first)} elements)"
return desc
return ""
def build_json_preview(parsed: Any, *, max_chars: int = 5000) -> str | None:
"""Build a smart preview of parsed JSON, truncating large arrays.
Shows first 3 + last 1 items of large arrays with explicit count
markers so the LLM cannot mistake the preview for the full dataset.
Returns ``None`` if no truncation was needed (no large arrays).
"""
_LARGE_ARRAY_THRESHOLD = 10
def _truncate_arrays(obj: Any) -> tuple[Any, bool]:
"""Return (truncated_copy, was_truncated)."""
if isinstance(obj, list) and len(obj) > _LARGE_ARRAY_THRESHOLD:
n = len(obj)
head = obj[:3]
tail = obj[-1:]
marker = f"... ({n - 4} more items omitted, {n} total) ..."
return head + [marker] + tail, True
if isinstance(obj, dict):
changed = False
out: dict[str, Any] = {}
for k, v in obj.items():
new_v, did = _truncate_arrays(v)
out[k] = new_v
changed = changed or did
return (out, True) if changed else (obj, False)
return obj, False
preview_obj, was_truncated = _truncate_arrays(parsed)
if not was_truncated:
return None # No large arrays — caller should use raw slicing
try:
result = json.dumps(preview_obj, indent=2, ensure_ascii=False)
except (TypeError, ValueError):
return None
if len(result) > max_chars:
# Even 3+1 items too big — try just 1 item
def _minimal_arrays(obj: Any) -> Any:
if isinstance(obj, list) and len(obj) > _LARGE_ARRAY_THRESHOLD:
n = len(obj)
return obj[:1] + [f"... ({n - 1} more items omitted, {n} total) ..."]
if isinstance(obj, dict):
return {k: _minimal_arrays(v) for k, v in obj.items()}
return obj
preview_obj = _minimal_arrays(parsed)
try:
result = json.dumps(preview_obj, indent=2, ensure_ascii=False)
except (TypeError, ValueError):
return None
if len(result) > max_chars:
result = result[:max_chars] + ""
return result
def truncate_tool_result(
result: ToolResult,
tool_name: str,
*,
max_tool_result_chars: int,
spillover_dir: str | None,
next_spill_filename_fn: Any, # Callable[[str], str]
) -> ToolResult:
"""Persist tool result to file and optionally truncate for context.
When *spillover_dir* is configured, EVERY non-error tool result is
saved to a file (short filename like ``web_search_1.txt``). A
``[Saved to '...']`` annotation is appended so the reference
survives pruning and compaction.
- Small results ( limit): full content kept + file annotation
- Large results (> limit): preview + file reference
- Errors: pass through unchanged
- load_data results: truncate with pagination hint (no re-spill)
"""
limit = max_tool_result_chars
# Errors always pass through unchanged
if result.is_error:
return result
# load_data reads FROM spilled files — never re-spill (circular).
# Just truncate with a pagination hint if the result is too large.
if tool_name == "load_data":
if limit <= 0 or len(result.content) <= limit:
return result # Small load_data result — pass through as-is
# Large load_data result — truncate with smart preview
PREVIEW_CAP = min(5000, max(limit - 500, limit // 2))
metadata_str = ""
smart_preview: str | None = None
try:
parsed_ld = json.loads(result.content)
metadata_str = extract_json_metadata(parsed_ld)
smart_preview = build_json_preview(parsed_ld, max_chars=PREVIEW_CAP)
except (json.JSONDecodeError, TypeError, ValueError):
pass
if smart_preview is not None:
preview_block = smart_preview
else:
preview_block = result.content[:PREVIEW_CAP] + ""
header = (
f"[{tool_name} result: {len(result.content):,} chars — "
f"too large for context. Use offset_bytes/limit_bytes "
f"parameters to read smaller chunks.]"
)
if metadata_str:
header += f"\n\nData structure:\n{metadata_str}"
header += (
"\n\nWARNING: This is an INCOMPLETE preview. Do NOT draw conclusions or counts from it."
)
truncated = f"{header}\n\nPreview (small sample only):\n{preview_block}"
logger.info(
"%s result truncated: %d%d chars (use offset/limit to paginate)",
tool_name,
len(result.content),
len(truncated),
)
return ToolResult(
tool_use_id=result.tool_use_id,
content=truncated,
is_error=False,
image_content=result.image_content,
is_skill_content=result.is_skill_content,
)
spill_dir = spillover_dir
if spill_dir:
spill_path = Path(spill_dir)
spill_path.mkdir(parents=True, exist_ok=True)
filename = next_spill_filename_fn(tool_name)
# Pretty-print JSON content so load_data's line-based
# pagination works correctly.
write_content = result.content
parsed_json: Any = None # track for metadata extraction
try:
parsed_json = json.loads(result.content)
write_content = json.dumps(parsed_json, indent=2, ensure_ascii=False)
except (json.JSONDecodeError, TypeError, ValueError):
pass # Not JSON — write as-is
(spill_path / filename).write_text(write_content, encoding="utf-8")
if limit > 0 and len(result.content) > limit:
# Large result: build a small, metadata-rich preview so the
# LLM cannot mistake it for the complete dataset.
PREVIEW_CAP = 5000
# Extract structural metadata (array lengths, key names)
metadata_str = ""
smart_preview: str | None = None
if parsed_json is not None:
metadata_str = extract_json_metadata(parsed_json)
smart_preview = build_json_preview(parsed_json, max_chars=PREVIEW_CAP)
if smart_preview is not None:
preview_block = smart_preview
else:
preview_block = result.content[:PREVIEW_CAP] + ""
# Assemble header with structural info + warning
header = (
f"[Result from {tool_name}: {len(result.content):,} chars — "
f"too large for context, saved to '{filename}'.]\n"
)
if metadata_str:
header += f"\nData structure:\n{metadata_str}"
header += (
f"\n\nWARNING: The preview below is INCOMPLETE. "
f"Do NOT draw conclusions or counts from it. "
f"Use load_data(filename='{filename}') to read the "
f"full data before analysis."
)
content = f"{header}\n\nPreview (small sample only):\n{preview_block}"
logger.info(
"Tool result spilled to file: %s (%d chars → %s)",
tool_name,
len(result.content),
filename,
)
else:
# Small result: keep full content + annotation
content = f"{result.content}\n\n[Saved to '{filename}']"
logger.info(
"Tool result saved to file: %s (%d chars → %s)",
tool_name,
len(result.content),
filename,
)
return ToolResult(
tool_use_id=result.tool_use_id,
content=content,
is_error=False,
image_content=result.image_content,
is_skill_content=result.is_skill_content,
)
# No spillover_dir — truncate in-place if needed
if limit > 0 and len(result.content) > limit:
PREVIEW_CAP = min(5000, max(limit - 500, limit // 2))
metadata_str = ""
smart_preview: str | None = None
try:
parsed_inline = json.loads(result.content)
metadata_str = extract_json_metadata(parsed_inline)
smart_preview = build_json_preview(parsed_inline, max_chars=PREVIEW_CAP)
except (json.JSONDecodeError, TypeError, ValueError):
pass
if smart_preview is not None:
preview_block = smart_preview
else:
preview_block = result.content[:PREVIEW_CAP] + ""
header = (
f"[Result from {tool_name}: {len(result.content):,} chars — "
f"truncated to fit context budget.]"
)
if metadata_str:
header += f"\n\nData structure:\n{metadata_str}"
header += (
"\n\nWARNING: This is an INCOMPLETE preview. "
"Do NOT draw conclusions or counts from the preview alone."
)
truncated = f"{header}\n\n{preview_block}"
logger.info(
"Tool result truncated in-place: %s (%d%d chars)",
tool_name,
len(result.content),
len(truncated),
)
return ToolResult(
tool_use_id=result.tool_use_id,
content=truncated,
is_error=False,
image_content=result.image_content,
is_skill_content=result.is_skill_content,
)
return result
async def execute_tool(
tool_executor: Any, # Callable[[ToolUse], ToolResult | Awaitable[ToolResult]] | None
tc: ToolCallEvent,
timeout: float,
skill_dirs: list[str] | None = None,
) -> ToolResult:
"""Execute a tool call, handling both sync and async executors.
Applies ``tool_call_timeout_seconds`` to prevent hung MCP servers
from blocking the event loop indefinitely. The initial executor
call is offloaded to a thread pool so that sync executors don't
freeze the event loop.
"""
if tool_executor is None:
return ToolResult(
tool_use_id=tc.tool_use_id,
content=f"No tool executor configured for '{tc.tool_name}'",
is_error=True,
)
skill_dirs = skill_dirs or []
skill_read_tools = {"view_file", "load_data", "read_file"}
if tc.tool_name in skill_read_tools and skill_dirs:
raw_path = tc.tool_input.get("path", "")
if raw_path:
resolved = Path(raw_path).resolve(strict=False)
resolved_roots = [Path(skill_dir).resolve(strict=False) for skill_dir in skill_dirs]
if any(resolved.is_relative_to(root) for root in resolved_roots):
try:
content = resolved.read_text(encoding="utf-8")
except Exception as exc:
return ToolResult(
tool_use_id=tc.tool_use_id,
content=f"Could not read skill resource '{raw_path}': {exc}",
is_error=True,
)
return ToolResult(
tool_use_id=tc.tool_use_id,
content=content,
is_skill_content=resolved.name == "SKILL.md",
)
tool_use = ToolUse(id=tc.tool_use_id, name=tc.tool_name, input=tc.tool_input)
async def _run() -> ToolResult:
# Offload the executor call to a thread. Sync MCP executors
# block on future.result() — running in a thread keeps the
# event loop free so asyncio.wait_for can fire the timeout.
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(None, tool_executor, tool_use)
# Async executors return a coroutine — await it on the loop
if asyncio.iscoroutine(result) or asyncio.isfuture(result):
result = await result
return result
try:
if timeout > 0:
result = await asyncio.wait_for(_run(), timeout=timeout)
else:
result = await _run()
except TimeoutError:
logger.warning("Tool '%s' timed out after %.0fs", tc.tool_name, timeout)
return ToolResult(
tool_use_id=tc.tool_use_id,
content=(
f"Tool '{tc.tool_name}' timed out after {timeout:.0f}s. "
"The operation took too long and was cancelled. "
"Try a simpler request or a different approach."
),
is_error=True,
)
return result
def record_learning(key: str, value: Any, spillover_dir: str | None) -> None:
"""Append a set_output value to adapt.md as a learning entry.
Called at set_output time the moment knowledge is produced so that
adapt.md accumulates the agent's outputs across the session. Since
adapt.md is injected into the system prompt, these persist through
any compaction.
"""
if not spillover_dir:
return
try:
adapt_path = Path(spillover_dir) / "adapt.md"
adapt_path.parent.mkdir(parents=True, exist_ok=True)
content = adapt_path.read_text(encoding="utf-8") if adapt_path.exists() else ""
if "## Outputs" not in content:
content += "\n\n## Outputs\n"
# Truncate long values for memory (full value is in shared memory)
v_str = str(value)
if len(v_str) > 500:
v_str = v_str[:500] + ""
entry = f"- {key}: {v_str}\n"
# Replace existing entry for same key (update, not duplicate)
lines = content.splitlines(keepends=True)
replaced = False
for i, line in enumerate(lines):
if line.startswith(f"- {key}:"):
lines[i] = entry
replaced = True
break
if replaced:
content = "".join(lines)
else:
content += entry
adapt_path.write_text(content, encoding="utf-8")
except Exception as e:
logger.warning("Failed to record learning for key=%s: %s", key, e)
def next_spill_filename(tool_name: str, counter: int) -> str:
"""Return a short, monotonic filename for a tool result spill."""
# Shorten common tool name prefixes to save tokens
short = tool_name.removeprefix("tool_").removeprefix("mcp_")
return f"{short}_{counter}.txt"
def restore_spill_counter(spillover_dir: str | None) -> int:
"""Scan spillover_dir for existing spill files and return the max counter.
Returns the highest spill number found (or 0 if none).
"""
if not spillover_dir:
return 0
spill_path = Path(spillover_dir)
if not spill_path.is_dir():
return 0
max_n = 0
for f in spill_path.iterdir():
if not f.is_file():
continue
m = re.search(r"_(\d+)\.txt$", f.name)
if m:
max_n = max(max_n, int(m.group(1)))
return max_n
+190
View File
@@ -0,0 +1,190 @@
"""Shared types and state containers for the event loop package."""
from __future__ import annotations
import json
import logging
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal, Protocol, runtime_checkable
from framework.graph.conversation import ConversationStore
logger = logging.getLogger(__name__)
@dataclass
class TriggerEvent:
"""A framework-level trigger signal (timer tick or webhook hit)."""
trigger_type: str
source_id: str
payload: dict[str, Any] = field(default_factory=dict)
timestamp: float = field(default_factory=time.time)
@dataclass
class JudgeVerdict:
"""Result of judge evaluation for the event loop."""
action: Literal["ACCEPT", "RETRY", "ESCALATE"]
# None = no evaluation happened (skip_judge, tool-continue); not logged.
# "" = evaluated but no feedback; logged with default text.
# "..." = evaluated with feedback; logged as-is.
feedback: str | None = None
@runtime_checkable
class JudgeProtocol(Protocol):
"""Protocol for event-loop judges."""
async def evaluate(self, context: dict[str, Any]) -> JudgeVerdict: ...
@dataclass
class LoopConfig:
"""Configuration for the event loop."""
max_iterations: int = 50
max_tool_calls_per_turn: int = 30
judge_every_n_turns: int = 1
stall_detection_threshold: int = 3
stall_similarity_threshold: float = 0.85
max_context_tokens: int = 32_000
store_prefix: str = ""
# Overflow margin for max_tool_calls_per_turn. Tool calls are only
# discarded when the count exceeds max_tool_calls_per_turn * (1 + margin).
tool_call_overflow_margin: float = 0.5
# Tool result context management.
max_tool_result_chars: int = 30_000
spillover_dir: str | None = None
# set_output value spilling.
max_output_value_chars: int = 2_000
# Stream retry.
max_stream_retries: int = 3
stream_retry_backoff_base: float = 2.0
stream_retry_max_delay: float = 60.0
# Tool doom loop detection.
tool_doom_loop_threshold: int = 3
# Client-facing auto-block grace period.
cf_grace_turns: int = 1
tool_doom_loop_enabled: bool = True
# Per-tool-call timeout.
tool_call_timeout_seconds: float = 60.0
# Subagent delegation timeout.
subagent_timeout_seconds: float = 600.0
# Lifecycle hooks.
hooks: dict[str, list] | None = None
def __post_init__(self) -> None:
if self.hooks is None:
object.__setattr__(self, "hooks", {})
@dataclass
class HookContext:
"""Context passed to every lifecycle hook."""
event: str
trigger: str | None
system_prompt: str
@dataclass
class HookResult:
"""What a hook may return to modify node state."""
system_prompt: str | None = None
inject: str | None = None
@dataclass
class OutputAccumulator:
"""Accumulates output key-value pairs with optional write-through persistence."""
values: dict[str, Any] = field(default_factory=dict)
store: ConversationStore | None = None
spillover_dir: str | None = None
max_value_chars: int = 0
async def set(self, key: str, value: Any) -> None:
"""Set a key-value pair, auto-spilling large values to files."""
value = self._auto_spill(key, value)
self.values[key] = value
if self.store:
cursor = await self.store.read_cursor() or {}
outputs = cursor.get("outputs", {})
outputs[key] = value
cursor["outputs"] = outputs
await self.store.write_cursor(cursor)
def _auto_spill(self, key: str, value: Any) -> Any:
"""Save large values to a file and return a reference string."""
if self.max_value_chars <= 0 or not self.spillover_dir:
return value
val_str = json.dumps(value, ensure_ascii=False) if not isinstance(value, str) else value
if len(val_str) <= self.max_value_chars:
return value
spill_path = Path(self.spillover_dir)
spill_path.mkdir(parents=True, exist_ok=True)
ext = ".json" if isinstance(value, (dict, list)) else ".txt"
filename = f"output_{key}{ext}"
write_content = (
json.dumps(value, indent=2, ensure_ascii=False)
if isinstance(value, (dict, list))
else str(value)
)
(spill_path / filename).write_text(write_content, encoding="utf-8")
file_size = (spill_path / filename).stat().st_size
logger.info(
"set_output value auto-spilled: key=%s, %d chars -> %s (%d bytes)",
key,
len(val_str),
filename,
file_size,
)
return (
f"[Saved to '{filename}' ({file_size:,} bytes). "
f"Use load_data(filename='{filename}') "
f"to access full data.]"
)
def get(self, key: str) -> Any | None:
return self.values.get(key)
def to_dict(self) -> dict[str, Any]:
return dict(self.values)
def has_all_keys(self, required: list[str]) -> bool:
return all(key in self.values and self.values[key] is not None for key in required)
@classmethod
async def restore(cls, store: ConversationStore) -> OutputAccumulator:
cursor = await store.read_cursor()
values = {}
if cursor and "outputs" in cursor:
values = cursor["outputs"]
return cls(values=values, store=store)
__all__ = [
"HookContext",
"HookResult",
"JudgeProtocol",
"JudgeVerdict",
"LoopConfig",
"OutputAccumulator",
"TriggerEvent",
]
File diff suppressed because it is too large Load Diff
+8
View File
@@ -155,6 +155,8 @@ class GraphExecutor:
skills_catalog_prompt: str = "",
protocols_prompt: str = "",
skill_dirs: list[str] | None = None,
context_warn_ratio: float | None = None,
batch_init_nudge: str | None = None,
):
"""
Initialize the executor.
@@ -183,6 +185,8 @@ class GraphExecutor:
skills_catalog_prompt: Available skills catalog for system prompt
protocols_prompt: Default skill operational protocols for system prompt
skill_dirs: Skill base directories for Tier 3 resource access
context_warn_ratio: Token usage ratio to trigger DS-13 preservation warning
batch_init_nudge: System prompt nudge for DS-12 batch auto-detection
"""
self.runtime = runtime
self.llm = llm
@@ -207,6 +211,8 @@ class GraphExecutor:
self.skills_catalog_prompt = skills_catalog_prompt
self.protocols_prompt = protocols_prompt
self.skill_dirs: list[str] = skill_dirs or []
self.context_warn_ratio: float | None = context_warn_ratio
self.batch_init_nudge: str | None = batch_init_nudge
if protocols_prompt:
self.logger.info(
@@ -1906,6 +1912,8 @@ class GraphExecutor:
skills_catalog_prompt=self.skills_catalog_prompt,
protocols_prompt=self.protocols_prompt,
skill_dirs=self.skill_dirs,
default_skill_warn_ratio=self.context_warn_ratio,
default_skill_batch_nudge=self.batch_init_nudge,
)
VALID_NODE_TYPES = {
+4
View File
@@ -569,6 +569,10 @@ class NodeContext:
skills_catalog_prompt: str = "" # Available skills XML catalog
protocols_prompt: str = "" # Default skill operational protocols
skill_dirs: list[str] = field(default_factory=list) # Skill base dirs for resource access
# DS-12: batch auto-detection nudge appended to system prompt when input looks like a batch
default_skill_batch_nudge: str | None = None
# DS-13: token usage ratio at which to inject a context preservation warning
default_skill_warn_ratio: float | None = None
# Per-iteration metadata provider — when set, EventLoopNode merges
# the returned dict into node_loop_iteration event data. Used by
+35 -1
View File
@@ -159,6 +159,26 @@ if litellm is not None:
# (e.g. stream_options for Anthropic) instead of forwarding them verbatim.
litellm.drop_params = True
def _is_ollama_model(model: str) -> bool:
"""Return True for any Ollama model string (ollama/ or ollama_chat/ prefix)."""
return model.startswith("ollama/") or model.startswith("ollama_chat/")
def _ensure_ollama_chat_prefix(model: str) -> str:
"""Normalise Ollama model strings to use the ollama_chat/ prefix.
LiteLLM requires the ``ollama_chat/`` prefix (not ``ollama/``) to enable
native function-calling support. With ``ollama/``, LiteLLM falls back to
JSON-mode tool calls, which the framework cannot parse as real tool calls.
See: https://docs.litellm.ai/docs/providers/ollama#example-usage---tool-calling
"""
if model.startswith("ollama/"):
return "ollama_chat/" + model[len("ollama/") :]
return model
RATE_LIMIT_MAX_RETRIES = 10
RATE_LIMIT_BACKOFF_BASE = 2 # seconds
RATE_LIMIT_MAX_DELAY = 120 # seconds - cap to prevent absurd waits
@@ -499,7 +519,9 @@ class LiteLLMProvider(LLMProvider):
# Translate kimi/ prefix to anthropic/ so litellm uses the Anthropic
# Messages API handler and routes to that endpoint — no special headers needed.
_original_model = model
if model.lower().startswith("kimi/"):
if _is_ollama_model(model):
model = _ensure_ollama_chat_prefix(model)
elif model.lower().startswith("kimi/"):
model = "anthropic/" + model[len("kimi/") :]
# Normalise api_base: litellm's Anthropic handler appends /v1/messages,
# so the base must be https://api.kimi.com/coding (no /v1 suffix).
@@ -722,6 +744,10 @@ class LiteLLMProvider(LLMProvider):
# Add tools if provided
if tools:
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
if _is_ollama_model(self.model):
# Ollama requires explicit tool_choice=auto for function calling
# so future readers don't have to guess.
kwargs.setdefault("tool_choice", "auto")
# Add response_format for structured output
# LiteLLM passes this through to the underlying provider
@@ -919,6 +945,10 @@ class LiteLLMProvider(LLMProvider):
kwargs["api_base"] = self.api_base
if tools:
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
if _is_ollama_model(self.model):
# Ollama requires explicit tool_choice=auto for function calling
# so future readers don't have to guess.
kwargs.setdefault("tool_choice", "auto")
if response_format:
kwargs["response_format"] = response_format
@@ -1620,6 +1650,10 @@ class LiteLLMProvider(LLMProvider):
kwargs["api_base"] = self.api_base
if tools:
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
if _is_ollama_model(self.model):
# Ollama requires explicit tool_choice=auto for function calling
# so future readers don't have to guess.
kwargs.setdefault("tool_choice", "auto")
if response_format:
kwargs["response_format"] = response_format
# The Codex ChatGPT backend (Responses API) rejects several params.
+18 -4
View File
@@ -14,6 +14,8 @@ from typing import Any, Literal
import httpx
from framework.runner.mcp_errors import MCPToolNotFoundError
logger = logging.getLogger(__name__)
@@ -456,7 +458,10 @@ class MCPClient:
self.connect()
if tool_name not in self._tools:
raise ValueError(f"Unknown tool: {tool_name}")
raise MCPToolNotFoundError(
server=self.config.name,
tool_name=tool_name,
)
if self.config.transport == "stdio":
with self._stdio_call_lock:
@@ -507,7 +512,10 @@ class MCPClient:
content_item = result.content[0]
if hasattr(content_item, "text"):
error_text = content_item.text
raise RuntimeError(f"MCP tool '{tool_name}' failed: {error_text}")
raise RuntimeError(
f"[Server: {self.config.name}] [Transport: {self.config.transport}] "
f"Tool '{tool_name}' failed: {error_text}"
)
# Extract content — preserve image blocks alongside text
if result.content:
@@ -558,11 +566,17 @@ class MCPClient:
data = response.json()
if "error" in data:
raise RuntimeError(f"Tool execution error: {data['error']}")
raise RuntimeError(
f"[Server: {self.config.name}] [Transport: {self.config.transport}] "
f"Tool '{tool_name}' failed: {data['error']}"
)
return data.get("result", {}).get("content", [])
except Exception as e:
raise RuntimeError(f"Failed to call tool via HTTP: {e}") from e
raise RuntimeError(
f"[Server: {self.config.name}] [Transport: {self.config.transport}] "
f"Failed to call tool via HTTP: Tool '{tool_name}' failed: {e}"
) from e
def _reconnect(self) -> None:
"""Reconnect to the configured MCP server."""
+99
View File
@@ -0,0 +1,99 @@
"""Structured error codes and exceptions for MCP server operations."""
from enum import Enum
class MCPErrorCode(Enum):
"""Standardized error codes for MCP operations."""
MCP_INSTALL_FAILED = "MCP_INSTALL_FAILED"
MCP_AUTH_MISSING = "MCP_AUTH_MISSING"
MCP_CONNECT_TIMEOUT = "MCP_CONNECT_TIMEOUT"
MCP_TOOL_NOT_FOUND = "MCP_TOOL_NOT_FOUND"
MCP_PROTOCOL_MISMATCH = "MCP_PROTOCOL_MISMATCH"
MCP_VERSION_CONFLICT = "MCP_VERSION_CONFLICT"
MCP_HEALTH_FAILED = "MCP_HEALTH_FAILED"
class MCPError(ValueError):
"""Base exception for all structured MCP errors."""
def __init__(self, code: MCPErrorCode, what: str, why: str, fix: str):
self.code = code
self.what = what
self.why = why
self.fix = fix
self.message = (
f"[{self.code.value}]\nWhat failed: {self.what}\nWhy: {self.why}\nFix: {self.fix}"
)
super().__init__(self.message)
class MCPToolNotFoundError(MCPError):
def __init__(self, server: str, tool_name: str):
super().__init__(
code=MCPErrorCode.MCP_TOOL_NOT_FOUND,
what=f"Tool '{tool_name}' not found on server '{server}'",
why=f"The server '{server}' does not expose a tool named '{tool_name}'.",
fix=f"Run 'hive mcp inspect {server}' to view available tools.",
)
class MCPConnectTimeoutError(MCPError):
def __init__(self, server: str, transport: str, timeout_sec: int):
super().__init__(
code=MCPErrorCode.MCP_CONNECT_TIMEOUT,
what=f"Connection timed out while starting server '{server}'",
why=f"The {transport} transport did not respond within {timeout_sec} seconds.",
fix=f"Check if the server is running. Run 'hive mcp doctor {server}' for diagnostics.",
)
class MCPAuthError(MCPError):
def __init__(self, server: str, env_var: str):
super().__init__(
code=MCPErrorCode.MCP_AUTH_MISSING,
what=f"Authentication failed for server '{server}'",
why=f"The required environment variable '{env_var}' is missing or empty.",
fix=f"Run: hive mcp config {server} --set {env_var}=<your-token>",
)
class MCPInstallError(MCPError):
def __init__(self, server: str, why: str, fix: str):
super().__init__(
code=MCPErrorCode.MCP_INSTALL_FAILED,
what=f"Could not install MCP server '{server}'",
why=why,
fix=fix,
)
class MCPProtocolMismatchError(MCPError):
def __init__(self, server: str, detail: str):
super().__init__(
code=MCPErrorCode.MCP_PROTOCOL_MISMATCH,
what=f"Protocol mismatch with server '{server}'",
why=detail,
fix=f"Check the MCP SDK version required by '{server}' matches your installation.",
)
class MCPVersionConflictError(MCPError):
def __init__(self, server: str, detail: str):
super().__init__(
code=MCPErrorCode.MCP_VERSION_CONFLICT,
what=f"Version conflict with server '{server}'",
why=detail,
fix="Update or pin the MCP server package to a compatible version.",
)
class MCPHealthCheckError(MCPError):
def __init__(self, server: str, detail: str):
super().__init__(
code=MCPErrorCode.MCP_HEALTH_FAILED,
what=f"Health check failed for server '{server}'",
why=detail,
fix=f"Run 'hive mcp doctor {server}' to diagnose the issue.",
)
+120 -31
View File
@@ -16,6 +16,11 @@ import httpx
from framework.runner.mcp_client import MCPClient, MCPServerConfig
from framework.runner.mcp_connection_manager import MCPConnectionManager
from framework.runner.mcp_errors import (
MCPError,
MCPErrorCode,
MCPInstallError,
)
logger = logging.getLogger(__name__)
@@ -141,7 +146,12 @@ class MCPRegistry:
"""
data = self._read_installed()
if name in data["servers"]:
raise ValueError(f"Server '{name}' already exists. Use remove first.")
raise MCPError(
code=MCPErrorCode.MCP_INSTALL_FAILED,
what=f"Server '{name}' already exists",
why="A server with this name is already registered locally.",
fix=f"Run: hive mcp remove {name} — then add it again.",
)
if manifest is not None:
# Inline manifest provided directly
@@ -153,7 +163,12 @@ class MCPRegistry:
else:
# Build manifest from individual params
if not transport:
raise ValueError("transport is required when manifest is not provided")
raise MCPError(
code=MCPErrorCode.MCP_INSTALL_FAILED,
what=f"Cannot register server '{name}'",
why="transport is required when manifest is not provided.",
fix="Pass --transport stdio|http|unix|sse when using hive mcp add.",
)
manifest = {
"name": name,
"description": description,
@@ -162,11 +177,21 @@ class MCPRegistry:
match transport:
case "http":
if not url:
raise ValueError("url is required for http transport")
raise MCPError(
code=MCPErrorCode.MCP_INSTALL_FAILED,
what=f"Cannot register server '{name}' with http transport",
why="url is required for http transport.",
fix="Pass --url https://your-server to hive mcp add.",
)
manifest["http"] = {"url": url, "headers": headers or {}}
case "stdio":
if not command:
raise ValueError("command is required for stdio transport")
raise MCPError(
code=MCPErrorCode.MCP_INSTALL_FAILED,
what=f"Cannot register server '{name}' with stdio transport",
why="command is required for stdio transport.",
fix="Pass --command <executable> to hive mcp add.",
)
manifest["stdio"] = {
"command": command,
"args": args or [],
@@ -175,15 +200,30 @@ class MCPRegistry:
}
case "unix":
if not socket_path:
raise ValueError("socket_path is required for unix transport")
raise MCPError(
code=MCPErrorCode.MCP_INSTALL_FAILED,
what=f"Cannot register server '{name}' with unix transport",
why="socket_path is required for unix transport.",
fix="Pass --socket-path /path/to/socket to hive mcp add.",
)
manifest["unix"] = {"socket_path": socket_path}
manifest["http"] = {"url": url or "http://localhost"}
case "sse":
if not url:
raise ValueError("url is required for sse transport")
raise MCPError(
code=MCPErrorCode.MCP_INSTALL_FAILED,
what=f"Cannot register server '{name}' with sse transport",
why="url is required for sse transport.",
fix="Pass --url https://your-server to hive mcp add.",
)
manifest["sse"] = {"url": url}
case _:
raise ValueError(f"Unsupported transport: {transport}")
raise MCPError(
code=MCPErrorCode.MCP_INSTALL_FAILED,
what=f"Cannot register server '{name}'",
why=f"Unsupported transport: '{transport}'.",
fix="Use one of: stdio, http, unix, sse.",
)
entry = self._make_entry(
source="local",
@@ -203,34 +243,48 @@ class MCPRegistry:
"""Install a server from the cached remote registry index."""
data = self._read_installed()
if name in data["servers"]:
raise ValueError(f"Server '{name}' already exists. Remove it first or use update.")
raise MCPInstallError(
server=name,
why=f"Server '{name}' already exists in the registry.",
fix=f"Run: hive mcp remove {name} — then install again.",
)
index = self._read_cached_index()
manifest = index.get("servers", {}).get(name)
if manifest is None:
raise ValueError(
f"Server '{name}' not found in registry index. "
"Run 'hive mcp update' to refresh the index."
raise MCPInstallError(
server=name,
why=f"Server '{name}' not found in registry index.",
fix="Run: hive mcp update — then try again.",
)
# Validate version if specified
if version is not None:
index_version = manifest.get("version")
if index_version is None:
raise ValueError(f"Cannot pin version for '{name}': manifest has no version field.")
raise MCPError(
code=MCPErrorCode.MCP_VERSION_CONFLICT,
what=f"Cannot pin version for '{name}'",
why="The registry manifest has no version field.",
fix="Run: hive mcp update — then omit --version to use latest.",
)
if index_version != version:
raise ValueError(
f"Version mismatch for '{name}': requested {version}, "
f"index has {index_version}. "
"Run 'hive mcp update' to refresh the index."
raise MCPError(
code=MCPErrorCode.MCP_VERSION_CONFLICT,
what=f"Version mismatch for '{name}'",
why=f"Requested {version} but index has {index_version}.",
fix="Run: hive mcp update — or omit --version to use latest.",
)
transport_config = manifest.get("transport", {})
supported = transport_config.get("supported", [])
if transport is not None:
if supported and transport not in supported:
raise ValueError(
f"Transport '{transport}' not supported by '{name}'. Supported: {supported}"
raise MCPError(
code=MCPErrorCode.MCP_INSTALL_FAILED,
what=f"Transport '{transport}' not supported by '{name}'",
why=f"Server supports: {supported}.",
fix=f"Use one of the supported transports: {supported}.",
)
resolved_transport = transport
else:
@@ -261,7 +315,12 @@ class MCPRegistry:
"""Remove a server from the registry."""
data = self._read_installed()
if name not in data["servers"]:
raise ValueError(f"Server '{name}' is not installed.")
raise MCPError(
code=MCPErrorCode.MCP_INSTALL_FAILED,
what=f"Cannot remove server '{name}'",
why="Server is not installed.",
fix="Run: hive mcp list — to see installed servers.",
)
del data["servers"][name]
self._write_installed(data)
logger.info("Removed MCP server '%s'", name)
@@ -277,7 +336,12 @@ class MCPRegistry:
def _set_enabled(self, name: str, *, enabled: bool) -> None:
data = self._read_installed()
if name not in data["servers"]:
raise ValueError(f"Server '{name}' is not installed.")
raise MCPError(
code=MCPErrorCode.MCP_INSTALL_FAILED,
what=f"Cannot {'enable' if enabled else 'disable'} server '{name}'",
why="Server is not installed.",
fix="Run: hive mcp list — to see installed servers.",
)
data["servers"][name]["enabled"] = enabled
self._write_installed(data)
logger.info("%s MCP server '%s'", "Enabled" if enabled else "Disabled", name)
@@ -314,9 +378,19 @@ class MCPRegistry:
"""Set an env or header override for a server."""
data = self._read_installed()
if name not in data["servers"]:
raise ValueError(f"Server '{name}' is not installed.")
raise MCPError(
code=MCPErrorCode.MCP_INSTALL_FAILED,
what=f"Cannot set override for server '{name}'",
why="Server is not installed.",
fix="Run: hive mcp list — to see installed servers.",
)
if override_type not in ("env", "headers"):
raise ValueError(f"Invalid override type: {override_type}")
raise MCPError(
code=MCPErrorCode.MCP_INSTALL_FAILED,
what=f"Invalid override type '{override_type}' for server '{name}'",
why="Override type must be 'env' or 'headers'.",
fix="Use --type env or --type headers.",
)
data["servers"][name]["overrides"][override_type][key] = value
self._write_installed(data)
logger.info("Set %s override %s for MCP server '%s'", override_type, key, name)
@@ -401,14 +475,16 @@ class MCPRegistry:
# ── load_agent_selection ────────────────────────────────────────
def load_agent_selection(self, agent_path: Path) -> list[dict[str, Any]]:
def load_agent_selection(self, agent_path: Path) -> tuple[list[dict[str, Any]], int | None]:
"""Load mcp_registry.json from an agent directory and resolve servers.
Returns list of plain dicts compatible with ToolRegistry.register_mcp_server().
Returns:
(server_config_dicts, max_tools) for :meth:`ToolRegistry.load_registry_servers`.
``max_tools`` is ``None`` when omitted or invalid in JSON.
"""
registry_json_path = agent_path / "mcp_registry.json"
if not registry_json_path.exists():
return []
return [], None
selection = json.loads(registry_json_path.read_text(encoding="utf-8"))
@@ -437,15 +513,16 @@ class MCPRegistry:
continue
validated[field] = value
max_tools = validated.get("max_tools")
configs = self.resolve_for_agent(
include=validated.get("include"),
tags=validated.get("tags"),
exclude=validated.get("exclude"),
profile=validated.get("profile"),
max_tools=validated.get("max_tools"),
max_tools=max_tools,
versions=validated.get("versions"),
)
return [self._server_config_to_dict(c) for c in configs]
return [self._server_config_to_dict(c) for c in configs], max_tools
# ── resolve_for_agent ───────────────────────────────────────────
@@ -552,12 +629,14 @@ class MCPRegistry:
)
continue
# Check tool count cap before adding (FR-56)
# Check tool count cap before adding (FR-56), using manifest tool list when present.
# When ``tools`` is empty (e.g. ``add_local``), counts are unknown here—callers should
# pass the same ``max_tools`` to ToolRegistry.load_registry_servers to cap registration.
manifest_tools = manifest.get("tools", [])
server_tool_count = len(manifest_tools)
if max_tools is not None and server_tool_count == 0:
logger.debug(
"Server '%s' has no declared tools in manifest, skipping max_tools check",
"Server '%s' has no tools list in manifest; max_tools enforced at registration",
name,
)
elif max_tools is not None and total_tools + server_tool_count > max_tools:
@@ -693,7 +772,12 @@ class MCPRegistry:
data = self._read_installed()
if name not in data["servers"]:
raise ValueError(f"Server '{name}' is not installed.")
raise MCPError(
code=MCPErrorCode.MCP_HEALTH_FAILED,
what=f"Cannot health-check server '{name}'",
why="Server is not installed.",
fix="Run: hive mcp list — to see installed servers.",
)
entry = data["servers"][name]
manifest = self._get_effective_manifest(name, entry)
@@ -728,7 +812,12 @@ class MCPRegistry:
if manager.has_connection(name):
is_healthy = manager.health_check(name)
if not is_healthy:
raise RuntimeError("Shared MCP connection health check failed")
raise MCPError(
code=MCPErrorCode.MCP_HEALTH_FAILED,
what=f"Health check failed for server '{name}'",
why="Shared MCP connection reported unhealthy.",
fix=f"Run: hive mcp doctor {name} — for diagnostics.",
)
pooled_client = manager.acquire(config)
try:
tools = pooled_client.list_tools()
+906
View File
@@ -0,0 +1,906 @@
"""CLI commands for MCP server registry management.
Commands:
hive mcp install <name> Install a server from the registry
hive mcp add Register a local/running MCP server
hive mcp remove <name> Remove an installed server
hive mcp enable <name> Enable a server
hive mcp disable <name> Disable a server
hive mcp list List installed servers
hive mcp info <name> Show server details
hive mcp config <name> Set env/header overrides
hive mcp search <query> Search the registry index
hive mcp health [name] Check server health
hive mcp update Refresh index and update installed servers
hive mcp update <name> Update a single installed server
"""
from __future__ import annotations
import json
import os
import sys
from pathlib import Path
from typing import Any
# ── Shared helpers ──────────────────────────────────────────────────
def _get_registry(base_path: Path | None = None):
"""Initialize and return an MCPRegistry instance."""
from framework.runner.mcp_registry import MCPRegistry
registry = MCPRegistry(base_path=base_path)
registry.initialize()
return registry
def _ensure_index_available(registry) -> bool:
"""Ensure the registry index is cached locally.
If no index exists or the cache is stale, fetches a fresh copy.
Returns True if a usable index exists, False otherwise.
Semantics:
- Stale cache + refresh fails -> warn and continue with stale cache (True)
- No cache + refresh fails -> hard fail (False)
"""
import httpx
cache_exists = (registry._cache_dir / "registry_index.json").exists()
if registry.is_index_stale():
print("Updating registry index...", file=sys.stderr)
try:
count = registry.update_index()
print(f"Registry index updated ({count} servers available).", file=sys.stderr)
return True
except (httpx.HTTPError, OSError) as exc:
if cache_exists:
print(
f"Warning: failed to update registry index: {exc}\nUsing cached index.",
file=sys.stderr,
)
return True
print(
f"Error: no registry index available and refresh failed: {exc}\n"
"Check your network connection and try: hive mcp update",
file=sys.stderr,
)
return False
return cache_exists
_SECURITY_NOTICE = (
"Registry servers run code on your machine. Only install servers you trust.\n"
"Learn more: https://github.com/aden-hive/hive-mcp-registry"
)
_NOTICE_SENTINEL = ".security_notice_shown"
def _print_security_notice_if_first_use(registry_base: Path) -> None:
"""Print a one-time security notice on first registry install.
Only prints the notice. Call _mark_security_notice_shown() after
a successful install to persist the sentinel.
"""
sentinel = registry_base / _NOTICE_SENTINEL
if sentinel.exists():
return
print(f"\n {_SECURITY_NOTICE}\n", file=sys.stderr)
def _mark_security_notice_shown(registry_base: Path) -> None:
"""Persist the security notice sentinel after a successful install."""
sentinel = registry_base / _NOTICE_SENTINEL
try:
sentinel.touch()
except OSError:
pass
def _prompt_for_missing_credentials(
registry,
name: str,
manifest: dict,
) -> None:
"""Prompt for required credentials not already set in env or overrides."""
credentials = manifest.get("credentials", [])
if not credentials:
return
server = registry.get_server(name)
existing_overrides = server.get("overrides", {}).get("env", {}) if server else {}
prompted = False
for cred in credentials:
if not isinstance(cred, dict):
continue
env_var = cred.get("env_var", "")
if not env_var:
continue
required = cred.get("required", False)
if not required:
continue
# Skip if already in environment or overrides
if os.environ.get(env_var) or existing_overrides.get(env_var):
continue
if not prompted:
print(f"\n{name} requires credentials:", file=sys.stderr)
prompted = True
description = cred.get("description", env_var)
help_url = cred.get("help_url", "")
help_hint = f" (get one at {help_url})" if help_url else ""
try:
value = input(f" {description}{help_hint}\n {env_var}: ").strip()
except (EOFError, KeyboardInterrupt):
print("\nSkipped credential prompting.", file=sys.stderr)
return
if value:
registry.set_override(name, env_var, value, override_type="env")
def _parse_key_value_pairs(values: list[str]) -> dict[str, str]:
"""Parse KEY=VAL pairs from CLI args. Raises ValueError on bad format."""
result = {}
for item in values:
if "=" not in item:
raise ValueError(
f"Invalid format: '{item}'. Expected KEY=VALUE.\n"
f"Example: --set JIRA_API_TOKEN=abc123"
)
key, _, value = item.partition("=")
if not key:
raise ValueError(f"Invalid format: '{item}'. Key cannot be empty.")
result[key] = value
return result
def _find_agents_using_server(registry, name: str) -> list[str]:
"""Scan agent directories for mcp_registry.json files that would load a server.
Uses MCPRegistry.load_agent_selection() to resolve actual selection logic
so results stay consistent with runtime behavior.
"""
agent_dirs: list[Path] = []
# parents: [0]=runner, [1]=framework, [2]=core, [3]=hive (project root)
# NOTE: This path arithmetic assumes running from the source tree layout.
# It will not resolve correctly if installed via pip into site-packages.
project_root = Path(__file__).resolve().parents[3]
core_dir = Path(__file__).resolve().parents[2]
candidates = [
project_root / "exports",
core_dir / "exports",
core_dir / "framework" / "agents",
]
for candidate in candidates:
if candidate.is_dir():
for child in candidate.iterdir():
if child.is_dir():
agent_dirs.append(child)
matches = []
for agent_dir in agent_dirs:
registry_json = agent_dir / "mcp_registry.json"
if not registry_json.exists():
continue
try:
configs = registry.load_agent_selection(agent_dir)
resolved_names = {c["name"] for c in configs}
if name in resolved_names:
matches.append(str(agent_dir))
except Exception:
continue
return matches
def _render_installed_table(entries: list[dict]) -> None:
"""Render installed servers as a formatted table."""
if not entries:
print("No servers installed.")
print("Run 'hive mcp install <name>' or 'hive mcp add' to get started.")
return
# Column widths
name_w = max(len(e["name"]) for e in entries)
name_w = max(name_w, 4)
transport_w = max(len(e.get("transport", "")) for e in entries)
transport_w = max(transport_w, 9)
header = (
f" {'NAME':<{name_w}} "
f"{'TRANSPORT':<{transport_w}} "
f"{'ENABLED':<7} "
f"{'HEALTH':<9} "
f"{'TOOLS':<5} "
f"{'TRUST':<10} "
f"{'SOURCE'}"
)
print(header)
print(" " + "" * (len(header) - 2))
for entry in entries:
enabled = "yes" if entry.get("enabled", True) else "no"
health = entry.get("last_health_status") or "unknown"
health_sym = {"healthy": "", "unhealthy": ""}.get(health, "")
source = entry.get("source", "")
manifest = entry.get("manifest", {})
tools_count = str(len(manifest.get("tools", [])))
trust_tier = manifest.get("status", "")
print(
f" {entry['name']:<{name_w}} "
f"{entry.get('transport', ''):<{transport_w}} "
f"{enabled:<7} "
f"{health_sym} {health:<7} "
f"{tools_count:<5} "
f"{trust_tier:<10} "
f"{source}"
)
def _render_available_table(entries: list[dict]) -> None:
"""Render available registry servers as a formatted table."""
if not entries:
print("No servers in registry index.")
print("Run 'hive mcp update' to refresh the index.")
return
name_w = max(len(e["name"]) for e in entries)
name_w = max(name_w, 4)
header = f" {'NAME':<{name_w}} {'VERSION':<9} {'STATUS':<10} DESCRIPTION"
print(header)
print(" " + "" * (len(header) - 2))
for entry in entries:
version = entry.get("version", "")
status = entry.get("status", "community")
desc = entry.get("description", "")
# Truncate long descriptions
if len(desc) > 60:
desc = desc[:57] + "..."
print(f" {entry['name']:<{name_w}} {version:<9} {status:<10} {desc}")
def _mask_overrides(overrides: dict) -> dict:
"""Replace override values with '<set>' markers. Shared by all output paths."""
masked: dict[str, dict[str, str]] = {}
if overrides.get("env"):
masked["env"] = dict.fromkeys(overrides["env"], "<set>")
else:
masked["env"] = {}
if overrides.get("headers"):
masked["headers"] = dict.fromkeys(overrides["headers"], "<set>")
else:
masked["headers"] = {}
return masked
def _emit_json(data: Any) -> None:
"""Print data as formatted JSON."""
print(json.dumps(data, indent=2, default=str))
# ── Command registration ───────────────────────────────────────────
def register_mcp_commands(subparsers) -> None:
"""Register the ``hive mcp`` subcommand group."""
mcp_parser = subparsers.add_parser("mcp", help="Manage MCP servers")
mcp_sub = mcp_parser.add_subparsers(dest="mcp_command", required=True)
# ── install ──
install_p = mcp_sub.add_parser("install", help="Install a server from the registry")
install_p.add_argument("name", help="Server name in the registry")
install_p.add_argument(
"--version", dest="version", default=None, help="Pin to a specific version"
)
install_p.add_argument(
"--transport", default=None, help="Override default transport (stdio, http, unix, sse)"
)
install_p.set_defaults(func=cmd_mcp_install)
# ── add ──
add_p = mcp_sub.add_parser("add", help="Register a local/running MCP server")
add_p.add_argument("--name", required=False, help="Server name")
add_p.add_argument(
"--transport",
choices=["stdio", "http", "unix", "sse"],
default=None,
help="Transport type",
)
add_p.add_argument("--url", default=None, help="Server URL (http, unix, sse)")
add_p.add_argument("--command", default=None, help="Command to run (stdio)")
add_p.add_argument("--args", nargs="*", default=None, help="Command arguments (stdio)")
add_p.add_argument("--socket-path", default=None, help="Unix socket path")
add_p.add_argument("--description", default="", help="Server description")
add_p.add_argument("--from", dest="from_manifest", default=None, help="Path to manifest.json")
add_p.set_defaults(func=cmd_mcp_add)
# ── remove ──
remove_p = mcp_sub.add_parser("remove", help="Remove an installed server")
remove_p.add_argument("name", help="Server name")
remove_p.set_defaults(func=cmd_mcp_remove)
# ── enable ──
enable_p = mcp_sub.add_parser("enable", help="Enable a disabled server")
enable_p.add_argument("name", help="Server name")
enable_p.set_defaults(func=cmd_mcp_enable)
# ── disable ──
disable_p = mcp_sub.add_parser("disable", help="Disable a server without removing it")
disable_p.add_argument("name", help="Server name")
disable_p.set_defaults(func=cmd_mcp_disable)
# ── list ──
list_p = mcp_sub.add_parser("list", help="List servers")
list_p.add_argument(
"--available", action="store_true", help="Show available servers from registry"
)
list_p.add_argument("--json", dest="output_json", action="store_true", help="Output as JSON")
list_p.set_defaults(func=cmd_mcp_list)
# ── info ──
info_p = mcp_sub.add_parser("info", help="Show server details")
info_p.add_argument("name", help="Server name")
info_p.add_argument("--json", dest="output_json", action="store_true", help="Output as JSON")
info_p.set_defaults(func=cmd_mcp_info)
# ── config ──
config_p = mcp_sub.add_parser("config", help="Set server configuration overrides")
config_p.add_argument("name", help="Server name")
config_p.add_argument(
"--set",
dest="set_env",
nargs="+",
metavar="KEY=VAL",
help="Set environment variable overrides",
)
config_p.add_argument(
"--set-header", dest="set_header", nargs="+", metavar="KEY=VAL", help="Set header overrides"
)
config_p.set_defaults(func=cmd_mcp_config)
# ── search ──
search_p = mcp_sub.add_parser("search", help="Search the registry")
search_p.add_argument("query", help="Search term (name, tag, description, tool name)")
search_p.add_argument("--json", dest="output_json", action="store_true", help="Output as JSON")
search_p.set_defaults(func=cmd_mcp_search)
# ── health ──
health_p = mcp_sub.add_parser("health", help="Check server health")
health_p.add_argument("name", nargs="?", default=None, help="Server name (all if omitted)")
health_p.add_argument("--json", dest="output_json", action="store_true", help="Output as JSON")
health_p.set_defaults(func=cmd_mcp_health)
# ── update ──
update_p = mcp_sub.add_parser(
"update", help="Update installed servers or refresh the registry index"
)
update_p.add_argument(
"name",
nargs="?",
default=None,
help="Server name to update (omit to update all registry servers)",
)
update_p.set_defaults(func=cmd_mcp_update)
# ── P0 command handlers ────────────────────────────────────────────
def cmd_mcp_install(args) -> int:
"""Install a server from the registry index."""
registry = _get_registry()
_print_security_notice_if_first_use(registry._base)
if not _ensure_index_available(registry):
return 1
try:
entry = registry.install(
args.name,
transport=args.transport,
version=args.version,
)
except ValueError as exc:
print(f"Error: {exc}", file=sys.stderr)
return 1
_mark_security_notice_shown(registry._base)
version_str = entry.get("manifest_version", "")
transport = entry.get("transport", "")
print(f"✓ Installed {args.name} v{version_str} ({transport})")
# Prompt for credentials defined in the manifest
manifest = entry.get("manifest", {})
_prompt_for_missing_credentials(registry, args.name, manifest)
print("\nNext steps:")
print(f" hive mcp health {args.name} Check that the server is reachable")
print(f" hive mcp info {args.name} View server details")
return 0
def cmd_mcp_add(args) -> int:
"""Register a local/running MCP server."""
registry = _get_registry()
# Handle --from manifest.json
if args.from_manifest:
return _cmd_mcp_add_from_manifest(registry, args.from_manifest)
if not args.name:
print(
"Error: --name is required.\n"
"Usage: hive mcp add --name my-server --transport http --url http://localhost:8080\n"
" or: hive mcp add --from manifest.json",
file=sys.stderr,
)
return 1
if not args.transport:
print(
f"Error: --transport is required.\n"
f"Supported transports: stdio, http, unix, sse\n"
f"Example: hive mcp add --name {args.name} --transport http --url http://localhost:8080",
file=sys.stderr,
)
return 1
try:
entry = registry.add_local(
name=args.name,
transport=args.transport,
url=args.url,
command=args.command,
args=args.args,
socket_path=args.socket_path,
description=args.description,
)
except ValueError as exc:
print(f"Error: {exc}", file=sys.stderr)
return 1
print(f"✓ Registered {args.name} ({entry['transport']})")
return 0
def _cmd_mcp_add_from_manifest(registry, manifest_path: str) -> int:
"""Register a server from a manifest.json file."""
path = Path(manifest_path)
if not path.exists():
print(
f"Error: manifest file not found: {manifest_path}\nCheck the path and try again.",
file=sys.stderr,
)
return 1
try:
manifest = json.loads(path.read_text(encoding="utf-8"))
except json.JSONDecodeError as exc:
print(
f"Error: invalid JSON in {manifest_path}: {exc}\n"
f"Validate with: python -m json.tool {manifest_path}",
file=sys.stderr,
)
return 1
name = manifest.get("name")
if not name:
print(
f"Error: manifest missing 'name' field.\nAdd a 'name' field to {manifest_path}.",
file=sys.stderr,
)
return 1
try:
entry = registry.add_local(name=name, manifest=manifest)
except ValueError as exc:
print(f"Error: {exc}", file=sys.stderr)
return 1
print(f"✓ Registered {name} from {manifest_path} ({entry['transport']})")
return 0
def cmd_mcp_remove(args) -> int:
"""Remove an installed server."""
registry = _get_registry()
try:
registry.remove(args.name)
except ValueError as exc:
print(f"Error: {exc}", file=sys.stderr)
return 1
print(f"✓ Removed {args.name}")
return 0
def cmd_mcp_enable(args) -> int:
"""Enable a disabled server."""
registry = _get_registry()
try:
registry.enable(args.name)
except ValueError as exc:
print(f"Error: {exc}", file=sys.stderr)
return 1
print(f"✓ Enabled {args.name}")
return 0
def cmd_mcp_disable(args) -> int:
"""Disable a server without removing it."""
registry = _get_registry()
try:
registry.disable(args.name)
except ValueError as exc:
print(f"Error: {exc}", file=sys.stderr)
return 1
print(f"✓ Disabled {args.name}")
return 0
def cmd_mcp_list(args) -> int:
"""List installed or available servers."""
registry = _get_registry()
if args.available:
if not _ensure_index_available(registry):
return 1
entries = registry.list_available()
if args.output_json:
_emit_json(entries)
else:
_render_available_table(entries)
else:
entries = registry.list_installed()
if args.output_json:
safe_entries = []
for entry in entries:
safe = dict(entry)
safe["overrides"] = _mask_overrides(safe.get("overrides", {}))
safe_entries.append(safe)
_emit_json(safe_entries)
else:
_render_installed_table(entries)
return 0
def cmd_mcp_info(args) -> int:
"""Show full details for a server."""
registry = _get_registry()
server = registry.get_server(args.name)
if server is None:
print(
f"Error: server '{args.name}' is not installed.\n"
f"Run 'hive mcp list' to see installed servers.\n"
f"Run 'hive mcp install {args.name}' to install from registry.",
file=sys.stderr,
)
return 1
# Enrich with agent usage for both JSON and human output
agents = _find_agents_using_server(registry, args.name)
if agents:
server["used_by_agents"] = agents
if args.output_json:
safe = dict(server)
safe["overrides"] = _mask_overrides(safe.get("overrides", {}))
_emit_json(safe)
return 0
manifest = server.get("manifest", {})
overrides = _mask_overrides(server.get("overrides", {}))
tools = manifest.get("tools", [])
status = manifest.get("status", "community")
hive_block = manifest.get("hive", {})
print(f"{server['name']}")
print("=" * 50)
# Core info
print(f" Source: {server.get('source', '')}")
print(f" Transport: {server.get('transport', '')}")
print(f" Version: {server.get('manifest_version', 'unknown')}")
print(f" Trust tier: {status}")
print(f" Enabled: {'yes' if server.get('enabled', True) else 'no'}")
# Description
desc = manifest.get("description", "")
if desc:
print(f" Description: {desc}")
# Health
health = server.get("last_health_status")
if health:
health_sym = {"healthy": "", "unhealthy": ""}.get(health, "")
print(f" Health: {health_sym} {health}")
last_check = server.get("last_health_check_at")
if last_check:
print(f" Last check: {last_check}")
last_error = server.get("last_error")
if last_error:
print(f" Last error: {last_error}")
# Tools
if tools:
print(f"\n Tools ({len(tools)}):")
for tool in tools:
if isinstance(tool, dict):
tool_name = tool.get("name", "")
tool_desc = tool.get("description", "")
print(f"{tool_name}: {tool_desc}" if tool_desc else f"{tool_name}")
else:
print(f"{tool}")
# Overrides
env_overrides = overrides.get("env", {})
header_overrides = overrides.get("headers", {})
if env_overrides or header_overrides:
print("\n Overrides:")
for key in env_overrides:
print(f" env.{key} = <set>")
for key in header_overrides:
print(f" header.{key} = <set>")
# Hive block
if hive_block:
profiles = hive_block.get("profiles", [])
if profiles:
print(f"\n Profiles: {', '.join(profiles)}")
min_ver = hive_block.get("min_version")
if min_ver:
print(f" Min Hive version: {min_ver}")
# Agent usage
if agents:
print("\n Used by agents:")
for agent in agents:
print(f"{agent}")
# Timestamps
print(f"\n Installed: {server.get('installed_at', 'unknown')}")
print(f" Installed by: {server.get('installed_by', 'unknown')}")
return 0
def cmd_mcp_config(args) -> int:
"""Set env or header overrides for a server."""
registry = _get_registry()
if not args.set_env and not args.set_header:
# Show current config
server = registry.get_server(args.name)
if server is None:
print(
f"Error: server '{args.name}' is not installed.\n"
f"Run 'hive mcp list' to see installed servers.",
file=sys.stderr,
)
return 1
masked = _mask_overrides(server.get("overrides", {}))
env_o = masked.get("env", {})
header_o = masked.get("headers", {})
if not env_o and not header_o:
print(f"No overrides set for {args.name}.")
print(f"Set one with: hive mcp config {args.name} --set KEY=VALUE")
else:
print(f"Overrides for {args.name}:")
for key in env_o:
print(f" env.{key} = <set>")
for key in header_o:
print(f" header.{key} = <set>")
return 0
try:
if args.set_env:
pairs = _parse_key_value_pairs(args.set_env)
for key, value in pairs.items():
registry.set_override(args.name, key, value, override_type="env")
print(f"✓ Set {len(pairs)} env override(s) for {args.name}")
if args.set_header:
pairs = _parse_key_value_pairs(args.set_header)
for key, value in pairs.items():
registry.set_override(args.name, key, value, override_type="headers")
print(f"✓ Set {len(pairs)} header override(s) for {args.name}")
except ValueError as exc:
print(f"Error: {exc}", file=sys.stderr)
return 1
return 0
# ── P1 command handlers ────────────────────────────────────────────
def cmd_mcp_search(args) -> int:
"""Search the registry index."""
registry = _get_registry()
if not _ensure_index_available(registry):
return 1
results = registry.search(args.query)
if args.output_json:
_emit_json(results)
return 0
if not results:
print(f"No servers matching '{args.query}'.")
return 0
print(f"Found {len(results)} server(s) matching '{args.query}':\n")
_render_available_table(results)
return 0
def cmd_mcp_health(args) -> int:
"""Check server health."""
registry = _get_registry()
try:
results = registry.health_check(name=args.name)
except ValueError as exc:
print(f"Error: {exc}", file=sys.stderr)
return 1
# Single server returns a flat dict, all-servers returns name->dict
if args.name:
results = {args.name: results}
if args.output_json:
_emit_json(results)
return 0
for name, result in results.items():
status = result.get("status", "unknown")
tools = result.get("tools", 0)
error = result.get("error")
sym = {"healthy": "", "unhealthy": ""}.get(status, "")
print(f" {sym} {name}: {status}", end="")
if status == "healthy" and tools:
print(f" ({tools} tools)")
elif error:
print(f"\n Error: {error}")
else:
print()
return 0
def cmd_mcp_update(args) -> int:
"""Update a single server, or refresh the index and update all registry servers."""
registry = _get_registry()
if args.name:
return _cmd_mcp_update_server(args.name, registry)
# Step 1: refresh the registry index
try:
count = registry.update_index()
except Exception as exc:
print(
f"Error: failed to update registry index: {exc}\n"
f"Check your network connection and try again.",
file=sys.stderr,
)
return 1
print(f"✓ Registry index updated ({count} servers available)")
# Step 2: update all installed registry servers (skip local/pinned)
installed = registry.list_installed()
registry_servers = [
s for s in installed if s.get("source") == "registry" and not s.get("pinned")
]
if not registry_servers:
return 0
print(f"\nUpdating {len(registry_servers)} installed server(s)...")
errors = 0
for server in registry_servers:
name = server["name"]
rc = _cmd_mcp_update_server(name, registry)
if rc != 0:
errors += 1
return 1 if errors else 0
def _cmd_mcp_update_server(name: str, registry=None) -> int:
"""Bridge: reinstall a server from the latest index.
This is a temporary bridge until #6355 adds proper version diffing,
tool-signature change detection, and --dry-run support.
"""
if registry is None:
registry = _get_registry()
server = registry.get_server(name)
if server is None:
print(
f"Error: server '{name}' is not installed.\n"
f"Run 'hive mcp install {name}' to install it.",
file=sys.stderr,
)
return 1
if server.get("source") != "registry":
print(
f"Error: '{name}' is a local server and cannot be updated from the registry.\n"
f"Use 'hive mcp remove {name}' and 'hive mcp add' to re-register it.",
file=sys.stderr,
)
return 1
if server.get("pinned"):
print(
f"Error: '{name}' is pinned to v{server.get('manifest_version', '?')}.\n"
f"To update a pinned server, remove and reinstall:\n"
f" hive mcp remove {name} && hive mcp install {name}",
file=sys.stderr,
)
return 1
# Refresh index, then reinstall
if not _ensure_index_available(registry):
return 1
old_version = server.get("manifest_version", "unknown")
transport = server.get("transport")
overrides = server.get("overrides", {})
was_enabled = server.get("enabled", True)
# Save the full entry before removing so we can restore on failure
saved_entry = dict(server)
saved_entry.pop("name", None)
try:
registry.remove(name)
entry = registry.install(name, transport=transport)
except ValueError as exc:
# Restore the original entry so update doesn't become an uninstall
data = registry._read_installed()
data["servers"][name] = saved_entry
registry._write_installed(data)
print(
f"Error: {exc}\nServer '{name}' has been restored to its previous state.",
file=sys.stderr,
)
return 1
new_version = entry.get("manifest_version", "unknown")
# Restore prior state from the previous installation
for key, value in overrides.get("env", {}).items():
registry.set_override(name, key, value, override_type="env")
for key, value in overrides.get("headers", {}).items():
registry.set_override(name, key, value, override_type="headers")
if not was_enabled:
registry.disable(name)
if old_version == new_version:
print(f"{name} is already at v{new_version}")
else:
print(f"✓ Updated {name}: v{old_version} → v{new_version}")
return 0
@@ -0,0 +1,252 @@
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
_CACHE_INDEX_PATH = Path.home() / ".hive" / "mcp_registry" / "cache" / "registry_index.json"
_FIXTURE_INDEX_PATH = Path(__file__).resolve().parent / "fixtures" / "registry_index.json"
def resolve_registry_servers(
*,
include: list[str] | None = None,
tags: list[str] | None = None,
exclude: list[str] | None = None,
profile: str | None = None,
max_tools: int | None = None,
versions: dict[str, str] | None = None,
) -> list[dict[str, Any]]:
"""
Resolve registry-sourced MCP servers for `mcp_registry.json` selection.
This function is written to be mock-friendly during early development:
- If the real `MCPRegistry` core module is present, delegate to it.
- Otherwise, fall back to a cached local index (`~/.hive/.../registry_index.json`)
and then to the repo fixture index.
"""
# `max_tools` is enforced by ToolRegistry. We keep it in the resolver
# signature to match the PRD and future MCPRegistry interfaces.
_ = max_tools
try:
from framework.runner.mcp_registry import MCPRegistry # type: ignore
registry = MCPRegistry()
resolved = registry.resolve_for_agent(
include=include or [],
tags=tags or [],
exclude=exclude or [],
profile=profile,
max_tools=max_tools,
versions=versions or {},
)
# Future-proof: normalize both dicts and typed objects to dicts.
return [_normalize_server_config(x) for x in resolved]
except ImportError:
# Expected while #6349/#6574 is not merged locally.
pass
except Exception as e:
logger.warning("MCPRegistry resolution failed; falling back to cache/fixtures: %s", e)
return _resolve_from_local_index(
include=include,
tags=tags,
exclude=exclude,
profile=profile,
versions=versions or {},
)
def _resolve_from_local_index(
*,
include: list[str] | None,
tags: list[str] | None,
exclude: list[str] | None,
profile: str | None,
versions: dict[str, str],
) -> list[dict[str, Any]]:
index = _load_index_json()
servers = _coerce_index_servers(index)
servers_by_name: dict[str, dict[str, Any]] = {
s["name"]: s for s in servers if isinstance(s, dict) and "name" in s
}
include_list = include or []
tags_list = tags or []
exclude_set = set(exclude or [])
def _profiles_of(entry: dict[str, Any]) -> set[str]:
if isinstance(entry.get("profiles"), list):
return set(entry["profiles"])
hive = entry.get("hive")
if isinstance(hive, dict) and isinstance(hive.get("profiles"), list):
return set(hive["profiles"])
return set()
def _tags_of(entry: dict[str, Any]) -> set[str]:
if isinstance(entry.get("tags"), list):
return set(entry["tags"])
return set()
def _entry_version(entry: dict[str, Any]) -> str | None:
# Prefer flat `version`, but support a few common shapes.
v = entry.get("version")
if isinstance(v, str):
return v
v2 = entry.get("manifest_version")
if isinstance(v2, str):
return v2
hive = entry.get("manifest")
if isinstance(hive, dict) and isinstance(hive.get("version"), str):
return hive["version"]
return None
def _version_allows(server_name: str) -> bool:
if server_name not in versions:
return True
pinned = versions[server_name]
entry = servers_by_name.get(server_name)
if not entry:
return False
return _entry_version(entry) == pinned
resolved_names: list[str] = []
resolved_set: set[str] = set()
# 1) Include-order first
for name in include_list:
if name in exclude_set:
continue
if name in servers_by_name and _version_allows(name) and name not in resolved_set:
resolved_names.append(name)
resolved_set.add(name)
# 2) Then tag/profile matches, alphabetical
profile_candidates = set()
if profile:
for name, entry in servers_by_name.items():
if name in exclude_set or not _version_allows(name):
continue
if profile in _profiles_of(entry):
profile_candidates.add(name)
tag_candidates = set()
if tags_list:
tags_set = set(tags_list)
for name, entry in servers_by_name.items():
if name in exclude_set or not _version_allows(name):
continue
if _tags_of(entry).intersection(tags_set):
tag_candidates.add(name)
tag_profile_names = sorted((profile_candidates | tag_candidates) - resolved_set)
resolved_names.extend(tag_profile_names)
# Missing requested servers should warn (FR-54).
for name in include_list:
if name in exclude_set:
continue
if name not in resolved_set:
if name not in servers_by_name:
logger.warning(
"Server '%s' requested by mcp_registry.json but not found in index. "
"Run: hive mcp install %s",
name,
name,
)
elif name in versions:
logger.warning(
"Server '%s' was requested but pinned version '%s' was not found in index. "
"Run: hive mcp update %s or change the pin in mcp_registry.json",
name,
versions[name],
name,
)
else:
logger.warning(
"Server '%s' requested by mcp_registry.json was not selected. "
"Check selection filters/exclude lists.",
name,
)
resolved_configs: list[dict[str, Any]] = []
repo_root = Path(__file__).resolve().parents[3]
for name in resolved_names:
entry = servers_by_name.get(name)
if not entry:
continue
config = entry.get("mcp_config")
if not isinstance(config, dict):
# Best-effort: allow a direct MCP config shape at top-level.
config = {
k: v
for k, v in entry.items()
if k
in {
"name",
"transport",
"command",
"args",
"env",
"cwd",
"url",
"headers",
"description",
}
}
mcp_config = dict(config)
mcp_config["name"] = name
if mcp_config.get("transport") == "stdio":
_absolutize_stdio_config_in_place(repo_root, mcp_config)
resolved_configs.append(mcp_config)
return resolved_configs
def _load_index_json() -> Any:
if _CACHE_INDEX_PATH.exists():
return json.loads(_CACHE_INDEX_PATH.read_text(encoding="utf-8"))
if _FIXTURE_INDEX_PATH.exists():
logger.info("Using local fixture index because registry cache is missing")
return json.loads(_FIXTURE_INDEX_PATH.read_text(encoding="utf-8"))
logger.warning("No local MCP registry index found (cache and fixture missing)")
return {"servers": []}
def _coerce_index_servers(index: Any) -> list[dict[str, Any]]:
if isinstance(index, list):
return [x for x in index if isinstance(x, dict)]
if isinstance(index, dict):
servers = index.get("servers", [])
if isinstance(servers, list):
return [x for x in servers if isinstance(x, dict)]
return []
def _normalize_server_config(raw: Any) -> dict[str, Any]:
if isinstance(raw, dict):
return dict(raw)
# Future-proof object-to-dict normalization.
for attr in ("to_dict", "model_dump"):
maybe = getattr(raw, attr, None)
if callable(maybe):
return dict(maybe())
return dict(getattr(raw, "__dict__", {}))
def _absolutize_stdio_config_in_place(repo_root: Path, config: dict[str, Any]) -> None:
cwd = config.get("cwd")
if isinstance(cwd, str) and not Path(cwd).is_absolute():
config["cwd"] = str((repo_root / cwd).resolve())
# We intentionally do not absolutize `args` here.
# For stdio servers, arguments may include the script name relative to
# `cwd` (e.g. "coder_tools_server.py" with cwd="tools"). ToolRegistry's
# stdio resolution logic handles script path checks and platform quirks.
+13 -2
View File
@@ -1429,12 +1429,18 @@ class AgentRunner:
def _load_registry_mcp_servers(self, agent_path: Path) -> None:
"""Load and register MCP servers selected via ``mcp_registry.json``."""
registry_json = agent_path / "mcp_registry.json"
if registry_json.is_file():
self._tool_registry.set_mcp_registry_agent_path(agent_path)
else:
self._tool_registry.set_mcp_registry_agent_path(None)
from framework.runner.mcp_registry import MCPRegistry
try:
registry = MCPRegistry()
registry.initialize()
server_configs = registry.load_agent_selection(agent_path)
server_configs, selection_max_tools = registry.load_agent_selection(agent_path)
except Exception as exc:
logger.warning(
"Failed to load MCP registry servers for '%s': %s",
@@ -1446,7 +1452,12 @@ class AgentRunner:
if not server_configs:
return
results = self._tool_registry.load_registry_servers(server_configs)
results = self._tool_registry.load_registry_servers(
server_configs,
preserve_existing_tools=True,
log_collisions=True,
max_tools=selection_max_tools,
)
loaded = [result for result in results if result["status"] == "loaded"]
skipped = [result for result in results if result["status"] != "loaded"]
+128 -6
View File
@@ -66,6 +66,8 @@ class ToolRegistry:
self._mcp_cred_snapshot: set[str] = set() # Credential filenames at MCP load time
self._mcp_aden_key_snapshot: str | None = None # ADEN_API_KEY value at MCP load time
self._mcp_server_tools: dict[str, set[str]] = {} # server name -> tool names
# Agent dir for re-loading registry MCP after credential resync.
self._mcp_registry_agent_path: Path | None = None
def register(
self,
@@ -490,7 +492,13 @@ class ToolRegistry:
self._resolve_mcp_server_config(server_config, base_dir)
for server_config in server_list
]
self.load_registry_servers(resolved_server_list, log_summary=False)
# Ordered first-wins for duplicate tool names across servers; keep tools.py tools.
self.load_registry_servers(
resolved_server_list,
log_summary=False,
preserve_existing_tools=True,
log_collisions=False,
)
# Snapshot credential files and ADEN_API_KEY so we can detect mid-session changes
self._mcp_cred_snapshot = self._snapshot_credentials()
@@ -499,6 +507,10 @@ class ToolRegistry:
def _register_mcp_server_with_retry(
self,
server_config: dict[str, Any],
*,
preserve_existing_tools: bool = True,
tool_cap: int | None = None,
log_collisions: bool = False,
) -> tuple[bool, int, str | None]:
"""Register a single MCP server with one retry for transient failures."""
name = server_config.get("name", "unknown")
@@ -506,7 +518,12 @@ class ToolRegistry:
for attempt in range(2):
try:
count = self.register_mcp_server(server_config)
count = self.register_mcp_server(
server_config,
preserve_existing_tools=preserve_existing_tools,
tool_cap=tool_cap,
log_collisions=log_collisions,
)
if count > 0:
return True, count, None
last_error = "registered 0 tools"
@@ -532,13 +549,38 @@ class ToolRegistry:
server_list: list[dict[str, Any]],
*,
log_summary: bool = True,
preserve_existing_tools: bool = True,
max_tools: int | None = None,
log_collisions: bool = False,
) -> list[dict[str, Any]]:
"""Register resolved registry-selected MCP servers with retry and status tracking."""
"""Register MCP servers from a resolved config list (registry and/or static).
``preserve_existing_tools`` enforces first-wins tool names (FR-100): later
servers skip names already taken including tools from ``mcp_servers.json``
or ``tools.py`` when those were loaded first.
``max_tools`` caps how many *new* tool names are registered across this batch
(collisions do not consume the cap). When ``log_collisions`` is True, skipped
duplicate names emit a warning (FR-101).
"""
results: list[dict[str, Any]] = []
tools_added_batch = 0
for server_config in server_list:
remaining: int | None = None
if max_tools is not None:
remaining = max_tools - tools_added_batch
if remaining <= 0:
break
name = server_config.get("name", "unknown")
success, tools_loaded, error = self._register_mcp_server_with_retry(server_config)
success, tools_loaded, error = self._register_mcp_server_with_retry(
server_config,
preserve_existing_tools=preserve_existing_tools,
tool_cap=remaining,
log_collisions=log_collisions,
)
tools_added_batch += tools_loaded
result = {
"server": name,
"status": "loaded" if success else "skipped",
@@ -565,6 +607,10 @@ class ToolRegistry:
self,
server_config: dict[str, Any],
use_connection_manager: bool = True,
*,
preserve_existing_tools: bool = True,
tool_cap: int | None = None,
log_collisions: bool = False,
) -> int:
"""
Register an MCP server and discover its tools.
@@ -581,6 +627,9 @@ class ToolRegistry:
- headers: HTTP headers (for http)
- description: Server description (optional)
use_connection_manager: When True, reuse a shared client keyed by server name
preserve_existing_tools: If True, do not replace tools already in the registry.
tool_cap: Max tools to newly register from this server (None = unlimited).
log_collisions: If True, log when this server skips a tool name already taken.
Returns:
Number of tools registered from this server
@@ -623,6 +672,23 @@ class ToolRegistry:
self._mcp_server_tools[server_name] = set()
count = 0
for mcp_tool in client.list_tools():
if tool_cap is not None and count >= tool_cap:
break
if preserve_existing_tools and mcp_tool.name in self._tools:
if log_collisions:
origin_server = (
self._find_mcp_origin_server_for_tool(mcp_tool.name) or "<existing>"
)
logger.warning(
"MCP tool '%s' from '%s' shadowed by '%s' (loaded first)",
mcp_tool.name,
server_name,
origin_server,
)
# Skip registration; do not update MCP tool bookkeeping for this server.
continue
# Convert MCP tool to framework Tool (strips context params from LLM schema)
tool = self._convert_mcp_tool_to_framework_tool(mcp_tool)
@@ -688,11 +754,27 @@ class ToolRegistry:
self._mcp_server_tools[server_name].add(mcp_tool.name)
count += 1
logger.info(f"Registered {count} tools from MCP server '{config.name}'")
logger.info(
"MCP Registry Load",
extra={
"server": config.name,
"status": "success",
"tools_loaded": count,
"skipped_reason": None,
},
)
return count
except Exception as e:
logger.error(f"Failed to register MCP server: {e}")
logger.error(
"MCP Registry Load",
extra={
"server": server_config.get("name", "unknown"),
"status": "failed",
"tools_loaded": 0,
"skipped_reason": str(e),
},
)
if "Connection closed" in str(e) and os.name == "nt":
logger.debug(
"On Windows, check that the MCP subprocess starts (e.g. uv in PATH, "
@@ -700,6 +782,12 @@ class ToolRegistry:
)
return 0
def _find_mcp_origin_server_for_tool(self, tool_name: str) -> str | None:
for server_name, tool_names in self._mcp_server_tools.items():
if tool_name in tool_names:
return server_name
return None
def _convert_mcp_tool_to_framework_tool(self, mcp_tool: Any) -> Tool:
"""
Convert an MCP tool to a framework Tool.
@@ -787,6 +875,37 @@ class ToolRegistry:
# MCP credential resync
# ------------------------------------------------------------------
def set_mcp_registry_agent_path(self, agent_path: Path | None) -> None:
"""Remember agent dir so registry MCP servers reload after credential resync."""
self._mcp_registry_agent_path = None if agent_path is None else Path(agent_path)
def reload_registry_mcp_servers_after_resync(self) -> None:
"""Re-run ``mcp_registry.json`` resolution and register servers (post-resync)."""
if self._mcp_registry_agent_path is None:
return
from framework.runner.mcp_registry import MCPRegistry
try:
reg = MCPRegistry()
reg.initialize()
configs, selection_max_tools = reg.load_agent_selection(self._mcp_registry_agent_path)
except Exception as exc:
logger.warning(
"Failed to reload MCP registry servers after resync for '%s': %s",
self._mcp_registry_agent_path.name,
exc,
)
return
if not configs:
return
self.load_registry_servers(
configs,
log_summary=True,
preserve_existing_tools=True,
log_collisions=True,
max_tools=selection_max_tools,
)
def _snapshot_credentials(self) -> set[str]:
"""Return the set of credential filenames currently on disk."""
try:
@@ -832,9 +951,12 @@ class ToolRegistry:
for name in self._mcp_tool_names:
self._tools.pop(name, None)
self._mcp_tool_names.clear()
self._mcp_server_tools.clear()
# 3. Re-load MCP servers (spawns fresh subprocesses with new credentials)
self.load_mcp_config(self._mcp_config_path)
if self._mcp_registry_agent_path is not None:
self.reload_registry_mcp_servers_after_resync()
logger.info("MCP server resync complete")
return True
+4
View File
@@ -200,6 +200,8 @@ class AgentRuntime:
self._skills_manager.load()
self.skill_dirs: list[str] = self._skills_manager.allowlisted_dirs
self.context_warn_ratio: float | None = self._skills_manager.context_warn_ratio
self.batch_init_nudge: str | None = self._skills_manager.batch_init_nudge
# Primary graph identity
self._graph_id: str = graph_id or "primary"
@@ -348,6 +350,8 @@ class AgentRuntime:
skills_catalog_prompt=self.skills_catalog_prompt,
protocols_prompt=self.protocols_prompt,
skill_dirs=self.skill_dirs,
context_warn_ratio=self.context_warn_ratio,
batch_init_nudge=self.batch_init_nudge,
)
await stream.start()
self._streams[ep_id] = stream
+4 -4
View File
@@ -16,7 +16,7 @@ from typing import Any
from framework.observability import set_trace_context
from framework.schemas.decision import Decision, DecisionType, Option, Outcome
from framework.schemas.run import Run, RunStatus
from framework.storage.backend import FileStorage
from framework.storage.concurrent import ConcurrentStorage
logger = logging.getLogger(__name__)
@@ -62,7 +62,7 @@ class Runtime:
logger.warning(f"Storage path does not exist, creating: {path}")
path.mkdir(parents=True, exist_ok=True)
self.storage = FileStorage(storage_path)
self.storage = ConcurrentStorage(storage_path)
self._current_run: Run | None = None
self._current_node: str = "unknown"
@@ -132,8 +132,8 @@ class Runtime:
self._current_run.output_data = output_data or {}
self._current_run.complete(status, narrative)
# Save to storage
self.storage.save_run(self._current_run)
# Save to storage (sync — Runtime methods are not async)
self.storage.save_run_sync(self._current_run)
self._current_run = None
def set_node(self, node_id: str) -> None:
+12 -1
View File
@@ -189,6 +189,8 @@ class ExecutionStream:
skills_catalog_prompt: str = "",
protocols_prompt: str = "",
skill_dirs: list[str] | None = None,
context_warn_ratio: float | None = None,
batch_init_nudge: str | None = None,
):
"""
Initialize execution stream.
@@ -215,6 +217,8 @@ class ExecutionStream:
skills_catalog_prompt: Available skills catalog for system prompt
protocols_prompt: Default skill operational protocols for system prompt
skill_dirs: Skill base directories for Tier 3 resource access
context_warn_ratio: Token usage ratio to trigger DS-13 preservation warning
batch_init_nudge: System prompt nudge for DS-12 batch auto-detection
"""
self.stream_id = stream_id
self.entry_spec = entry_spec
@@ -239,6 +243,8 @@ class ExecutionStream:
self._skills_catalog_prompt = skills_catalog_prompt
self._protocols_prompt = protocols_prompt
self._skill_dirs: list[str] = skill_dirs or []
self._context_warn_ratio: float | None = context_warn_ratio
self._batch_init_nudge: str | None = batch_init_nudge
_es_logger = logging.getLogger(__name__)
if protocols_prompt:
@@ -703,6 +709,8 @@ class ExecutionStream:
skills_catalog_prompt=self._skills_catalog_prompt,
protocols_prompt=self._protocols_prompt,
skill_dirs=self._skill_dirs,
context_warn_ratio=self._context_warn_ratio,
batch_init_nudge=self._batch_init_nudge,
)
# Track executor so inject_input() can reach EventLoopNode instances
self._active_executors[execution_id] = executor
@@ -961,7 +969,10 @@ class ExecutionStream:
return
import json as _json
session_dir = self._session_store.get_session_path(execution_id)
try:
session_dir = self._session_store.get_session_path(execution_id)
except ValueError:
return
runs_file = session_dir / "runs.jsonl"
now = datetime.now()
record = {
+12 -2
View File
@@ -90,9 +90,16 @@ async def create_queen(
try:
registry = MCPRegistry()
registry.initialize()
registry_configs = registry.load_agent_selection(queen_pkg_dir)
if (queen_pkg_dir / "mcp_registry.json").is_file():
queen_registry.set_mcp_registry_agent_path(queen_pkg_dir)
registry_configs, selection_max_tools = registry.load_agent_selection(queen_pkg_dir)
if registry_configs:
results = queen_registry.load_registry_servers(registry_configs)
results = queen_registry.load_registry_servers(
registry_configs,
preserve_existing_tools=True,
log_collisions=True,
max_tools=selection_max_tools,
)
logger.info("Queen: loaded MCP registry servers: %s", results)
except Exception:
logger.warning("Queen: MCP registry config failed to load", exc_info=True)
@@ -232,6 +239,7 @@ async def create_queen(
)
# ---- Default skill protocols -------------------------------------
_queen_skill_dirs: list[str] = []
try:
from framework.skills.manager import SkillsManager, SkillsManagerConfig
@@ -242,6 +250,7 @@ async def create_queen(
_queen_skills_mgr.load()
phase_state.protocols_prompt = _queen_skills_mgr.protocols_prompt
phase_state.skills_catalog_prompt = _queen_skills_mgr.skills_catalog_prompt
_queen_skill_dirs = _queen_skills_mgr.allowlisted_dirs
except Exception:
logger.debug("Queen skill loading failed (non-fatal)", exc_info=True)
@@ -306,6 +315,7 @@ async def create_queen(
dynamic_tools_provider=phase_state.get_current_tools,
dynamic_prompt_provider=phase_state.get_current_prompt,
iteration_metadata_provider=lambda: {"phase": phase_state.phase},
skill_dirs=_queen_skill_dirs,
)
session.queen_executor = executor
+18 -3
View File
@@ -9,27 +9,42 @@ from framework.skills.catalog import SkillCatalog
from framework.skills.config import DefaultSkillConfig, SkillsConfig
from framework.skills.defaults import DefaultSkillManager
from framework.skills.discovery import DiscoveryConfig, SkillDiscovery
from framework.skills.installer import (
fork_skill,
install_from_git,
install_from_registry,
remove_skill,
)
from framework.skills.manager import SkillsManager, SkillsManagerConfig
from framework.skills.models import TrustStatus
from framework.skills.parser import ParsedSkill, parse_skill_md
from framework.skills.registry import RegistryClient
from framework.skills.skill_errors import SkillError, SkillErrorCode, log_skill_error
from framework.skills.trust import TrustedRepoStore, TrustGate
from framework.skills.validator import ValidationResult, validate_strict
__all__ = [
"DefaultSkillConfig",
"DefaultSkillManager",
"DiscoveryConfig",
"ParsedSkill",
"RegistryClient",
"SkillCatalog",
"SkillDiscovery",
"SkillError",
"SkillErrorCode",
"SkillsConfig",
"SkillsManager",
"SkillsManagerConfig",
"TrustGate",
"TrustedRepoStore",
"TrustStatus",
"parse_skill_md",
"SkillError",
"SkillErrorCode",
"ValidationResult",
"fork_skill",
"install_from_git",
"install_from_registry",
"log_skill_error",
"parse_skill_md",
"remove_skill",
"validate_strict",
]
@@ -20,3 +20,5 @@ What to extract: URLs and key snippets (not full pages), relevant API fields
Before transitioning to the next phase/node, write a handoff summary to
`_handoff_context` with everything the next phase needs to know.
You will receive an alert when context reaches {{warn_at_usage_ratio_pct}}% — preserve immediately.
@@ -14,5 +14,5 @@ When a tool call fails:
2. Decide — transient: retry once. Structural fixable: fix and retry.
Structural unfixable: record as failed, move to next item.
Blocking all progress: record escalation note.
3. Adapt — if same tool failed 3+ times, stop using it and find alternative.
3. Adapt — if same tool failed {{max_retries_per_tool}}+ times, stop using it and find alternative.
Update plan in notes. Never silently drop the failed item.
@@ -8,7 +8,7 @@ metadata:
## Operational Protocol: Quality Self-Assessment
Every 5 iterations, self-assess:
Every {{assessment_interval}} iterations, self-assess:
1. On-task? Still working toward the stated objective?
2. Thorough? Cutting corners compared to earlier?
File diff suppressed because it is too large Load Diff
+80 -2
View File
@@ -8,6 +8,7 @@ from __future__ import annotations
import logging
from pathlib import Path
from typing import Any
from framework.skills.config import SkillsConfig
from framework.skills.parser import ParsedSkill, parse_skill_md
@@ -18,6 +19,56 @@ logger = logging.getLogger(__name__)
# Default skills directory relative to this module
_DEFAULT_SKILLS_DIR = Path(__file__).parent / "_default_skills"
# Default config values per skill — used for {{placeholder}} substitution
_SKILL_DEFAULTS: dict[str, dict[str, Any]] = {
"hive.quality-monitor": {"assessment_interval": 5},
"hive.error-recovery": {"max_retries_per_tool": 3},
"hive.context-preservation": {"warn_at_usage_ratio_pct": 45},
"hive.batch-ledger": {"checkpoint_every_n": 5},
}
# Keywords that indicate a batch processing scenario (DS-12)
_BATCH_KEYWORDS: tuple[str, ...] = (
"list of",
"collection of",
"set of",
"batch of",
"each item",
"for each",
"process all",
"records",
"entries",
"rows",
"items",
)
_BATCH_INIT_NUDGE = (
"Note: your input appears to describe a batch operation. "
"Initialize `_batch_ledger` with the total item count before processing."
)
def is_batch_scenario(text: str) -> bool:
"""Return True if *text* contains batch-processing indicators (DS-12)."""
lower = text.lower()
return any(kw in lower for kw in _BATCH_KEYWORDS)
def _apply_overrides(skill_name: str, body: str, overrides: dict[str, Any]) -> str:
"""Substitute {{placeholder}} values in a skill body using overrides + defaults."""
defaults = _SKILL_DEFAULTS.get(skill_name, {})
# Convert float warn_at_usage_ratio → warn_at_usage_ratio_pct for the placeholder
if "warn_at_usage_ratio" in overrides:
overrides = dict(overrides)
overrides.setdefault(
"warn_at_usage_ratio_pct", int(float(overrides["warn_at_usage_ratio"]) * 100)
)
values = {**defaults, **overrides}
for key, val in values.items():
body = body.replace(f"{{{{{key}}}}}", str(val))
return body
# Ordered list of default skills (name → directory)
SKILL_REGISTRY: dict[str, str] = {
"hive.note-taking": "note-taking",
@@ -123,8 +174,10 @@ class DefaultSkillManager:
skill = self._skills.get(skill_name)
if skill is None:
continue
# Use the full body — each SKILL.md contains exactly one protocol section
parts.append(skill.body)
# Apply config overrides to {{placeholder}} values before injection
overrides = self._config.get_default_overrides(skill_name)
body = _apply_overrides(skill_name, skill.body, overrides)
parts.append(body)
if len(parts) <= 1:
return ""
@@ -198,3 +251,28 @@ class DefaultSkillManager:
def active_skills(self) -> dict[str, ParsedSkill]:
"""All active default skills keyed by name."""
return dict(self._skills)
@property
def batch_init_nudge(self) -> str | None:
"""Nudge text to prepend to system prompt when batch input detected (DS-12).
Returns None if ``hive.batch-ledger`` is disabled or auto_detect_batch is False.
"""
if "hive.batch-ledger" not in self._skills:
return None
overrides = self._config.get_default_overrides("hive.batch-ledger")
if overrides.get("auto_detect_batch") is False:
return None
return _BATCH_INIT_NUDGE
@property
def context_warn_ratio(self) -> float | None:
"""Token usage ratio at which to inject a context preservation warning (DS-13).
Returns None if ``hive.context-preservation`` is disabled.
Defaults to 0.45 when the skill is active but no override is set.
"""
if "hive.context-preservation" not in self._skills:
return None
overrides = self._config.get_default_overrides("hive.context-preservation")
return float(overrides.get("warn_at_usage_ratio", 0.45))
+348
View File
@@ -0,0 +1,348 @@
"""Skill install, remove, and fork operations.
Handles filesystem operations for the hive skill CLI:
- install_from_git: git clone --depth=1 copy to target directory
- install_from_registry: resolve registry entry delegate to install_from_git
- remove_skill: delete a skill from ~/.hive/skills/
- fork_skill: copy a skill to a new location with a new name
- maybe_show_install_notice: one-time security notice on first install (NFR-5)
"""
from __future__ import annotations
import shutil
import subprocess
import tempfile
from pathlib import Path
from framework.skills.parser import ParsedSkill
from framework.skills.skill_errors import SkillError, SkillErrorCode
# Default install destination for user-scope skills
USER_SKILLS_DIR = Path.home() / ".hive" / "skills"
# Sentinel file for the one-time security notice on first install (NFR-5)
INSTALL_NOTICE_SENTINEL = Path.home() / ".hive" / ".install_notice_shown"
_INSTALL_NOTICE = """\
Security Notice: Installing Third-Party Skills
Skills are instructions executed by AI agents. A malicious
skill can manipulate agent behavior, exfiltrate data, or
cause unintended actions.
Only install skills from sources you trust. Review the
SKILL.md before running it in a production environment.
This notice is shown once. Use 'hive skill doctor' to audit
installed skills at any time.
"""
def maybe_show_install_notice() -> None:
"""Print a one-time security notice before the first skill install (NFR-5).
Touches a sentinel file in ~/.hive/ after showing the notice so it is
only displayed once across all future installs.
"""
if INSTALL_NOTICE_SENTINEL.exists():
return
print(_INSTALL_NOTICE, flush=True)
try:
INSTALL_NOTICE_SENTINEL.parent.mkdir(parents=True, exist_ok=True)
INSTALL_NOTICE_SENTINEL.touch()
except OSError:
pass # If we can't write the sentinel, just show the notice every time
def install_from_git(
git_url: str,
skill_name: str,
subdirectory: str | None = None,
version: str | None = None,
target_dir: Path | None = None,
) -> Path:
"""Install a skill from a git repository.
Clones the repository with --depth=1 into a temporary directory, then
copies the skill subdirectory (or repo root) to the target location.
Args:
git_url: Git repository URL to clone.
skill_name: Name of the skill used as the install directory name.
subdirectory: Relative path within the repo to the skill directory.
If None, the repo root is treated as the skill directory.
version: Git ref to checkout (tag, branch, or commit). Defaults to
the remote's default branch.
target_dir: Where to install the skill. Defaults to
~/.hive/skills/<skill_name>/.
Returns:
Path to the installed skill directory (the parent of SKILL.md).
Raises:
SkillError: On any failure (git not found, clone failed, SKILL.md missing).
"""
if shutil.which("git") is None:
raise SkillError(
code=SkillErrorCode.SKILL_ACTIVATION_FAILED,
what=f"Cannot install '{skill_name}' from {git_url}",
why="git is not installed or not on PATH.",
fix="Install git (https://git-scm.com/) and retry.",
)
dest = (target_dir or USER_SKILLS_DIR) / skill_name
if dest.exists():
raise SkillError(
code=SkillErrorCode.SKILL_ACTIVATION_FAILED,
what=f"Cannot install '{skill_name}'",
why=f"Directory already exists: {dest}",
fix=f"Run 'hive skill remove {skill_name}' first, or use a different --name.",
)
tmp_dir = tempfile.mkdtemp(prefix="hive-skill-install-")
try:
_git_clone_shallow(git_url, Path(tmp_dir), version=version)
# Locate the skill within the cloned repo
source_dir = Path(tmp_dir) / subdirectory if subdirectory else Path(tmp_dir)
skill_md = source_dir / "SKILL.md"
if not skill_md.exists():
raise SkillError(
code=SkillErrorCode.SKILL_NOT_FOUND,
what=f"No SKILL.md found in '{subdirectory or '/'}' of {git_url}",
why="The expected SKILL.md file is not present at the given path.",
fix=(
"Check the repository structure and use "
"'hive skill install --from <url>' with the correct subdirectory."
),
)
dest.parent.mkdir(parents=True, exist_ok=True)
_copy_skill_dir(source_dir, dest)
return dest
except SkillError:
raise
except Exception as exc:
raise SkillError(
code=SkillErrorCode.SKILL_ACTIVATION_FAILED,
what=f"Failed to install '{skill_name}' from {git_url}",
why=str(exc),
fix="Check the URL, your network connection, and git configuration.",
) from exc
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)
def install_from_registry(
registry_entry: dict,
target_dir: Path | None = None,
version: str | None = None,
) -> Path:
"""Install a skill using a registry index entry.
Resolves the git_url and subdirectory from the registry entry and
delegates to install_from_git.
Args:
registry_entry: A skill entry dict from skill_index.json.
target_dir: Override install destination.
version: Override version (defaults to entry's 'version' field).
Returns:
Path to the installed skill directory.
Raises:
SkillError: If the registry entry is missing required fields or install fails.
"""
name = registry_entry.get("name")
git_url = registry_entry.get("git_url")
if not name or not git_url:
raise SkillError(
code=SkillErrorCode.SKILL_NOT_FOUND,
what="Incomplete registry entry — missing 'name' or 'git_url'.",
why="The registry index entry does not contain all required fields.",
fix="Report this issue to the registry maintainer.",
)
resolved_version = version or registry_entry.get("version")
subdirectory = registry_entry.get("subdirectory")
return install_from_git(
git_url=git_url,
skill_name=str(name),
subdirectory=subdirectory,
version=resolved_version,
target_dir=target_dir,
)
def remove_skill(name: str, skills_dir: Path | None = None) -> bool:
"""Remove an installed skill from the user skills directory.
Args:
name: Skill directory name to remove.
skills_dir: Override the search directory (default: ~/.hive/skills/).
Returns:
True if removed, False if not found.
Raises:
SkillError: If the directory exists but cannot be removed.
"""
target = (skills_dir or USER_SKILLS_DIR) / name
if not target.exists():
return False
try:
shutil.rmtree(target)
return True
except OSError as exc:
raise SkillError(
code=SkillErrorCode.SKILL_ACTIVATION_FAILED,
what=f"Failed to remove skill '{name}' at {target}",
why=str(exc),
fix="Check file permissions and try again.",
) from exc
def fork_skill(
source: ParsedSkill,
new_name: str,
target_dir: Path,
) -> Path:
"""Create a local editable copy of a skill with a new name.
Copies the skill's base directory to target_dir/new_name/ and rewrites
the 'name' field in the copied SKILL.md frontmatter.
Args:
source: The source skill to fork (from SkillDiscovery).
new_name: Name for the forked skill.
target_dir: Parent directory for the fork (e.g. ~/.hive/skills/).
Returns:
Path to the forked skill directory.
Raises:
SkillError: If the target already exists or the copy fails.
"""
dest = target_dir / new_name
if dest.exists():
raise SkillError(
code=SkillErrorCode.SKILL_ACTIVATION_FAILED,
what=f"Cannot fork to '{dest}'",
why="Target directory already exists.",
fix=f"Choose a different --name or remove '{dest}' first.",
)
source_dir = Path(source.base_dir)
try:
dest.parent.mkdir(parents=True, exist_ok=True)
_copy_skill_dir(source_dir, dest)
except OSError as exc:
raise SkillError(
code=SkillErrorCode.SKILL_ACTIVATION_FAILED,
what=f"Failed to fork skill '{source.name}' to '{dest}'",
why=str(exc),
fix="Check file permissions and available disk space.",
) from exc
# Rewrite the name in the forked SKILL.md via YAML round-trip (safe)
forked_skill_md = dest / "SKILL.md"
if forked_skill_md.exists():
_rewrite_name_in_skill_md(forked_skill_md, new_name)
return dest
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _git_clone_shallow(git_url: str, target: Path, version: str | None = None) -> None:
"""Clone a git repo at --depth=1 into target directory.
Args:
git_url: Repository URL.
target: Destination directory (will be created by git).
version: Optional git ref (branch/tag) to clone.
Raises:
SkillError: If the clone fails.
"""
cmd = ["git", "clone", "--depth=1"]
if version:
cmd += ["--branch", version]
cmd += [git_url, str(target)]
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=60,
)
except subprocess.TimeoutExpired:
raise SkillError(
code=SkillErrorCode.SKILL_ACTIVATION_FAILED,
what=f"git clone timed out for {git_url}",
why="The clone operation took longer than 60 seconds.",
fix="Check your network connection and retry.",
) from None
except (FileNotFoundError, OSError) as exc:
raise SkillError(
code=SkillErrorCode.SKILL_ACTIVATION_FAILED,
what=f"Cannot run git for {git_url}",
why=str(exc),
fix="Ensure git is installed and on PATH.",
) from exc
if result.returncode != 0:
stderr = result.stderr.strip()
raise SkillError(
code=SkillErrorCode.SKILL_ACTIVATION_FAILED,
what=f"git clone failed for {git_url}",
why=stderr or f"git exited with code {result.returncode}",
fix="Check the URL is correct and the repository is publicly accessible.",
)
def _copy_skill_dir(src: Path, dst: Path) -> None:
"""Copy a skill directory, ignoring VCS and cache artifacts."""
ignore = shutil.ignore_patterns(".git", "__pycache__", "*.pyc", ".venv", "venv", "node_modules")
shutil.copytree(src, dst, ignore=ignore)
def _rewrite_name_in_skill_md(skill_md: Path, new_name: str) -> None:
"""Rewrite the 'name' field in a SKILL.md frontmatter via YAML round-trip.
Parses the frontmatter with yaml.safe_load, updates 'name', re-serializes
with yaml.dump, and reconstructs the file as:
---
<yaml>
---
<body>
Falls back to no-op if the file can't be parsed (the copy is still usable).
"""
import yaml
try:
content = skill_md.read_text(encoding="utf-8")
parts = content.split("---", 2)
if len(parts) < 3:
return
frontmatter = yaml.safe_load(parts[1].strip())
if not isinstance(frontmatter, dict):
return
frontmatter["name"] = new_name
new_yaml = yaml.dump(frontmatter, default_flow_style=False, allow_unicode=True)
new_content = f"---\n{new_yaml}---\n{parts[2]}"
skill_md.write_text(new_content, encoding="utf-8")
except Exception:
pass # Degraded: forked copy works, name just isn't updated
+17
View File
@@ -67,6 +67,7 @@ class SkillsManager:
self._catalog_prompt: str = ""
self._protocols_prompt: str = ""
self._allowlisted_dirs: list[str] = []
self._default_mgr: object = None # DefaultSkillManager, set after load()
# ------------------------------------------------------------------
# Factory for backwards-compat bridge
@@ -90,6 +91,7 @@ class SkillsManager:
mgr._catalog_prompt = skills_catalog_prompt
mgr._protocols_prompt = protocols_prompt
mgr._allowlisted_dirs = []
mgr._default_mgr = None
return mgr
# ------------------------------------------------------------------
@@ -146,6 +148,7 @@ class SkillsManager:
default_mgr.load()
default_mgr.log_active_skills()
protocols_prompt = default_mgr.build_protocols_prompt()
self._default_mgr = default_mgr
# DX-3: Community skill startup summary
if self._config.project_root is not None and not self._config.skip_community_discovery:
community_count = len(catalog._skills) if catalog_prompt else 0
@@ -189,6 +192,20 @@ class SkillsManager:
"""Skill base directories for Tier 3 resource access (AS-6)."""
return self._allowlisted_dirs
@property
def batch_init_nudge(self) -> str | None:
"""Batch init nudge text for DS-12 auto-detection, or None if disabled."""
if self._default_mgr is None:
return None
return self._default_mgr.batch_init_nudge # type: ignore[union-attr]
@property
def context_warn_ratio(self) -> float | None:
"""Token usage ratio for DS-13 context preservation warning, or None if disabled."""
if self._default_mgr is None:
return None
return self._default_mgr.context_warn_ratio # type: ignore[union-attr]
@property
def is_loaded(self) -> bool:
return self._loaded
+11 -2
View File
@@ -211,6 +211,15 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
fix=f"Rename the directory to '{name}' or set name to '{parent_dir_name}'.",
)
# Coerce compatibility / allowed-tools to list[str] — many SKILL.md files
# in the wild use a plain string instead of a YAML list.
raw_compat = frontmatter.get("compatibility")
if isinstance(raw_compat, str):
raw_compat = [raw_compat]
raw_tools = frontmatter.get("allowed-tools")
if isinstance(raw_tools, str):
raw_tools = [raw_tools]
return ParsedSkill(
name=name,
description=str(description).strip(),
@@ -219,7 +228,7 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
source_scope=source_scope,
body=body,
license=frontmatter.get("license"),
compatibility=frontmatter.get("compatibility"),
compatibility=raw_compat,
metadata=frontmatter.get("metadata"),
allowed_tools=frontmatter.get("allowed-tools"),
allowed_tools=raw_tools,
)
+206
View File
@@ -0,0 +1,206 @@
"""Registry client for the Hive community skill registry.
Fetches the skill index from the hive-skill-registry GitHub repo, caches it
locally, and provides search and resolution utilities.
The registry repo (Phase 3) may not exist yet. All public methods degrade
gracefully returning None or [] on any network or parse failure.
Configure a custom registry URL via the HIVE_REGISTRY_URL environment variable.
"""
from __future__ import annotations
import json
import logging
import os
from datetime import UTC, datetime
from pathlib import Path
from urllib.error import URLError
from urllib.request import urlopen
logger = logging.getLogger(__name__)
# Default registry index URL (Phase 3 repo, may not exist yet)
_DEFAULT_REGISTRY_URL = (
"https://raw.githubusercontent.com/hive-skill-registry/"
"hive-skill-registry/main/skill_index.json"
)
_CACHE_DIR = Path.home() / ".hive" / "registry_cache"
_CACHE_INDEX_PATH = _CACHE_DIR / "skill_index.json"
_CACHE_METADATA_PATH = _CACHE_DIR / "metadata.json"
_CACHE_TTL_SECONDS = 3600 # 1 hour
class RegistryClient:
"""Client for the Hive community skill registry.
All public methods return None / [] on any failure never raise.
Network errors, parse failures, and missing registries are all
treated as graceful degradation.
"""
def __init__(
self,
registry_url: str | None = None,
cache_dir: Path | None = None,
) -> None:
self._url = registry_url or os.environ.get("HIVE_REGISTRY_URL", _DEFAULT_REGISTRY_URL)
cache_root = cache_dir or _CACHE_DIR
self._index_path = cache_root / "skill_index.json"
self._metadata_path = cache_root / "metadata.json"
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def fetch_index(self, force_refresh: bool = False) -> dict | None:
"""Return the registry index dict.
Uses the local cache if it is fresh (within TTL) unless
force_refresh=True. Returns None on any failure.
"""
if not force_refresh and self._is_cache_fresh():
cached = self._load_cache()
if cached is not None:
return cached
raw = self._http_fetch(self._url)
if raw is None:
# Network unavailable — fall back to stale cache if present
stale = self._load_cache()
if stale is not None:
logger.debug("registry: network unavailable, using stale cache")
return stale
try:
data = json.loads(raw.decode("utf-8"))
except (json.JSONDecodeError, UnicodeDecodeError) as exc:
logger.warning("registry: failed to parse index JSON: %s", exc)
return self._load_cache()
if not isinstance(data, dict):
logger.warning("registry: index is not a JSON object")
return self._load_cache()
self._save_cache(data)
return data
def search(self, query: str) -> list[dict]:
"""Search registry skills by name, description, or tags.
Case-insensitive substring match. Returns [] if index unavailable.
"""
index = self.fetch_index()
if not index:
return []
skills = index.get("skills", [])
if not isinstance(skills, list):
return []
q = query.lower()
results = []
for entry in skills:
if not isinstance(entry, dict):
continue
name = str(entry.get("name", "")).lower()
description = str(entry.get("description", "")).lower()
tags = " ".join(str(t) for t in entry.get("tags", [])).lower()
if q in name or q in description or q in tags:
results.append(entry)
return results
def get_skill_entry(self, name: str) -> dict | None:
"""Look up a single skill by exact name. Returns None if not found."""
index = self.fetch_index()
if not index:
return None
for entry in index.get("skills", []):
if isinstance(entry, dict) and entry.get("name") == name:
return entry
return None
def get_pack(self, pack_name: str) -> list[str] | None:
"""Return the list of skill names in a starter pack.
Returns None if the pack is not found or the index is unavailable.
"""
index = self.fetch_index()
if not index:
return None
for pack in index.get("packs", []):
if isinstance(pack, dict) and pack.get("name") == pack_name:
skills = pack.get("skills", [])
if isinstance(skills, list):
return [s for s in skills if isinstance(s, str)]
return None
def resolve_git_url(self, name: str) -> tuple[str, str | None] | None:
"""Return (git_url, subdirectory) for a skill name.
Returns None if the skill is not in the registry or the index
is unavailable.
"""
entry = self.get_skill_entry(name)
if not entry:
return None
git_url = entry.get("git_url")
if not git_url:
return None
subdirectory = entry.get("subdirectory") or None
return str(git_url), subdirectory
# ------------------------------------------------------------------
# Cache internals
# ------------------------------------------------------------------
def _load_cache(self) -> dict | None:
"""Read cached index from disk. Returns None if absent or unreadable."""
try:
data = json.loads(self._index_path.read_text(encoding="utf-8"))
return data if isinstance(data, dict) else None
except FileNotFoundError:
return None
except Exception as exc:
logger.debug("registry: could not read cache: %s", exc)
return None
def _save_cache(self, data: dict) -> None:
"""Write index to disk atomically (.tmp then rename)."""
try:
self._index_path.parent.mkdir(parents=True, exist_ok=True)
tmp = self._index_path.with_suffix(".tmp")
tmp.write_text(json.dumps(data, indent=2), encoding="utf-8")
tmp.replace(self._index_path)
# Update metadata
meta = {"last_fetched": datetime.now(tz=UTC).isoformat()}
meta_tmp = self._metadata_path.with_suffix(".tmp")
meta_tmp.write_text(json.dumps(meta, indent=2), encoding="utf-8")
meta_tmp.replace(self._metadata_path)
except Exception as exc:
logger.debug("registry: could not write cache: %s", exc)
def _is_cache_fresh(self) -> bool:
"""Return True if the cached index was fetched within the TTL."""
try:
meta = json.loads(self._metadata_path.read_text(encoding="utf-8"))
last_fetched = datetime.fromisoformat(meta["last_fetched"])
age = (datetime.now(tz=UTC) - last_fetched).total_seconds()
return age < _CACHE_TTL_SECONDS
except Exception:
return False
def _http_fetch(self, url: str, timeout: int = 10) -> bytes | None:
"""Fetch URL contents. Returns None on any network error — never raises."""
try:
with urlopen(url, timeout=timeout) as resp: # noqa: S310
return resp.read()
except URLError as exc:
logger.debug("registry: HTTP fetch failed for %s: %s", url, exc)
return None
except TimeoutError as exc:
logger.debug("registry: HTTP fetch timed out for %s: %s", url, exc)
return None
except Exception as exc:
logger.debug("registry: unexpected error fetching %s: %s", url, exc)
return None
+178
View File
@@ -0,0 +1,178 @@
"""Strict SKILL.md validation for contributor tooling (hive skill validate).
Unlike the lenient parser used at runtime, this module applies hard-error rules
that match the Agent Skills specification exactly. Intended for contributor
tooling, CI gates, and hive skill doctor.
"""
from __future__ import annotations
import os
import stat
from dataclasses import dataclass, field
from pathlib import Path
from framework.skills.parser import _MAX_NAME_LENGTH
@dataclass
class ValidationResult:
"""Result of a strict SKILL.md validation run."""
passed: bool
errors: list[str] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
def validate_strict(path: Path) -> ValidationResult:
"""Run all strict checks against a SKILL.md file.
Applies hard-error rules that go beyond the lenient runtime parser:
- name must be explicit (no directory-name fallback)
- YAML must parse without fixup
- name/directory mismatch is an error, not a warning
- empty body is an error
- scripts must be executable
Args:
path: Path to the SKILL.md file to validate.
Returns:
ValidationResult with passed=True if no errors, plus any warnings.
"""
errors: list[str] = []
warnings: list[str] = []
# 1. File exists and is readable
try:
content = path.read_text(encoding="utf-8")
except FileNotFoundError:
return ValidationResult(passed=False, errors=[f"File not found: {path}"])
except PermissionError:
return ValidationResult(passed=False, errors=[f"Permission denied reading: {path}"])
except OSError as exc:
return ValidationResult(passed=False, errors=[f"Cannot read file: {exc}"])
# 2. File not empty
if not content.strip():
return ValidationResult(passed=False, errors=["File is empty."])
# 3. YAML frontmatter present
parts = content.split("---", 2)
if len(parts) < 3:
return ValidationResult(
passed=False,
errors=["Missing YAML frontmatter — wrap frontmatter with --- delimiters."],
)
raw_yaml = parts[1].strip()
body = parts[2].strip()
if not raw_yaml:
return ValidationResult(
passed=False,
errors=["Frontmatter delimiters present but YAML block is empty."],
)
# 4. YAML parses WITHOUT fixup (strict: unquoted colons are an error)
import yaml
frontmatter: dict | None = None
try:
frontmatter = yaml.safe_load(raw_yaml)
except yaml.YAMLError as exc:
errors.append(
f"YAML parse error: {exc}. "
'Wrap values containing colons in quotes, e.g. description: "Use for: research".'
)
return ValidationResult(passed=False, errors=errors, warnings=warnings)
if not isinstance(frontmatter, dict):
return ValidationResult(
passed=False,
errors=["Frontmatter is not a YAML key-value mapping."],
)
# 5. description present and non-empty
description = frontmatter.get("description")
if not description or not str(description).strip():
errors.append("Missing required field: 'description' must be present and non-empty.")
# 6. name present and non-empty (no directory-name fallback in strict mode)
name = frontmatter.get("name")
if not name or not str(name).strip():
errors.append(
"Missing required field: 'name' must be present. "
"Add 'name: your-skill-name' to the frontmatter."
)
else:
name = str(name).strip()
parent_dir_name = path.parent.name
# 7. name length <= 64 chars
if len(name) > _MAX_NAME_LENGTH:
errors.append(
f"Skill name '{name}' is {len(name)} characters — "
f"maximum is {_MAX_NAME_LENGTH}. Shorten the name."
)
# 8. name matches parent directory (dot-namespace prefix allowed: hive.X with dir X)
if name != parent_dir_name and not name.endswith(f".{parent_dir_name}"):
errors.append(
f"Name '{name}' does not match directory '{parent_dir_name}'. "
f"Rename the directory to '{name}' or set name to '{parent_dir_name}'."
)
# 9. body non-empty
if not body:
errors.append(
"Skill body (instructions) is empty. "
"Add markdown instructions after the closing --- delimiter."
)
# 10. license present — warning only
if not frontmatter.get("license"):
warnings.append("No 'license' field — consider adding a license (e.g. MIT, Apache-2.0).")
# 11. Scripts in scripts/ exist and are executable (POSIX only —
# Windows does not use POSIX permission bits)
base_dir = path.parent
scripts_dir = base_dir / "scripts"
if scripts_dir.is_dir() and os.name != "nt":
for script_path in sorted(scripts_dir.iterdir()):
if script_path.is_file():
if not (script_path.stat().st_mode & (stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)):
errors.append(
f"Script not executable: {script_path.name}. Run: chmod +x {script_path}"
)
# 12. allowed-tools entries are non-empty strings — warning if malformed
allowed_tools = frontmatter.get("allowed-tools")
if allowed_tools is not None:
if not isinstance(allowed_tools, list):
warnings.append("'allowed-tools' should be a list of strings.")
else:
for tool in allowed_tools:
if not isinstance(tool, str) or not tool.strip():
warnings.append(f"'allowed-tools' entry {tool!r} is not a non-empty string.")
# 13. compatibility is a list of strings — error if malformed
compatibility = frontmatter.get("compatibility")
if compatibility is not None:
if not isinstance(compatibility, list):
errors.append("'compatibility' must be a list of strings.")
else:
for item in compatibility:
if not isinstance(item, str):
errors.append(f"'compatibility' entry {item!r} is not a string.")
# 14. metadata is a dict — error if malformed
metadata = frontmatter.get("metadata")
if metadata is not None and not isinstance(metadata, dict):
errors.append("'metadata' must be a YAML mapping (dict), not a scalar or list.")
return ValidationResult(
passed=len(errors) == 0,
errors=errors,
warnings=warnings,
)
+2 -2
View File
@@ -1,6 +1,6 @@
"""Storage backends for runtime data."""
from framework.storage.backend import FileStorage
from framework.storage.concurrent import ConcurrentStorage
from framework.storage.conversation_store import FileConversationStore
__all__ = ["FileStorage", "FileConversationStore"]
__all__ = ["ConcurrentStorage", "FileConversationStore"]
-266
View File
@@ -1,266 +0,0 @@
"""
File-based storage backend for runtime data.
DEPRECATED: This storage backend is deprecated for new sessions.
New sessions use unified storage at sessions/{session_id}/state.json.
This module is kept for backward compatibility with old run data only.
Uses Pydantic's built-in serialization.
"""
import json
from pathlib import Path
from framework.schemas.run import Run, RunStatus, RunSummary
from framework.utils.io import atomic_write
class FileStorage:
"""
DEPRECATED: File-based storage for old runs only.
New sessions use unified storage at sessions/{session_id}/state.json.
This class is kept for backward compatibility with old run data.
Old directory structure (deprecated):
{base_path}/
runs/ # DEPRECATED - no longer written
{run_id}.json
summaries/ # DEPRECATED - no longer written
{run_id}.json
indexes/ # DEPRECATED - no longer written or read
by_goal/
{goal_id}.json
by_status/
{status}.json
by_node/
{node_id}.json
"""
def __init__(self, base_path: str | Path):
self.base_path = Path(base_path)
self._ensure_dirs()
def _ensure_dirs(self) -> None:
"""Create directory structure if it doesn't exist.
DEPRECATED: All directories (runs/, summaries/, indexes/) are deprecated.
New sessions use unified storage at sessions/{session_id}/state.json.
This method is now a no-op. Tests should not rely on this.
"""
# No-op: do not create deprecated directories
pass
def _validate_key(self, key: str) -> None:
"""
Validate key to prevent path traversal attacks.
Args:
key: The key to validate
Raises:
ValueError: If key contains path traversal or dangerous patterns
"""
if not key or key.strip() == "":
raise ValueError("Key cannot be empty")
# Block path separators
if "/" in key or "\\" in key:
raise ValueError(f"Invalid key format: path separators not allowed in '{key}'")
# Block parent directory references
if ".." in key or key.startswith("."):
raise ValueError(f"Invalid key format: path traversal detected in '{key}'")
# Block absolute paths
if key.startswith("/") or (len(key) > 1 and key[1] == ":"):
raise ValueError(f"Invalid key format: absolute paths not allowed in '{key}'")
# Block null bytes (Unix path injection)
if "\x00" in key:
raise ValueError("Invalid key format: null bytes not allowed")
# Block other dangerous special characters
dangerous_chars = {"<", ">", "|", "&", "$", "`", "'", '"'}
if any(char in key for char in dangerous_chars):
raise ValueError(f"Invalid key format: contains dangerous characters in '{key}'")
# === RUN OPERATIONS ===
def save_run(self, run: Run) -> None:
"""Save a run to storage.
DEPRECATED: This method is now a no-op.
New sessions use unified storage at sessions/{session_id}/state.json.
Tests should not rely on FileStorage - use unified session storage instead.
"""
import warnings
warnings.warn(
"FileStorage.save_run() is deprecated. "
"New sessions use unified storage at sessions/{session_id}/state.json. "
"This write has been skipped.",
DeprecationWarning,
stacklevel=2,
)
# No-op: do not write to deprecated locations
def load_run(self, run_id: str) -> Run | None:
"""Load a run from storage."""
run_path = self.base_path / "runs" / f"{run_id}.json"
if not run_path.exists():
return None
with open(run_path, encoding="utf-8") as f:
return Run.model_validate_json(f.read())
def load_summary(self, run_id: str) -> RunSummary | None:
"""Load just the summary (faster than full run)."""
summary_path = self.base_path / "summaries" / f"{run_id}.json"
if not summary_path.exists():
# Fall back to computing from full run
run = self.load_run(run_id)
if run:
return RunSummary.from_run(run)
return None
with open(summary_path, encoding="utf-8") as f:
return RunSummary.model_validate_json(f.read())
def delete_run(self, run_id: str) -> bool:
"""Delete a run from storage."""
run_path = self.base_path / "runs" / f"{run_id}.json"
summary_path = self.base_path / "summaries" / f"{run_id}.json"
if not run_path.exists():
return False
# Load run to get index keys
run = self.load_run(run_id)
if run:
self._remove_from_index("by_goal", run.goal_id, run_id)
self._remove_from_index("by_status", run.status.value, run_id)
for node_id in run.metrics.nodes_executed:
self._remove_from_index("by_node", node_id, run_id)
run_path.unlink()
if summary_path.exists():
summary_path.unlink()
return True
# === QUERY OPERATIONS ===
def get_runs_by_goal(self, goal_id: str) -> list[str]:
"""Get all run IDs for a goal.
DEPRECATED: Indexes are deprecated. For new sessions, scan sessions/*/state.json instead.
This method only returns old run IDs from deprecated indexes.
"""
import warnings
warnings.warn(
"FileStorage.get_runs_by_goal() is deprecated. "
"For new sessions, scan sessions/*/state.json instead.",
DeprecationWarning,
stacklevel=2,
)
return self._get_index("by_goal", goal_id)
def get_runs_by_status(self, status: str | RunStatus) -> list[str]:
"""Get all run IDs with a status.
DEPRECATED: Indexes are deprecated. For new sessions, scan sessions/*/state.json instead.
This method only returns old run IDs from deprecated indexes.
"""
import warnings
warnings.warn(
"FileStorage.get_runs_by_status() is deprecated. "
"For new sessions, scan sessions/*/state.json instead.",
DeprecationWarning,
stacklevel=2,
)
if isinstance(status, RunStatus):
status = status.value
return self._get_index("by_status", status)
def get_runs_by_node(self, node_id: str) -> list[str]:
"""Get all run IDs that executed a node.
DEPRECATED: Indexes are deprecated. For new sessions, scan sessions/*/state.json instead.
This method only returns old run IDs from deprecated indexes.
"""
import warnings
warnings.warn(
"FileStorage.get_runs_by_node() is deprecated. "
"For new sessions, scan sessions/*/state.json instead.",
DeprecationWarning,
stacklevel=2,
)
return self._get_index("by_node", node_id)
def list_all_runs(self) -> list[str]:
"""List all run IDs."""
runs_dir = self.base_path / "runs"
return [f.stem for f in runs_dir.glob("*.json")]
def list_all_goals(self) -> list[str]:
"""List all goal IDs that have runs.
DEPRECATED: Indexes are deprecated. For new sessions, scan sessions/*/state.json instead.
This method only returns goals from old run IDs in deprecated indexes.
"""
import warnings
warnings.warn(
"FileStorage.list_all_goals() is deprecated. "
"For new sessions, scan sessions/*/state.json instead.",
DeprecationWarning,
stacklevel=2,
)
goals_dir = self.base_path / "indexes" / "by_goal"
if not goals_dir.exists():
return []
return [f.stem for f in goals_dir.glob("*.json")]
# === INDEX OPERATIONS ===
def _get_index(self, index_type: str, key: str) -> list[str]:
"""Get values from an index."""
self._validate_key(key) # Prevent path traversal
index_path = self.base_path / "indexes" / index_type / f"{key}.json"
if not index_path.exists():
return []
with open(index_path, encoding="utf-8") as f:
return json.load(f)
def _add_to_index(self, index_type: str, key: str, value: str) -> None:
"""Add a value to an index."""
self._validate_key(key) # Prevent path traversal
index_path = self.base_path / "indexes" / index_type / f"{key}.json"
values = self._get_index(index_type, key) # Already validated in _get_index
if value not in values:
values.append(value)
with atomic_write(index_path) as f:
json.dump(values, f, indent=2)
def _remove_from_index(self, index_type: str, key: str, value: str) -> None:
"""Remove a value from an index."""
self._validate_key(key) # Prevent path traversal
index_path = self.base_path / "indexes" / index_type / f"{key}.json"
values = self._get_index(index_type, key) # Already validated in _get_index
if value in values:
values.remove(value)
with atomic_write(index_path) as f:
json.dump(values, f, indent=2)
# === UTILITY ===
def get_stats(self) -> dict:
"""Get storage statistics."""
return {
"total_runs": len(self.list_all_runs()),
"total_goals": len(self.list_all_goals()),
"storage_path": str(self.base_path),
}
+133 -72
View File
@@ -1,7 +1,7 @@
"""
Concurrent Storage - Thread-safe storage backend with file locking.
Wraps FileStorage with:
Provides:
- Async file locking for atomic writes
- Write batching for performance
- Read caching for concurrent access
@@ -16,8 +16,8 @@ from pathlib import Path
from typing import Any
from weakref import WeakValueDictionary
from framework.schemas.run import Run, RunStatus, RunSummary
from framework.storage.backend import FileStorage
from framework.schemas.run import Run, RunSummary
from framework.utils.io import atomic_write
logger = logging.getLogger(__name__)
@@ -41,7 +41,6 @@ class ConcurrentStorage:
- Async file locking to prevent concurrent write corruption
- Write batching to reduce I/O overhead
- Read caching for frequently accessed data
- Compatible API with FileStorage
Example:
storage = ConcurrentStorage("/path/to/storage")
@@ -75,7 +74,6 @@ class ConcurrentStorage:
max_locks: Maximum number of active file locks to track strongly
"""
self.base_path = Path(base_path)
self._base_storage = FileStorage(base_path)
# Caching
self._cache: dict[str, CacheEntry] = {}
@@ -157,6 +155,93 @@ class ConcurrentStorage:
return lock
# === KEY VALIDATION ===
@staticmethod
def _validate_key(key: str) -> None:
"""Validate key to prevent path traversal attacks.
Args:
key: The key to validate
Raises:
ValueError: If key contains path traversal or dangerous patterns
"""
if not key or key.strip() == "":
raise ValueError("Key cannot be empty")
if "/" in key or "\\" in key:
raise ValueError(f"Invalid key format: path separators not allowed in '{key}'")
if ".." in key or key.startswith("."):
raise ValueError(f"Invalid key format: path traversal detected in '{key}'")
if key.startswith("/") or (len(key) > 1 and key[1] == ":"):
raise ValueError(f"Invalid key format: absolute paths not allowed in '{key}'")
if "\x00" in key:
raise ValueError("Invalid key format: null bytes not allowed")
dangerous_chars = {"<", ">", "|", "&", "$", "`", "'", '"'}
if any(char in key for char in dangerous_chars):
raise ValueError(f"Invalid key format: contains dangerous characters in '{key}'")
# === FILE OPERATIONS (formerly in FileStorage) ===
def _save_run_sync(self, run: Run) -> None:
"""Persist a run to disk as ``runs/{run_id}.json``.
Uses an atomic write (temp-file + rename) so a mid-write crash
never leaves a partially written file on disk.
"""
self._validate_key(run.id)
runs_dir = self.base_path / "runs"
runs_dir.mkdir(parents=True, exist_ok=True)
run_path = runs_dir / f"{run.id}.json"
with atomic_write(run_path) as f:
f.write(run.model_dump_json(indent=2))
def _load_run_sync(self, run_id: str) -> Run | None:
"""Load a run from storage."""
run_path = self.base_path / "runs" / f"{run_id}.json"
if not run_path.exists():
return None
with open(run_path, encoding="utf-8") as f:
return Run.model_validate_json(f.read())
def _load_summary_sync(self, run_id: str) -> RunSummary | None:
"""Load just the summary (faster than full run)."""
self._validate_key(run_id)
summary_path = self.base_path / "summaries" / f"{run_id}.json"
if not summary_path.exists():
run = self._load_run_sync(run_id)
if run:
return RunSummary.from_run(run)
return None
with open(summary_path, encoding="utf-8") as f:
return RunSummary.model_validate_json(f.read())
def _delete_run_sync(self, run_id: str) -> bool:
"""Delete a run from storage."""
run_path = self.base_path / "runs" / f"{run_id}.json"
summary_path = self.base_path / "summaries" / f"{run_id}.json"
if not run_path.exists():
return False
run_path.unlink()
if summary_path.exists():
summary_path.unlink()
return True
def _list_all_runs_sync(self) -> list[str]:
"""List all run IDs."""
runs_dir = self.base_path / "runs"
if not runs_dir.exists():
return []
return [f.stem for f in runs_dir.glob("*.json")]
# === RUN OPERATIONS (Async, Thread-Safe) ===
async def save_run(self, run: Run, immediate: bool = False) -> None:
@@ -180,40 +265,17 @@ class ConcurrentStorage:
await self._write_queue.put(("run", run))
async def _save_run_locked(self, run: Run) -> None:
"""Save a run with file locking, including index locks."""
"""Save a run with file locking."""
lock_key = f"run:{run.id}"
# Helper to get lock
async def get_lock(k):
return await self._get_lock(k)
# Acquire main lock
run_lock = await get_lock(lock_key)
run_lock = await self._get_lock(lock_key)
async with run_lock:
# 2. Acquire index locks
index_lock_keys = [
f"index:by_goal:{run.goal_id}",
f"index:by_status:{run.status.value}",
]
for node_id in run.metrics.nodes_executed:
index_lock_keys.append(f"index:by_node:{node_id}")
# Collect index locks
index_locks = [await get_lock(k) for k in index_lock_keys]
# Recursive acquisition
async def with_locks(locks, callback):
if not locks:
return await callback()
async with locks[0]:
return await with_locks(locks[1:], callback)
async def perform_save():
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, self._base_storage.save_run, run)
await loop.run_in_executor(None, self._save_run_sync, run)
await with_locks(index_locks, perform_save)
await perform_save()
async def load_run(self, run_id: str, use_cache: bool = True) -> Run | None:
"""
@@ -225,7 +287,11 @@ class ConcurrentStorage:
Returns:
Run object or None if not found
Raises:
ValueError: If run_id contains path traversal characters.
"""
self._validate_key(run_id)
if use_cache:
cache_key = f"run:{run_id}"
cached = self._cache.get(cache_key)
@@ -240,7 +306,7 @@ class ConcurrentStorage:
lock_key = f"run:{run_id}"
async with await self._get_lock(lock_key):
loop = asyncio.get_event_loop()
run = await loop.run_in_executor(None, self._base_storage.load_run, run_id)
run = await loop.run_in_executor(None, self._load_run_sync, run_id)
# Update cache
if run:
@@ -249,7 +315,12 @@ class ConcurrentStorage:
return run
async def load_summary(self, run_id: str, use_cache: bool = True) -> RunSummary | None:
"""Load just the summary (faster than full run)."""
"""Load just the summary (faster than full run).
Raises:
ValueError: If run_id contains path traversal characters.
"""
self._validate_key(run_id)
cache_key = f"summary:{run_id}"
# Check cache
@@ -262,7 +333,7 @@ class ConcurrentStorage:
lock_key = f"summary:{run_id}"
async with await self._get_lock(lock_key):
loop = asyncio.get_event_loop()
summary = await loop.run_in_executor(None, self._base_storage.load_summary, run_id)
summary = await loop.run_in_executor(None, self._load_summary_sync, run_id)
# Update cache
if summary:
@@ -271,11 +342,16 @@ class ConcurrentStorage:
return summary
async def delete_run(self, run_id: str) -> bool:
"""Delete a run from storage."""
"""Delete a run from storage.
Raises:
ValueError: If run_id contains path traversal characters.
"""
self._validate_key(run_id)
lock_key = f"run:{run_id}"
async with await self._get_lock(lock_key):
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(None, self._base_storage.delete_run, run_id)
result = await loop.run_in_executor(None, self._delete_run_sync, run_id)
# Clear cache
self._cache.pop(f"run:{run_id}", None)
@@ -283,37 +359,10 @@ class ConcurrentStorage:
return result
# === QUERY OPERATIONS (Async, with Locking) ===
async def get_runs_by_goal(self, goal_id: str) -> list[str]:
"""Get all run IDs for a goal."""
async with await self._get_lock(f"index:by_goal:{goal_id}"):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self._base_storage.get_runs_by_goal, goal_id)
async def get_runs_by_status(self, status: str | RunStatus) -> list[str]:
"""Get all run IDs with a status."""
if isinstance(status, RunStatus):
status = status.value
async with await self._get_lock(f"index:by_status:{status}"):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self._base_storage.get_runs_by_status, status)
async def get_runs_by_node(self, node_id: str) -> list[str]:
"""Get all run IDs that executed a node."""
async with await self._get_lock(f"index:by_node:{node_id}"):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self._base_storage.get_runs_by_node, node_id)
async def list_all_runs(self) -> list[str]:
"""List all run IDs."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self._base_storage.list_all_runs)
async def list_all_goals(self) -> list[str]:
"""List all goal IDs that have runs."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self._base_storage.list_all_goals)
return await loop.run_in_executor(None, self._list_all_runs_sync)
# === BATCH OPERATIONS ===
@@ -411,10 +460,11 @@ class ConcurrentStorage:
async def get_stats(self) -> dict:
"""Get storage statistics."""
loop = asyncio.get_event_loop()
base_stats = await loop.run_in_executor(None, self._base_storage.get_stats)
all_runs = await loop.run_in_executor(None, self._list_all_runs_sync)
return {
**base_stats,
"total_runs": len(all_runs),
"storage_path": str(self.base_path),
"cache": self.get_cache_stats(),
"pending_writes": self._write_queue.qsize(),
"running": self._running,
@@ -423,10 +473,21 @@ class ConcurrentStorage:
# === SYNC API (for backward compatibility) ===
def save_run_sync(self, run: Run) -> None:
"""Synchronous save (uses base storage directly with lock)."""
# Use threading lock for sync operations
self._base_storage.save_run(run)
"""Synchronous save — persists a run to disk immediately."""
self._validate_key(run.id)
# Invalidate summary cache since the run data is changing
self._cache.pop(f"summary:{run.id}", None)
self._save_run_sync(run)
# Refresh run cache
self._cache[f"run:{run.id}"] = CacheEntry(run, time.time())
def load_run_sync(self, run_id: str) -> Run | None:
"""Synchronous load (uses base storage directly)."""
return self._base_storage.load_run(run_id)
"""Synchronous load.
Raises:
ValueError: If run_id contains path traversal characters.
"""
self._validate_key(run_id)
return self._load_run_sync(run_id)
+7 -1
View File
@@ -62,8 +62,14 @@ class SessionStore:
Returns:
Path to session directory
Raises:
ValueError: If session_id resolves outside the sessions directory
"""
return self.sessions_dir / session_id
resolved = (self.sessions_dir / session_id).resolve()
if not resolved.is_relative_to(self.sessions_dir.resolve()):
raise ValueError(f"Invalid session ID: {session_id}")
return resolved
def get_state_path(self, session_id: str) -> Path:
"""
+6 -2
View File
@@ -73,7 +73,9 @@ class DebugTool:
Args:
test_storage: Storage for test and result data
runtime_storage: Optional FileStorage for Runtime data
runtime_storage: Optional storage backend for Runtime data.
Must expose a synchronous ``load_run_sync(run_id)`` method
(e.g. ``ConcurrentStorage``).
"""
self.test_storage = test_storage
self.runtime_storage = runtime_storage
@@ -233,7 +235,9 @@ class DebugTool:
return {}
try:
run = self.runtime_storage.load_run(run_id)
# Use the synchronous loader — _get_runtime_data is not async
# and ConcurrentStorage.load_run() is a coroutine.
run = self.runtime_storage.load_run_sync(run_id)
if not run:
return {"error": f"Run {run_id} not found"}
+1 -1
View File
@@ -1,7 +1,7 @@
"""
File-based storage backend for test data.
Follows the same pattern as framework/storage/backend.py (FileStorage),
Follows the same pattern as framework/storage/concurrent.py (ConcurrentStorage),
storing tests as JSON files with indexes for efficient querying.
"""
+2 -2
View File
@@ -49,7 +49,7 @@ def recall_diary(query: str = "", days_back: int = 7) -> str:
"""
from datetime import date, timedelta
from framework.agents.queen.queen_memory import read_episodic_memory
from framework.agents.queen.queen_memory import format_memory_date, read_episodic_memory
days_back = max(1, min(days_back, 30))
today = date.today()
@@ -70,7 +70,7 @@ def recall_diary(query: str = "", days_back: int = 7) -> str:
if not matched:
continue
content = "### ".join(matched)
label = d.strftime("%B %-d, %Y")
label = format_memory_date(d)
if d == today:
label = f"Today — {label}"
entry = f"## {label}\n\n{content}"
+2
View File
@@ -2,6 +2,7 @@ import { Routes, Route } from "react-router-dom";
import Home from "./pages/home";
import MyAgents from "./pages/my-agents";
import Workspace from "./pages/workspace";
import NotFound from "./pages/not-found";
function App() {
return (
@@ -9,6 +10,7 @@ function App() {
<Route path="/" element={<Home />} />
<Route path="/my-agents" element={<MyAgents />} />
<Route path="/workspace" element={<Workspace />} />
<Route path="*" element={<NotFound />} />
</Routes>
);
}
+19
View File
@@ -0,0 +1,19 @@
import { Link } from "react-router-dom";
export default function NotFound() {
return (
<div className="min-h-screen bg-background flex flex-col items-center justify-center px-6 text-center">
<h1 className="text-5xl font-semibold text-foreground">404</h1>
<p className="mt-3 text-sm text-muted-foreground">Page not found</p>
<p className="mt-1 text-sm text-muted-foreground/80">
The page youre looking for doesnt exist.
</p>
<Link
to="/"
className="mt-6 inline-flex items-center rounded-lg border border-border/40 px-4 py-2 text-sm font-medium text-foreground hover:bg-muted/40 transition-colors"
>
Back to Home
</Link>
</div>
);
}
+5 -5
View File
@@ -32,7 +32,7 @@ class _FakeRegistry:
def load_agent_selection(self, agent_path: Path):
self.loaded_paths.append(agent_path)
return list(self._returned_configs)
return list(self._returned_configs), None
def test_agent_runner_loads_registry_selected_servers(tmp_path, monkeypatch):
@@ -61,7 +61,7 @@ def test_agent_runner_loads_registry_selected_servers(tmp_path, monkeypatch):
monkeypatch.setattr(AgentRunner, "_resolve_default_model", staticmethod(lambda: "test-model"))
monkeypatch.setattr(
"framework.runner.tool_registry.ToolRegistry.register_mcp_server",
lambda self, server_config, use_connection_manager=True: (
lambda self, server_config, use_connection_manager=True, **kwargs: (
registered.append(server_config) or 1
),
)
@@ -95,7 +95,7 @@ def test_agent_runner_skips_registry_when_no_servers_selected(tmp_path, monkeypa
monkeypatch.setattr(AgentRunner, "_resolve_default_model", staticmethod(lambda: "test-model"))
monkeypatch.setattr(
"framework.runner.tool_registry.ToolRegistry.register_mcp_server",
lambda self, server_config, use_connection_manager=True: (
lambda self, server_config, use_connection_manager=True, **kwargs: (
registered.append(server_config) or 1
),
)
@@ -135,7 +135,7 @@ def test_agent_runner_logs_actual_registry_load_results(tmp_path, monkeypatch):
monkeypatch.setattr(AgentRunner, "_resolve_default_model", staticmethod(lambda: "test-model"))
monkeypatch.setattr(
"framework.runner.tool_registry.ToolRegistry.load_registry_servers",
lambda self, server_configs: [
lambda self, server_configs, **kwargs: [
{"server": "jira", "status": "loaded", "tools_loaded": 2, "skipped_reason": None},
{
"server": "slack",
@@ -223,7 +223,7 @@ def test_integration_real_registry_to_agent_runner(tmp_path, monkeypatch):
registered: list[dict] = []
monkeypatch.setattr(
"framework.runner.tool_registry.ToolRegistry.register_mcp_server",
lambda self, server_config, use_connection_manager=True: (
lambda self, server_config, use_connection_manager=True, **kwargs: (
registered.append(server_config) or 1
),
)
+126
View File
@@ -9,6 +9,7 @@ from framework.skills.defaults import (
SHARED_MEMORY_KEYS,
SKILL_REGISTRY,
DefaultSkillManager,
is_batch_scenario,
)
from framework.skills.parser import parse_skill_md
@@ -186,3 +187,128 @@ class TestSkillsConfig:
assert config.skills == []
assert config.default_skills == {}
assert config.all_defaults_disabled is False
class TestConfigOverrideSubstitution:
"""Config overrides replace {{placeholder}} values in injected protocol text."""
def test_quality_monitor_default_interval(self):
manager = DefaultSkillManager()
manager.load()
prompt = manager.build_protocols_prompt()
assert "Every 5 iterations" in prompt
def test_quality_monitor_override_interval(self):
config = SkillsConfig.from_agent_vars(
default_skills={"hive.quality-monitor": {"assessment_interval": 10}}
)
manager = DefaultSkillManager(config)
manager.load()
prompt = manager.build_protocols_prompt()
assert "Every 10 iterations" in prompt
assert "Every 5 iterations" not in prompt
def test_error_recovery_default_retries(self):
manager = DefaultSkillManager()
manager.load()
prompt = manager.build_protocols_prompt()
assert "3+ times" in prompt
def test_error_recovery_override_retries(self):
config = SkillsConfig.from_agent_vars(
default_skills={"hive.error-recovery": {"max_retries_per_tool": 5}}
)
manager = DefaultSkillManager(config)
manager.load()
prompt = manager.build_protocols_prompt()
assert "5+ times" in prompt
assert "3+ times" not in prompt
def test_context_preservation_default_threshold(self):
manager = DefaultSkillManager()
manager.load()
prompt = manager.build_protocols_prompt()
assert "45%" in prompt
def test_context_preservation_override_threshold(self):
config = SkillsConfig.from_agent_vars(
default_skills={"hive.context-preservation": {"warn_at_usage_ratio": 0.4}}
)
manager = DefaultSkillManager(config)
manager.load()
prompt = manager.build_protocols_prompt()
assert "40%" in prompt
assert "45%" not in prompt
def test_no_unreplaced_placeholders_with_defaults(self):
"""All {{...}} placeholders should be replaced when using defaults."""
manager = DefaultSkillManager()
manager.load()
prompt = manager.build_protocols_prompt()
assert "{{" not in prompt
class TestBatchAutoDetection:
"""DS-12: is_batch_scenario() and batch_init_nudge property."""
def test_detects_list_of(self):
assert is_batch_scenario("process a list of 100 leads") is True
def test_detects_collection_of(self):
assert is_batch_scenario("a collection of invoices") is True
def test_detects_items(self):
assert is_batch_scenario("go through all items in the spreadsheet") is True
def test_detects_for_each(self):
assert is_batch_scenario("for each record, send an email") is True
def test_no_match_single_task(self):
assert is_batch_scenario("write a summary of the quarterly report") is False
def test_batch_nudge_active_by_default(self):
manager = DefaultSkillManager()
manager.load()
assert manager.batch_init_nudge is not None
assert "_batch_ledger" in manager.batch_init_nudge
def test_batch_nudge_none_when_skill_disabled(self):
config = SkillsConfig.from_agent_vars(
default_skills={"hive.batch-ledger": {"enabled": False}}
)
manager = DefaultSkillManager(config)
manager.load()
assert manager.batch_init_nudge is None
def test_batch_nudge_none_when_auto_detect_disabled(self):
config = SkillsConfig.from_agent_vars(
default_skills={"hive.batch-ledger": {"auto_detect_batch": False}}
)
manager = DefaultSkillManager(config)
manager.load()
assert manager.batch_init_nudge is None
class TestContextWarnRatio:
"""DS-13: context_warn_ratio property."""
def test_default_ratio(self):
manager = DefaultSkillManager()
manager.load()
assert manager.context_warn_ratio == pytest.approx(0.45)
def test_override_ratio(self):
config = SkillsConfig.from_agent_vars(
default_skills={"hive.context-preservation": {"warn_at_usage_ratio": 0.3}}
)
manager = DefaultSkillManager(config)
manager.load()
assert manager.context_warn_ratio == pytest.approx(0.3)
def test_ratio_none_when_skill_disabled(self):
config = SkillsConfig.from_agent_vars(
default_skills={"hive.context-preservation": {"enabled": False}}
)
manager = DefaultSkillManager(config)
manager.load()
assert manager.context_warn_ratio is None
+105 -2
View File
@@ -18,11 +18,14 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from framework.config import get_llm_extra_kwargs
from framework.llm.anthropic import AnthropicProvider
from framework.llm.litellm import (
OPENROUTER_TOOL_COMPAT_MODEL_CACHE,
LiteLLMProvider,
_compute_retry_delay,
_ensure_ollama_chat_prefix,
_is_ollama_model,
)
from framework.llm.provider import LLMProvider, LLMResponse, Tool
@@ -93,9 +96,9 @@ class TestLiteLLMProviderInit:
def test_init_ollama_no_key_needed(self):
"""Test that Ollama models don't require API key."""
with patch.dict(os.environ, {}, clear=True):
# Should not raise.
# Should not raise; ollama/ is normalised to ollama_chat/ for tool-call support.
provider = LiteLLMProvider(model="ollama/llama3")
assert provider.model == "ollama/llama3"
assert provider.model == "ollama_chat/llama3"
class TestLiteLLMProviderComplete:
@@ -1084,3 +1087,103 @@ class TestIsLocalModel:
from framework.runner.runner import AgentRunner
assert AgentRunner._is_local_model(model) is False
# ---------------------------------------------------------------------------
# Ollama helper functions
# ---------------------------------------------------------------------------
class TestIsOllamaModel:
"""Tests for _is_ollama_model()."""
@pytest.mark.parametrize(
"model",
[
"ollama/llama3",
"ollama/mistral:7b",
"ollama_chat/llama3",
"ollama_chat/qwen2.5:72b",
],
)
def test_ollama_models_return_true(self, model):
assert _is_ollama_model(model) is True
@pytest.mark.parametrize(
"model",
[
"gpt-4o-mini",
"anthropic/claude-3-haiku",
"openai/gpt-4o",
"gemini/gemini-1.5-flash",
"llama3",
"",
],
)
def test_non_ollama_models_return_false(self, model):
assert _is_ollama_model(model) is False
class TestEnsureOllamaChatPrefix:
"""Tests for _ensure_ollama_chat_prefix()."""
@pytest.mark.parametrize(
("input_model", "expected"),
[
("ollama/llama3", "ollama_chat/llama3"),
("ollama/mistral:7b", "ollama_chat/mistral:7b"),
("ollama/qwen2.5:72b-instruct", "ollama_chat/qwen2.5:72b-instruct"),
],
)
def test_rewrites_ollama_to_ollama_chat(self, input_model, expected):
assert _ensure_ollama_chat_prefix(input_model) == expected
@pytest.mark.parametrize(
"model",
[
"ollama_chat/llama3",
"gpt-4o-mini",
"anthropic/claude-3-haiku",
"gemini/gemini-1.5-flash",
"",
],
)
def test_leaves_non_ollama_prefix_unchanged(self, model):
assert _ensure_ollama_chat_prefix(model) == model
class TestGetLlmExtraKwargsOllama:
"""Tests for num_ctx injection via get_llm_extra_kwargs() for Ollama."""
def test_ollama_provider_returns_num_ctx(self):
"""Ollama config should inject num_ctx with default 16384."""
config = {
"llm": {"provider": "ollama", "model": "ollama/llama3"},
}
with patch("framework.config.get_hive_config", return_value=config):
result = get_llm_extra_kwargs()
assert result == {"num_ctx": 16384}
def test_ollama_provider_respects_custom_num_ctx(self):
"""User-specified num_ctx in config should take precedence."""
config = {
"llm": {"provider": "ollama", "model": "ollama/llama3", "num_ctx": 32768},
}
with patch("framework.config.get_hive_config", return_value=config):
result = get_llm_extra_kwargs()
assert result == {"num_ctx": 32768}
def test_non_ollama_provider_returns_empty(self):
"""Non-Ollama provider without subscriptions should return empty dict."""
config = {
"llm": {"provider": "anthropic", "model": "claude-3-haiku"},
}
with patch("framework.config.get_hive_config", return_value=config):
result = get_llm_extra_kwargs()
assert result == {}
def test_empty_config_returns_empty(self):
"""Missing config should return empty dict."""
with patch("framework.config.get_hive_config", return_value={}):
result = get_llm_extra_kwargs()
assert result == {}
+65
View File
@@ -0,0 +1,65 @@
"""Tests for MCP structured error formatting."""
import pytest
from framework.runner.mcp_errors import (
MCPAuthError,
MCPError,
MCPErrorCode,
MCPToolNotFoundError,
)
def test_mcp_error_code_stored():
err = MCPError(
code=MCPErrorCode.MCP_AUTH_MISSING,
what="Could not connect to server 'jira'",
why="JIRA_API_TOKEN is not set",
fix="Run: hive mcp config jira --set JIRA_API_TOKEN=<token>",
)
assert err.code == MCPErrorCode.MCP_AUTH_MISSING
def test_mcp_error_message_format():
err = MCPError(
code=MCPErrorCode.MCP_AUTH_MISSING,
what="Could not connect to server 'jira'",
why="JIRA_API_TOKEN is not set",
fix="Run: hive mcp config jira --set JIRA_API_TOKEN=<token>",
)
expected = (
"[MCP_AUTH_MISSING]\n"
"What failed: Could not connect to server 'jira'\n"
"Why: JIRA_API_TOKEN is not set\n"
"Fix: Run: hive mcp config jira --set JIRA_API_TOKEN=<token>"
)
assert str(err) == expected
def test_mcp_tool_not_found_error():
err = MCPToolNotFoundError(server="github", tool_name="create_pr")
assert err.code == MCPErrorCode.MCP_TOOL_NOT_FOUND
assert "create_pr" in str(err)
assert "github" in str(err)
def test_mcp_auth_error():
err = MCPAuthError(server="jira", env_var="JIRA_API_TOKEN")
assert err.code == MCPErrorCode.MCP_AUTH_MISSING
assert "JIRA_API_TOKEN" in str(err)
def test_mcp_client_raises_structured_error_for_missing_tool():
from framework.runner.mcp_client import MCPClient, MCPServerConfig
config = MCPServerConfig(name="test-server", transport="stdio")
client = MCPClient(config)
client._connected = True
client._tools = {} # empty — no tools registered
with pytest.raises(MCPToolNotFoundError) as exc_info:
client.call_tool("nonexistent_tool", {})
assert exc_info.value.code == MCPErrorCode.MCP_TOOL_NOT_FOUND
assert "test-server" in str(exc_info.value)
assert "nonexistent_tool" in str(exc_info.value)
+6 -5
View File
@@ -619,8 +619,9 @@ def test_load_agent_selection(tmp_path: Path):
agent_dir = tmp_path / "agent"
agent_dir.mkdir()
(agent_dir / "mcp_registry.json").write_text(json.dumps({"include": ["jira", "slack"]}))
dicts = registry.load_agent_selection(agent_dir)
assert len(dicts) == 2 and all("transport" in d for d in dicts)
dicts, max_tools = registry.load_agent_selection(agent_dir)
assert len(dicts) == 2 and max_tools is None
assert all("transport" in d for d in dicts)
def test_load_agent_selection_no_file(tmp_path: Path):
@@ -628,7 +629,7 @@ def test_load_agent_selection_no_file(tmp_path: Path):
registry.initialize()
agent_dir = tmp_path / "agent"
agent_dir.mkdir()
assert registry.load_agent_selection(agent_dir) == []
assert registry.load_agent_selection(agent_dir) == ([], None)
@pytest.mark.parametrize(
@@ -648,9 +649,9 @@ def test_load_agent_selection_rejects_wrong_types(tmp_path: Path, field, bad_val
agent_dir = tmp_path / "agent"
agent_dir.mkdir()
(agent_dir / "mcp_registry.json").write_text(json.dumps({field: bad_value}))
configs = registry.load_agent_selection(agent_dir)
configs, max_tools = registry.load_agent_selection(agent_dir)
# All bad fields are dropped, so resolve_for_agent gets no criteria and returns []
assert configs == []
assert configs == [] and max_tools is None
# ── run_health_check ────────────────────────────────────────────────
File diff suppressed because it is too large Load Diff
+140
View File
@@ -0,0 +1,140 @@
from __future__ import annotations
from typing import Any
from framework.runner.mcp_client import MCPTool
from framework.runner.tool_registry import ToolRegistry
def _patch_connection_manager_for_fake_stdio(monkeypatch, tool_map: dict[str, list[str]]) -> None:
"""Avoid spawning real stdio MCP processes; return in-memory clients per server name."""
class FakeMCPClient:
def __init__(self, config: Any):
self.config = config
def connect(self) -> None:
return
def disconnect(self) -> None:
return
def list_tools(self) -> list[MCPTool]:
names = tool_map.get(self.config.name, [])
return [_make_tool(n, self.config.name) for n in names]
def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
raise NotImplementedError
class FakeManager:
def acquire(self, config: Any) -> FakeMCPClient:
return FakeMCPClient(config)
def release(self, _server_name: str) -> None:
return
monkeypatch.setattr(
"framework.runner.mcp_connection_manager.MCPConnectionManager.get_instance",
lambda: FakeManager(),
)
def _make_tool(name: str, server_name: str) -> MCPTool:
return MCPTool(
name=name,
description=f"{name} from {server_name}",
input_schema={"type": "object", "properties": {}, "required": []},
server_name=server_name,
)
def test_registry_first_wins_collisions(monkeypatch):
"""
When multiple registry servers expose the same tool name, the first server
in load order should win and later servers should not overwrite it.
"""
tool_map: dict[str, list[str]] = {
"s1": ["tool_common", "tool_hive"],
"s2": ["tool_common", "tool_coder"],
}
_patch_connection_manager_for_fake_stdio(monkeypatch, tool_map)
resolved_servers = [
{"name": "s1", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
{"name": "s2", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
]
registry = ToolRegistry()
registry.load_registry_servers(
resolved_servers,
log_summary=False,
preserve_existing_tools=True,
log_collisions=True,
)
assert registry.has_tool("tool_common") is True
assert registry.has_tool("tool_hive") is True
assert registry.has_tool("tool_coder") is True
assert registry.get_server_tool_names("s1") == {"tool_common", "tool_hive"}
assert registry.get_server_tool_names("s2") == {"tool_coder"}
def test_registry_precedence_over_existing_mcp_servers(monkeypatch):
"""Registry-loaded tools should not overwrite already registered MCP tools."""
tool_map: dict[str, list[str]] = {
"pre": ["tool_common", "tool_pre"],
"s1": ["tool_common", "tool_hive"],
"s2": ["tool_common", "tool_coder"],
}
_patch_connection_manager_for_fake_stdio(monkeypatch, tool_map)
resolved_servers = [
{"name": "s1", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
{"name": "s2", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
]
registry = ToolRegistry()
registry.register_mcp_server(
{"name": "pre", "transport": "stdio", "command": "fake", "args": [], "cwd": None}
)
registry.load_registry_servers(
resolved_servers,
log_summary=False,
preserve_existing_tools=True,
log_collisions=True,
)
assert registry.get_server_tool_names("pre") == {"tool_common", "tool_pre"}
assert registry.get_server_tool_names("s1") == {"tool_hive"}
assert registry.get_server_tool_names("s2") == {"tool_coder"}
def test_registry_max_tools_cap(monkeypatch):
"""max_tools caps the total number of newly added tools from registry servers."""
tool_map: dict[str, list[str]] = {
"s1": ["tool_a", "tool_b"],
"s2": ["tool_c"],
}
_patch_connection_manager_for_fake_stdio(monkeypatch, tool_map)
resolved_servers = [
{"name": "s1", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
{"name": "s2", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
]
registry = ToolRegistry()
registry.load_registry_servers(
resolved_servers,
log_summary=False,
preserve_existing_tools=True,
max_tools=2,
)
assert registry.has_tool("tool_a") is True
assert registry.has_tool("tool_b") is True
assert registry.has_tool("tool_c") is False
+105 -70
View File
@@ -1,7 +1,8 @@
"""
Tests for path traversal vulnerability fix in FileStorage.
Tests for path traversal vulnerability protection in ConcurrentStorage.
Verifies that the _validate_key() method properly blocks path traversal attempts.
Verifies that the _validate_key() method properly blocks path traversal
attempts and that the public storage API enforces these checks end-to-end.
"""
import tempfile
@@ -9,23 +10,22 @@ from pathlib import Path
import pytest
from framework.storage.backend import FileStorage
from framework.storage.concurrent import ConcurrentStorage
class TestPathTraversalProtection:
"""Tests for path traversal vulnerability protection."""
"""Tests for path traversal vulnerability protection in ConcurrentStorage."""
@pytest.fixture
def storage(self):
"""Create a temporary storage instance for testing."""
"""Create a temporary ConcurrentStorage instance for testing."""
with tempfile.TemporaryDirectory() as tmpdir:
yield FileStorage(tmpdir)
yield ConcurrentStorage(tmpdir)
# === VALID KEYS (should pass validation) ===
def test_valid_alphanumeric_key(self, storage):
"""Alphanumeric keys should be allowed."""
# Should not raise
storage._validate_key("goal_123")
storage._validate_key("run_abc_def")
storage._validate_key("status_completed")
@@ -40,7 +40,6 @@ class TestPathTraversalProtection:
def test_blocks_parent_directory_traversal(self, storage):
"""Block .. path traversal attempts."""
# These all have path separators which are blocked first
with pytest.raises(ValueError):
storage._validate_key("../../../etc/passwd")
@@ -55,13 +54,12 @@ class TestPathTraversalProtection:
with pytest.raises(ValueError, match="path traversal detected"):
storage._validate_key(".env")
# This also has path separator which is caught first
# Also has a path separator which is caught first
with pytest.raises(ValueError):
storage._validate_key(".ssh/id_rsa")
def test_blocks_absolute_paths_unix(self, storage):
"""Block absolute paths (Unix)."""
# These have path separators which are blocked first
with pytest.raises(ValueError):
storage._validate_key("/etc/passwd")
@@ -70,7 +68,6 @@ class TestPathTraversalProtection:
def test_blocks_absolute_paths_windows(self, storage):
"""Block absolute paths (Windows)."""
# These have path separators which are blocked first
with pytest.raises(ValueError):
storage._validate_key("C:\\Windows\\System32")
@@ -115,68 +112,76 @@ class TestPathTraversalProtection:
with pytest.raises(ValueError, match="empty"):
storage._validate_key(" ")
# === END-TO-END TESTS ===
# === END-TO-END TESTS (public API enforces validation) ===
def test_get_runs_by_goal_blocks_traversal(self, storage):
"""get_runs_by_goal() should block path traversal."""
@pytest.mark.asyncio
async def test_load_run_blocks_traversal(self, storage):
"""load_run() must reject path traversal in the run_id."""
with pytest.raises(ValueError):
storage.get_runs_by_goal("../../../.env")
await storage.load_run("../../../.env")
def test_get_runs_by_node_blocks_traversal(self, storage):
"""get_runs_by_node() should block path traversal."""
@pytest.mark.asyncio
async def test_load_run_valid_id_returns_none(self, storage):
"""A valid but nonexistent run_id returns None, not an error."""
result = await storage.load_run("legitimate_run_id", use_cache=False)
assert result is None
@pytest.mark.asyncio
async def test_delete_run_blocks_traversal(self, storage):
"""delete_run() must reject path traversal in the run_id."""
with pytest.raises(ValueError):
storage.get_runs_by_node("/etc/passwd")
await storage.delete_run("../etc/passwd")
def test_get_runs_by_status_blocks_traversal(self, storage):
"""get_runs_by_status() should block path traversal."""
@pytest.mark.asyncio
async def test_load_summary_blocks_traversal(self, storage):
"""load_summary() must reject path traversal in the run_id."""
with pytest.raises(ValueError):
storage.get_runs_by_status("..\\..\\windows\\system32")
await storage.load_summary("../../../.env")
def test_valid_queries_still_work(self, storage):
"""Valid queries should work after fix."""
# These should return empty list, not raise errors
result = storage.get_runs_by_goal("legitimate_goal")
assert result == []
result = storage.get_runs_by_node("legitimate_node")
assert result == []
result = storage.get_runs_by_status("completed")
assert result == []
# === REAL-WORLD ATTACK SCENARIOS ===
def test_blocks_env_file_escape(self, storage):
"""Block attempts to access .env files."""
def test_load_run_sync_blocks_traversal(self, storage):
"""load_run_sync() must reject path traversal in the run_id."""
with pytest.raises(ValueError):
storage.get_runs_by_goal("../../../.env")
storage.load_run_sync("../../../.env")
def test_blocks_config_file_escape(self, storage):
"""Block attempts to access config files."""
with pytest.raises(ValueError):
storage.get_runs_by_goal("../../../../etc/aden/database.yaml")
def test_save_run_sync_blocks_traversal(self, storage):
"""save_run_sync() must reject path traversal in the run_id."""
from framework.schemas.run import Run
def test_blocks_web_shell_creation(self, storage):
"""Block attempts to create web shells."""
run = Run(id="../../../.env", goal_id="test", goal_description="", input_data={})
with pytest.raises(ValueError):
storage._add_to_index("by_goal", "../../var/www/html/shell", "malicious_code")
storage.save_run_sync(run)
def test_blocks_cron_injection(self, storage):
"""Block attempts to create cron jobs."""
with pytest.raises(ValueError):
storage._add_to_index("by_node", "../../../etc/cron.d/backdoor", "reverse_shell")
def test_load_run_sync_valid_id_returns_none(self, storage):
"""load_run_sync with a legitimate nonexistent ID returns None."""
result = storage.load_run_sync("legitimate_run_id")
assert result is None
def test_blocks_sudoers_modification(self, storage):
"""Block attempts to modify sudoers file."""
# === REAL-WORLD ATTACK SCENARIOS (end-to-end) ===
def test_blocks_env_file_escape_via_load_sync(self, storage):
"""Block attempts to read .env files via load_run_sync."""
with pytest.raises(ValueError):
storage._add_to_index("by_status", "../../../../etc/sudoers", "ALL=(ALL) NOPASSWD:ALL")
storage.load_run_sync("../../../.env")
def test_blocks_config_file_escape_via_load_sync(self, storage):
"""Block attempts to access config files via load_run_sync."""
with pytest.raises(ValueError):
storage.load_run_sync("../../../../etc/aden/database.yaml")
def test_blocks_arbitrary_write_via_save_sync(self, storage):
"""Block attempts to write arbitrary files via save_run_sync."""
from framework.schemas.run import Run
run = Run(id="../../var/www/html/shell", goal_id="test", goal_description="", input_data={})
with pytest.raises(ValueError):
storage.save_run_sync(run)
class TestPathTraversalWithActualFiles:
"""Test path traversal protection with actual file operations."""
def test_cannot_escape_storage_directory(self):
"""Verify that even with path traversal, we can't escape storage dir."""
"""Verify that path traversal is caught before any filesystem access."""
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
storage_dir = tmpdir_path / "storage"
@@ -186,31 +191,61 @@ class TestPathTraversalWithActualFiles:
secret_file = tmpdir_path / "secret.txt"
secret_file.write_text("SENSITIVE_DATA", encoding="utf-8")
storage = FileStorage(storage_dir)
storage = ConcurrentStorage(storage_dir)
# Attempt to read the secret file via path traversal
# Attempt to read the secret file via path traversal — must raise
with pytest.raises(ValueError):
storage.get_runs_by_goal("../secret")
storage.load_run_sync("../secret")
# Verify the secret file was not accessed (still contains original data)
# Verify the secret file was not accessed
assert secret_file.read_text(encoding="utf-8") == "SENSITIVE_DATA"
def test_cannot_write_outside_storage(self):
"""Verify that we can't write files outside storage directory."""
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
storage_dir = tmpdir_path / "storage"
storage_dir.mkdir()
def test_save_and_load_roundtrip(self, tmp_path):
"""Verify save_run_sync/load_run_sync roundtrip works correctly."""
from framework.schemas.run import Run, RunStatus
storage = FileStorage(storage_dir)
storage = ConcurrentStorage(tmp_path)
run = Run(
id="run_test_123",
goal_id="goal_abc",
goal_description="Integration test",
input_data={},
)
run.complete(RunStatus.COMPLETED, "done")
# Attempt to write outside storage directory
with pytest.raises(ValueError):
storage._add_to_index("by_goal", "../../malicious", "payload")
storage.save_run_sync(run)
# Verify no file was created outside storage
malicious_file = tmpdir_path / "malicious.json"
assert not malicious_file.exists()
loaded = storage.load_run_sync("run_test_123")
assert loaded is not None
assert loaded.id == "run_test_123"
assert loaded.status == RunStatus.COMPLETED
# Verify the file is at the expected path
run_file = tmp_path / "runs" / "run_test_123.json"
assert run_file.exists()
class TestSessionStorePathTraversal:
"""Path traversal protection in SessionStore.get_session_path()."""
@pytest.fixture
def store(self, tmp_path):
from framework.storage.session_store import SessionStore
return SessionStore(tmp_path)
def test_valid_session_id(self, store):
path = store.get_session_path("session_20260206_143022_abc12345")
assert path.name == "session_20260206_143022_abc12345"
def test_blocks_parent_traversal(self, store):
with pytest.raises(ValueError, match="Invalid session ID"):
store.get_session_path("../../etc/passwd")
@pytest.mark.asyncio
async def test_delete_session_blocks_traversal(self, store):
with pytest.raises(ValueError, match="Invalid session ID"):
await store.delete_session("../../package")
if __name__ == "__main__":
+35
View File
@@ -0,0 +1,35 @@
from datetime import date
from framework.agents.queen import queen_memory
from framework.tools.queen_memory_tools import recall_diary
def test_format_memory_date_uses_unpadded_day() -> None:
assert queen_memory.format_memory_date(date(2026, 3, 7)) == "March 7, 2026"
def test_format_for_injection_formats_recent_memory(monkeypatch) -> None:
monkeypatch.setattr(queen_memory, "read_semantic_memory", lambda: "")
monkeypatch.setattr(
queen_memory,
"_find_recent_episodic",
lambda lookback=7: (date(2026, 3, 7), "Remembered context."),
)
result = queen_memory.format_for_injection()
assert "## March 7, 2026" in result
assert "Remembered context." in result
def test_recall_diary_formats_today_without_platform_specific_strftime(monkeypatch) -> None:
monkeypatch.setattr(
queen_memory,
"read_episodic_memory",
lambda d=None: "Today's note." if d == date.today() else "",
)
result = recall_diary(days_back=1)
assert "## Today" in result
assert "Today's note." in result
+10 -9
View File
@@ -37,20 +37,21 @@ class TestRuntimeBasics:
runtime.end_run(success=True)
assert runtime.current_run is None
@pytest.mark.skip(
reason="FileStorage.save_run() is deprecated and now a no-op. "
"New sessions use unified storage at sessions/{session_id}/state.json"
)
def test_run_saved_on_end(self, tmp_path: Path):
"""Run is saved to storage when ended."""
"""Run is persisted to disk when ended.
ConcurrentStorage.save_run_sync() writes to runs/{run_id}.json
via an atomic temp-file+rename. This is the primary guardrail
ensuring end_run() does not silently discard completed runs.
"""
runtime = Runtime(tmp_path)
run_id = runtime.start_run("test_goal", "Test")
runtime.end_run(success=True)
# Check file exists
# ConcurrentStorage writes to {base_path}/runs/{run_id}.json
run_file = tmp_path / "runs" / f"{run_id}.json"
assert run_file.exists()
assert run_file.exists(), f"Expected persisted run at {run_file}"
class TestDecisionRecording:
@@ -346,7 +347,7 @@ class TestNarrativeGeneration:
"""Test automatic narrative generation."""
@pytest.mark.skip(
reason="FileStorage.save_run() and get_runs_by_goal() are deprecated. "
reason="save_run() and get_runs_by_goal() are deprecated. "
"New sessions use unified storage at sessions/{session_id}/state.json"
)
def test_default_narrative_success(self, tmp_path: Path):
@@ -369,7 +370,7 @@ class TestNarrativeGeneration:
assert "completed successfully" in run.narrative
@pytest.mark.skip(
reason="FileStorage.save_run() and get_runs_by_goal() are deprecated. "
reason="save_run() and get_runs_by_goal() are deprecated. "
"New sessions use unified storage at sessions/{session_id}/state.json"
)
def test_default_narrative_failure(self, tmp_path: Path):
+579
View File
@@ -0,0 +1,579 @@
"""Integration tests for hive skill CLI command handlers.
Uses argparse.Namespace objects directly (not argv parsing) for concise tests.
"""
from __future__ import annotations
import json
from argparse import Namespace
from pathlib import Path
from unittest.mock import patch
from framework.skills.cli import (
cmd_skill_doctor,
cmd_skill_info,
cmd_skill_init,
cmd_skill_install,
cmd_skill_list,
cmd_skill_remove,
cmd_skill_search,
cmd_skill_test,
cmd_skill_validate,
)
def _make_valid_skill(parent: Path, name: str) -> Path:
"""Create a minimal valid skill in parent/name/SKILL.md."""
d = parent / name
d.mkdir(parents=True, exist_ok=True)
(d / "SKILL.md").write_text(
f"---\nname: {name}\ndescription: A test skill.\nlicense: MIT\n---\n\n## Body\n",
encoding="utf-8",
)
return d
class TestCmdSkillInit:
def test_creates_skill_md(self, tmp_path):
args = Namespace(skill_name="test-skill", target_dir=str(tmp_path))
result = cmd_skill_init(args)
assert result == 0
assert (tmp_path / "test-skill" / "SKILL.md").exists()
def test_skill_md_contains_name(self, tmp_path):
args = Namespace(skill_name="my-skill", target_dir=str(tmp_path))
cmd_skill_init(args)
content = (tmp_path / "my-skill" / "SKILL.md").read_text()
assert "name: my-skill" in content
def test_error_when_dir_exists(self, tmp_path, capsys):
(tmp_path / "existing").mkdir()
args = Namespace(skill_name="existing", target_dir=str(tmp_path))
result = cmd_skill_init(args)
assert result == 1
assert "already exists" in capsys.readouterr().err
def test_error_when_no_name(self, tmp_path, monkeypatch, capsys):
# Non-interactive (stdin not a tty in test env) → error
monkeypatch.setattr("sys.stdin.isatty", lambda: False)
args = Namespace(skill_name=None, target_dir=str(tmp_path))
result = cmd_skill_init(args)
assert result == 1
class TestCmdSkillValidate:
def test_exits_0_on_valid_skill(self, tmp_path):
skill_dir = _make_valid_skill(tmp_path, "my-skill")
args = Namespace(path=str(skill_dir / "SKILL.md"))
result = cmd_skill_validate(args)
assert result == 0
def test_accepts_directory_path(self, tmp_path):
skill_dir = _make_valid_skill(tmp_path, "my-skill")
args = Namespace(path=str(skill_dir))
result = cmd_skill_validate(args)
assert result == 0
def test_exits_1_on_invalid_skill(self, tmp_path, capsys):
skill_dir = tmp_path / "bad-skill"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text("no frontmatter here", encoding="utf-8")
args = Namespace(path=str(skill_dir / "SKILL.md"))
result = cmd_skill_validate(args)
assert result == 1
assert "[ERROR]" in capsys.readouterr().out
def test_shows_warnings_on_valid_skill_without_license(self, tmp_path, capsys):
skill_dir = tmp_path / "my-skill"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text(
"---\nname: my-skill\ndescription: No license.\n---\n\n## Body\n",
encoding="utf-8",
)
args = Namespace(path=str(skill_dir / "SKILL.md"))
result = cmd_skill_validate(args)
assert result == 0
assert "[WARN]" in capsys.readouterr().out
class TestCmdSkillDoctor:
def test_defaults_pass_against_real_framework_skills(self):
"""All 6 framework default skills should be healthy (no mocking)."""
args = Namespace(defaults=True, name=None, project_dir=None)
result = cmd_skill_doctor(args)
assert result == 0
def test_named_skill_not_found_exits_1(self, tmp_path, capsys):
args = Namespace(name="nonexistent-skill", defaults=False, project_dir=str(tmp_path))
result = cmd_skill_doctor(args)
assert result == 1
assert "not found" in capsys.readouterr().err
def test_healthy_skill_exits_0(self, tmp_path):
_make_valid_skill(tmp_path, "my-skill")
args = Namespace(name=None, defaults=False, project_dir=str(tmp_path))
with patch("framework.skills.discovery.SkillDiscovery.discover") as mock_discover:
from framework.skills.parser import ParsedSkill
mock_discover.return_value = [
ParsedSkill(
name="my-skill",
description="Test.",
location=str(tmp_path / "my-skill" / "SKILL.md"),
base_dir=str(tmp_path / "my-skill"),
source_scope="user",
body="## Body",
)
]
result = cmd_skill_doctor(args)
assert result == 0
class TestCmdSkillInstall:
def test_shows_security_notice_on_first_use(self, tmp_path, monkeypatch, capsys):
sentinel = tmp_path / ".install_notice_shown"
monkeypatch.setattr("framework.skills.installer.INSTALL_NOTICE_SENTINEL", sentinel)
installed_path = tmp_path / "skills" / "my-skill"
installed_path.mkdir(parents=True)
args = Namespace(
name_or_url=None,
from_url="https://example.com/skill.git",
pack=None,
install_name="my-skill",
version=None,
)
with patch("framework.skills.installer.install_from_git", return_value=installed_path):
with patch("shutil.which", return_value="/usr/bin/git"):
result = cmd_skill_install(args)
captured = capsys.readouterr()
assert "Security Notice" in captured.out
assert result == 0
def test_install_from_url_calls_install_from_git(self, tmp_path, monkeypatch):
sentinel = tmp_path / ".install_notice_shown"
sentinel.parent.mkdir(parents=True, exist_ok=True)
sentinel.touch()
monkeypatch.setattr("framework.skills.installer.INSTALL_NOTICE_SENTINEL", sentinel)
installed_path = tmp_path / "skills" / "my-skill"
installed_path.mkdir(parents=True)
args = Namespace(
name_or_url=None,
from_url="https://github.com/org/my-skill.git",
pack=None,
install_name=None,
version=None,
)
with patch(
"framework.skills.installer.install_from_git", return_value=installed_path
) as mock_install:
result = cmd_skill_install(args)
mock_install.assert_called_once()
assert result == 0
def test_registry_not_found_exits_1(self, tmp_path, monkeypatch, capsys):
sentinel = tmp_path / ".install_notice_shown"
sentinel.parent.mkdir(parents=True, exist_ok=True)
sentinel.touch()
monkeypatch.setattr("framework.skills.installer.INSTALL_NOTICE_SENTINEL", sentinel)
args = Namespace(
name_or_url="nonexistent-skill",
from_url=None,
pack=None,
install_name=None,
version=None,
)
with patch("framework.skills.registry.RegistryClient.get_skill_entry", return_value=None):
result = cmd_skill_install(args)
assert result == 1
assert "not found in registry" in capsys.readouterr().err
def test_no_args_exits_1(self, tmp_path, monkeypatch, capsys):
sentinel = tmp_path / ".install_notice_shown"
sentinel.parent.mkdir(parents=True, exist_ok=True)
sentinel.touch()
monkeypatch.setattr("framework.skills.installer.INSTALL_NOTICE_SENTINEL", sentinel)
args = Namespace(
name_or_url=None, from_url=None, pack=None, install_name=None, version=None
)
result = cmd_skill_install(args)
assert result == 1
class TestCmdSkillRemove:
def test_removes_installed_skill(self, tmp_path, capsys):
skills_dir = tmp_path / "skills"
skill_dir = skills_dir / "my-skill"
skill_dir.mkdir(parents=True)
with patch("framework.skills.installer.USER_SKILLS_DIR", skills_dir):
with patch("framework.skills.installer.remove_skill", return_value=True):
args = Namespace(name="my-skill")
result = cmd_skill_remove(args)
assert result == 0
assert "Removed" in capsys.readouterr().out
def test_exits_1_when_not_found(self, tmp_path, capsys):
with patch("framework.skills.installer.remove_skill", return_value=False):
args = Namespace(name="missing-skill")
result = cmd_skill_remove(args)
assert result == 1
assert "not found" in capsys.readouterr().err
class TestCmdSkillSearch:
def test_exits_1_when_registry_unavailable(self, capsys):
with patch("framework.skills.registry.RegistryClient.fetch_index", return_value=None):
args = Namespace(query="research")
result = cmd_skill_search(args)
assert result == 1
assert "registry unavailable" in capsys.readouterr().err.lower()
def test_prints_results_when_found(self, capsys):
mock_index = {
"skills": [
{
"name": "deep-research",
"description": "Multi-step research.",
"tags": ["research"],
"trust_tier": "official",
}
]
}
with patch("framework.skills.registry.RegistryClient.fetch_index", return_value=mock_index):
args = Namespace(query="research")
result = cmd_skill_search(args)
out = capsys.readouterr().out
assert result == 0
assert "deep-research" in out
def test_no_results_message(self, capsys):
mock_index = {"skills": []}
with patch("framework.skills.registry.RegistryClient.fetch_index", return_value=mock_index):
args = Namespace(query="xyzzy-nothing")
result = cmd_skill_search(args)
assert result == 0
assert "No skills found" in capsys.readouterr().out
class TestCmdSkillInfo:
def test_shows_locally_installed_skill(self, tmp_path, capsys):
skill_dir = _make_valid_skill(tmp_path, "my-skill")
from framework.skills.parser import ParsedSkill
mock_skill = ParsedSkill(
name="my-skill",
description="A test skill.",
location=str(skill_dir / "SKILL.md"),
base_dir=str(skill_dir),
source_scope="user",
body="## Body",
license="MIT",
)
with patch("framework.skills.discovery.SkillDiscovery.discover", return_value=[mock_skill]):
args = Namespace(name="my-skill", project_dir=str(tmp_path))
result = cmd_skill_info(args)
out = capsys.readouterr().out
assert result == 0
assert "my-skill" in out
assert "A test skill." in out
def test_falls_back_to_registry_when_not_installed(self, capsys):
registry_entry = {
"name": "deep-research",
"description": "Multi-step research.",
"version": "1.0.0",
"author": "anthropics",
"trust_tier": "official",
}
with patch("framework.skills.discovery.SkillDiscovery.discover", return_value=[]):
with patch(
"framework.skills.registry.RegistryClient.get_skill_entry",
return_value=registry_entry,
):
args = Namespace(name="deep-research", project_dir=None)
result = cmd_skill_info(args)
out = capsys.readouterr().out
assert result == 0
assert "not installed" in out
assert "deep-research" in out
def test_exits_1_when_not_found_anywhere(self, tmp_path, capsys):
with patch("framework.skills.discovery.SkillDiscovery.discover", return_value=[]):
with patch(
"framework.skills.registry.RegistryClient.get_skill_entry", return_value=None
):
args = Namespace(name="ghost-skill", project_dir=str(tmp_path))
result = cmd_skill_info(args)
assert result == 1
class TestJsonFlag:
def test_list_json_produces_valid_json(self, tmp_path, capsys):
args = Namespace(project_dir=str(tmp_path), json=True)
with patch("framework.skills.discovery.SkillDiscovery.discover", return_value=[]):
result = cmd_skill_list(args)
out = capsys.readouterr().out
data = json.loads(out)
assert result == 0
assert "skills" in data
assert isinstance(data["skills"], list)
def test_validate_json_valid_skill(self, tmp_path, capsys):
from framework.skills.cli import cmd_skill_validate
skill_dir = _make_valid_skill(tmp_path, "my-skill")
args = Namespace(path=str(skill_dir / "SKILL.md"), json=True)
result = cmd_skill_validate(args)
out = capsys.readouterr().out
data = json.loads(out)
assert result == 0
assert data["passed"] is True
assert data["errors"] == []
assert "warnings" in data
def test_doctor_defaults_json(self, capsys):
args = Namespace(defaults=True, name=None, project_dir=None, json=True)
result = cmd_skill_doctor(args)
out = capsys.readouterr().out
data = json.loads(out)
assert result == 0
assert "skills" in data
assert len(data["skills"]) == 6 # 6 framework default skills
assert data["total_errors"] == 0
def test_search_json_registry_unavailable_exits_1(self, capsys):
with patch("framework.skills.registry.RegistryClient.fetch_index", return_value=None):
args = Namespace(query="research", json=True)
result = cmd_skill_search(args)
out = capsys.readouterr().out
data = json.loads(out)
assert result == 1
assert "error" in data
def test_remove_json_not_found_exits_1(self, capsys):
with patch("framework.skills.installer.remove_skill", return_value=False):
args = Namespace(name="ghost-skill", json=True)
result = cmd_skill_remove(args)
out = capsys.readouterr().out
data = json.loads(out)
assert result == 1
assert "error" in data
class TestCmdSkillTest:
"""Tests for hive skill test (CLI-9)."""
def test_structural_only_valid_exits_0(self, tmp_path):
skill_dir = _make_valid_skill(tmp_path, "my-skill")
args = Namespace(path=str(skill_dir), input_json=None, model=None, json=False)
result = cmd_skill_test(args)
assert result == 0
def test_structural_invalid_exits_1(self, tmp_path, capsys):
skill_dir = tmp_path / "bad-skill"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text("no frontmatter", encoding="utf-8")
args = Namespace(path=str(skill_dir), input_json=None, model=None, json=False)
result = cmd_skill_test(args)
assert result == 1
assert "[ERROR]" in capsys.readouterr().out
def test_invocation_mode_calls_provider_with_skill_body(self, tmp_path):
skill_dir = _make_valid_skill(tmp_path, "my-skill")
from unittest.mock import MagicMock
from framework.llm.provider import LLMResponse
mock_response = LLMResponse(content="Hello!", model="claude-haiku-4-5-20251001")
mock_provider = MagicMock()
mock_provider.complete.return_value = mock_response
args = Namespace(
path=str(skill_dir), input_json='{"prompt": "say hello"}', model=None, json=False
)
with patch("framework.llm.anthropic.AnthropicProvider", return_value=mock_provider):
result = cmd_skill_test(args)
assert result == 0
call_kwargs = mock_provider.complete.call_args
assert call_kwargs is not None
# system should be the skill body
assert "system" in call_kwargs.kwargs or len(call_kwargs.args) >= 2
def test_invocation_extracts_prompt_from_json(self, tmp_path):
skill_dir = _make_valid_skill(tmp_path, "my-skill")
from unittest.mock import MagicMock
from framework.llm.provider import LLMResponse
mock_provider = MagicMock()
mock_provider.complete.return_value = LLMResponse(
content="response", model="claude-haiku-4-5-20251001"
)
args = Namespace(
path=str(skill_dir), input_json='{"prompt": "extracted prompt"}', model=None, json=False
)
with patch("framework.llm.anthropic.AnthropicProvider", return_value=mock_provider):
cmd_skill_test(args)
call = mock_provider.complete.call_args
messages = call.kwargs.get("messages") or (call.args[0] if call.args else [])
assert any("extracted prompt" in m.get("content", "") for m in messages)
def test_eval_suite_all_pass_exits_0(self, tmp_path):
skill_dir = _make_valid_skill(tmp_path, "my-skill")
evals_dir = skill_dir / "evals"
evals_dir.mkdir()
(evals_dir / "evals.json").write_text(
json.dumps(
{
"skill_name": "my-skill",
"evals": [
{"id": 1, "prompt": "Say hi.", "assertions": ["Response is a greeting"]}
],
}
),
encoding="utf-8",
)
from unittest.mock import MagicMock
from framework.llm.provider import LLMResponse
mock_provider = MagicMock()
mock_provider.complete.return_value = LLMResponse(
content="Hello!", model="claude-haiku-4-5-20251001"
)
mock_judge = MagicMock()
mock_judge.evaluate.return_value = {"passes": True, "explanation": "Looks good."}
args = Namespace(path=str(skill_dir), input_json=None, model=None, json=False)
with patch("framework.llm.anthropic.AnthropicProvider", return_value=mock_provider):
with patch("framework.testing.llm_judge.LLMJudge", return_value=mock_judge):
result = cmd_skill_test(args)
assert result == 0
def test_eval_any_fail_exits_1(self, tmp_path):
skill_dir = _make_valid_skill(tmp_path, "my-skill")
evals_dir = skill_dir / "evals"
evals_dir.mkdir()
(evals_dir / "evals.json").write_text(
json.dumps(
{
"skill_name": "my-skill",
"evals": [
{"id": 1, "prompt": "Say hi.", "assertions": ["Impossible assertion"]}
],
}
),
encoding="utf-8",
)
from unittest.mock import MagicMock
from framework.llm.provider import LLMResponse
mock_provider = MagicMock()
mock_provider.complete.return_value = LLMResponse(
content="Hello!", model="claude-haiku-4-5-20251001"
)
mock_judge = MagicMock()
mock_judge.evaluate.return_value = {"passes": False, "explanation": "Did not satisfy."}
args = Namespace(path=str(skill_dir), input_json=None, model=None, json=False)
with patch("framework.llm.anthropic.AnthropicProvider", return_value=mock_provider):
with patch("framework.testing.llm_judge.LLMJudge", return_value=mock_judge):
result = cmd_skill_test(args)
assert result == 1
def test_json_flag_structural_output(self, tmp_path, capsys):
skill_dir = _make_valid_skill(tmp_path, "my-skill")
args = Namespace(path=str(skill_dir), input_json=None, model=None, json=True)
result = cmd_skill_test(args)
out = capsys.readouterr().out
data = json.loads(out)
assert result == 0
assert "structural" in data
assert data["structural"]["passed"] is True
assert data["skill"] == "my-skill"
def test_json_flag_eval_results(self, tmp_path, capsys):
skill_dir = _make_valid_skill(tmp_path, "my-skill")
evals_dir = skill_dir / "evals"
evals_dir.mkdir()
(evals_dir / "evals.json").write_text(
json.dumps(
{
"skill_name": "my-skill",
"evals": [{"id": 1, "prompt": "Hi.", "assertions": ["Is a greeting"]}],
}
),
encoding="utf-8",
)
from unittest.mock import MagicMock
from framework.llm.provider import LLMResponse
mock_provider = MagicMock()
mock_provider.complete.return_value = LLMResponse(
content="Hello!", model="claude-haiku-4-5-20251001"
)
mock_judge = MagicMock()
mock_judge.evaluate.return_value = {"passes": True, "explanation": "Yes."}
args = Namespace(path=str(skill_dir), input_json=None, model=None, json=True)
with patch("framework.llm.anthropic.AnthropicProvider", return_value=mock_provider):
with patch("framework.testing.llm_judge.LLMJudge", return_value=mock_judge):
result = cmd_skill_test(args)
out = capsys.readouterr().out
data = json.loads(out)
assert result == 0
assert "evals" in data
assert data["total_passed"] == 1
assert data["total_failed"] == 0
def test_no_api_key_with_evals_degrades_gracefully(self, tmp_path, capsys):
"""No API key + evals present → structural checks pass, skip LLM, exit 0."""
skill_dir = _make_valid_skill(tmp_path, "my-skill")
(skill_dir / "evals").mkdir()
(skill_dir / "evals" / "evals.json").write_text(
json.dumps({"skill_name": "my-skill", "evals": []}), encoding="utf-8"
)
args = Namespace(path=str(skill_dir), input_json=None, model=None, json=False)
with patch(
"framework.llm.anthropic.AnthropicProvider",
side_effect=ValueError("ANTHROPIC_API_KEY not set"),
):
result = cmd_skill_test(args)
assert result == 0
assert "ANTHROPIC_API_KEY" in capsys.readouterr().err
+248
View File
@@ -0,0 +1,248 @@
"""Tests for skill install, remove, and fork operations."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import patch
import pytest
from framework.skills.installer import (
fork_skill,
install_from_git,
maybe_show_install_notice,
remove_skill,
)
from framework.skills.parser import ParsedSkill
from framework.skills.skill_errors import SkillError
def _make_skill_dir(parent: Path, name: str, body: str = "## Instructions\n\nDo things.") -> Path:
"""Create a minimal skill directory with a valid SKILL.md."""
skill_dir = parent / name
skill_dir.mkdir(parents=True, exist_ok=True)
(skill_dir / "SKILL.md").write_text(
f"---\nname: {name}\ndescription: A test skill.\n---\n\n{body}\n",
encoding="utf-8",
)
return skill_dir
def _make_parsed_skill(base_dir: Path, name: str) -> ParsedSkill:
"""Create a ParsedSkill pointing to base_dir."""
return ParsedSkill(
name=name,
description="Test skill.",
location=str(base_dir / "SKILL.md"),
base_dir=str(base_dir),
source_scope="user",
body="## Instructions",
)
class TestInstallFromGit:
def test_copies_skill_dir_to_target(self, tmp_path):
"""Successful clone copies skill directory to target."""
source_repo = tmp_path / "repo"
_make_skill_dir(source_repo, ".") # SKILL.md at repo root
target = tmp_path / "skills"
def fake_clone(git_url, target_path, version=None):
# Simulate git clone by copying source_repo into target_path
import shutil
if target_path.exists():
shutil.rmtree(target_path)
shutil.copytree(source_repo, target_path)
with patch("framework.skills.installer._git_clone_shallow", side_effect=fake_clone):
with patch("shutil.which", return_value="/usr/bin/git"):
dest = install_from_git(
git_url="https://example.com/skill.git",
skill_name="my-skill",
target_dir=target,
)
assert (dest / "SKILL.md").exists()
assert dest == target / "my-skill"
def test_raises_when_git_not_found(self, tmp_path):
with patch("shutil.which", return_value=None):
with pytest.raises(SkillError) as exc_info:
install_from_git(
git_url="https://example.com/skill.git",
skill_name="my-skill",
target_dir=tmp_path / "skills",
)
assert "git is not installed" in exc_info.value.why
def test_raises_when_skill_md_missing(self, tmp_path):
"""Clone succeeds but no SKILL.md in the subdirectory → error."""
empty_repo = tmp_path / "empty_repo"
empty_repo.mkdir()
def fake_clone(git_url, target_path, version=None):
import shutil
if target_path.exists():
shutil.rmtree(target_path)
shutil.copytree(empty_repo, target_path)
with patch("framework.skills.installer._git_clone_shallow", side_effect=fake_clone):
with patch("shutil.which", return_value="/usr/bin/git"):
with pytest.raises(SkillError) as exc_info:
install_from_git(
git_url="https://example.com/skill.git",
skill_name="my-skill",
subdirectory="deep-research",
target_dir=tmp_path / "skills",
)
assert exc_info.value.code.value == "SKILL_NOT_FOUND"
def test_raises_when_target_already_exists(self, tmp_path):
skills_dir = tmp_path / "skills"
(skills_dir / "existing-skill").mkdir(parents=True)
with patch("shutil.which", return_value="/usr/bin/git"):
with pytest.raises(SkillError) as exc_info:
install_from_git(
git_url="https://example.com/skill.git",
skill_name="existing-skill",
target_dir=skills_dir,
)
assert "already exists" in exc_info.value.why
def test_cleans_temp_dir_on_clone_failure(self, tmp_path):
"""Temporary directory is cleaned up even when clone fails."""
created_tmp_dirs = []
original_mkdtemp = __import__("tempfile").mkdtemp
def tracking_mkdtemp(**kwargs):
d = original_mkdtemp(**kwargs)
created_tmp_dirs.append(d)
return d
def failing_clone(git_url, target_path, version=None):
from framework.skills.skill_errors import SkillErrorCode as SEC
raise SkillError(
code=SEC.SKILL_ACTIVATION_FAILED,
what="clone failed",
why="network error",
fix="check network",
)
with patch("tempfile.mkdtemp", side_effect=tracking_mkdtemp):
with patch("framework.skills.installer._git_clone_shallow", side_effect=failing_clone):
with patch("shutil.which", return_value="/usr/bin/git"):
with pytest.raises(SkillError):
install_from_git(
git_url="https://example.com/skill.git",
skill_name="my-skill",
target_dir=tmp_path / "skills",
)
# All created temp dirs should be cleaned up
for d in created_tmp_dirs:
assert not Path(d).exists(), f"Temp dir not cleaned: {d}"
class TestRemoveSkill:
def test_removes_existing_skill(self, tmp_path):
skills_dir = tmp_path / "skills"
skill_dir = _make_skill_dir(skills_dir, "my-skill")
assert skill_dir.exists()
result = remove_skill("my-skill", skills_dir=skills_dir)
assert result is True
assert not skill_dir.exists()
def test_returns_false_when_not_found(self, tmp_path):
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
result = remove_skill("nonexistent", skills_dir=skills_dir)
assert result is False
def test_raises_on_permission_error(self, tmp_path):
skills_dir = tmp_path / "skills"
_make_skill_dir(skills_dir, "locked-skill")
with patch("shutil.rmtree", side_effect=OSError("permission denied")):
with pytest.raises(SkillError) as exc_info:
remove_skill("locked-skill", skills_dir=skills_dir)
assert "permission" in exc_info.value.why.lower()
class TestForkSkill:
def test_copies_skill_to_new_name(self, tmp_path):
source_dir = _make_skill_dir(tmp_path / "sources", "my-skill")
source = _make_parsed_skill(source_dir, "my-skill")
target_parent = tmp_path / "skills"
dest = fork_skill(source, "my-skill-fork", target_parent)
assert dest.exists()
assert (dest / "SKILL.md").exists()
def test_rewrites_name_in_skill_md(self, tmp_path):
source_dir = _make_skill_dir(tmp_path / "sources", "original")
source = _make_parsed_skill(source_dir, "original")
target_parent = tmp_path / "skills"
dest = fork_skill(source, "forked", target_parent)
import yaml
content = (dest / "SKILL.md").read_text(encoding="utf-8")
parts = content.split("---", 2)
fm = yaml.safe_load(parts[1])
assert fm["name"] == "forked"
def test_raises_when_dest_already_exists(self, tmp_path):
source_dir = _make_skill_dir(tmp_path / "sources", "my-skill")
source = _make_parsed_skill(source_dir, "my-skill")
target_parent = tmp_path / "skills"
(target_parent / "my-skill-fork").mkdir(parents=True)
with pytest.raises(SkillError) as exc_info:
fork_skill(source, "my-skill-fork", target_parent)
assert "already exists" in exc_info.value.why
def test_preserves_scripts_and_references(self, tmp_path):
source_dir = _make_skill_dir(tmp_path / "sources", "my-skill")
(source_dir / "scripts").mkdir()
(source_dir / "scripts" / "run.sh").write_text("#!/bin/sh\necho hi")
(source_dir / "references").mkdir()
(source_dir / "references" / "guide.md").write_text("# Guide")
source = _make_parsed_skill(source_dir, "my-skill")
target_parent = tmp_path / "skills"
dest = fork_skill(source, "fork", target_parent)
assert (dest / "scripts" / "run.sh").exists()
assert (dest / "references" / "guide.md").exists()
class TestInstallNotice:
def test_shown_on_first_call(self, tmp_path, monkeypatch, capsys):
sentinel = tmp_path / ".install_notice_shown"
monkeypatch.setattr("framework.skills.installer.INSTALL_NOTICE_SENTINEL", sentinel)
maybe_show_install_notice()
captured = capsys.readouterr()
assert "Security Notice" in captured.out
assert sentinel.exists()
def test_not_shown_on_second_call(self, tmp_path, monkeypatch, capsys):
sentinel = tmp_path / ".install_notice_shown"
sentinel.parent.mkdir(parents=True, exist_ok=True)
sentinel.touch()
monkeypatch.setattr("framework.skills.installer.INSTALL_NOTICE_SENTINEL", sentinel)
maybe_show_install_notice()
captured = capsys.readouterr()
assert "Security Notice" not in captured.out
+244
View File
@@ -0,0 +1,244 @@
"""Tests for the RegistryClient skill registry client."""
from __future__ import annotations
import json
from datetime import UTC, datetime, timedelta
from pathlib import Path
from unittest.mock import patch
from urllib.error import URLError
import pytest
from framework.skills.registry import _CACHE_TTL_SECONDS, RegistryClient
_SAMPLE_INDEX = {
"version": 1,
"skills": [
{
"name": "deep-research",
"description": "Multi-step web research with source verification.",
"version": "1.0.0",
"author": "anthropics",
"license": "MIT",
"tags": ["research", "web"],
"git_url": "https://github.com/anthropics/skills",
"subdirectory": "deep-research",
"trust_tier": "official",
},
{
"name": "code-review",
"description": "Automated code review for style and correctness.",
"version": "0.9.0",
"author": "contributor",
"tags": ["code", "review"],
"git_url": "https://github.com/contributor/code-review",
"subdirectory": None,
"trust_tier": "community",
},
],
"packs": [
{
"name": "research-starter",
"description": "Research-focused skill bundle",
"skills": ["deep-research"],
}
],
}
@pytest.fixture
def cache_dir(tmp_path):
return tmp_path / "registry_cache"
@pytest.fixture
def client(cache_dir):
return RegistryClient(registry_url="https://example.com/skill_index.json", cache_dir=cache_dir)
class TestFetchIndex:
def test_returns_none_on_network_error(self, client):
with patch.object(client, "_http_fetch", return_value=None):
result = client.fetch_index()
assert result is None
def test_returns_none_on_url_error(self, client):
with patch("framework.skills.registry.urlopen", side_effect=URLError("connection refused")):
result = client.fetch_index()
assert result is None
def test_fetches_and_caches_index(self, client):
raw = json.dumps(_SAMPLE_INDEX).encode()
with patch.object(client, "_http_fetch", return_value=raw):
result = client.fetch_index()
assert result is not None
assert len(result["skills"]) == 2
# Cache should be written
assert client._index_path.exists()
def test_uses_fresh_cache_without_network(self, client, cache_dir):
# Write fresh cache
cache_dir.mkdir(parents=True, exist_ok=True)
(cache_dir / "skill_index.json").write_text(json.dumps(_SAMPLE_INDEX))
meta = {"last_fetched": datetime.now(tz=UTC).isoformat()}
(cache_dir / "metadata.json").write_text(json.dumps(meta))
fetch_called = []
def _no_fetch(*a, **kw):
fetch_called.append(1)
with patch.object(client, "_http_fetch", side_effect=_no_fetch):
result = client.fetch_index()
assert not fetch_called, "Should not hit network when cache is fresh"
assert result is not None
def test_refreshes_when_cache_is_stale(self, client, cache_dir):
# Write stale cache (older than TTL)
cache_dir.mkdir(parents=True, exist_ok=True)
(cache_dir / "skill_index.json").write_text(json.dumps(_SAMPLE_INDEX))
old_time = (datetime.now(tz=UTC) - timedelta(seconds=_CACHE_TTL_SECONDS + 60)).isoformat()
meta = {"last_fetched": old_time}
(cache_dir / "metadata.json").write_text(json.dumps(meta))
raw = json.dumps(_SAMPLE_INDEX).encode()
with patch.object(client, "_http_fetch", return_value=raw) as mock_fetch:
client.fetch_index()
mock_fetch.assert_called_once()
def test_force_refresh_bypasses_fresh_cache(self, client, cache_dir):
cache_dir.mkdir(parents=True, exist_ok=True)
(cache_dir / "skill_index.json").write_text(json.dumps(_SAMPLE_INDEX))
meta = {"last_fetched": datetime.now(tz=UTC).isoformat()}
(cache_dir / "metadata.json").write_text(json.dumps(meta))
raw = json.dumps(_SAMPLE_INDEX).encode()
with patch.object(client, "_http_fetch", return_value=raw) as mock_fetch:
client.fetch_index(force_refresh=True)
mock_fetch.assert_called_once()
def test_falls_back_to_stale_cache_on_network_error(self, client, cache_dir):
cache_dir.mkdir(parents=True, exist_ok=True)
(cache_dir / "skill_index.json").write_text(json.dumps(_SAMPLE_INDEX))
# No metadata → stale
with patch.object(client, "_http_fetch", return_value=None):
result = client.fetch_index()
assert result is not None
assert result["version"] == 1
class TestSearch:
def test_filters_by_name(self, client):
with patch.object(client, "fetch_index", return_value=_SAMPLE_INDEX):
results = client.search("deep")
assert len(results) == 1
assert results[0]["name"] == "deep-research"
def test_filters_by_description(self, client):
with patch.object(client, "fetch_index", return_value=_SAMPLE_INDEX):
results = client.search("source verification")
assert any(r["name"] == "deep-research" for r in results)
def test_filters_by_tag(self, client):
with patch.object(client, "fetch_index", return_value=_SAMPLE_INDEX):
results = client.search("review")
assert any(r["name"] == "code-review" for r in results)
def test_case_insensitive(self, client):
with patch.object(client, "fetch_index", return_value=_SAMPLE_INDEX):
results = client.search("DEEP")
assert len(results) == 1
def test_returns_empty_when_unavailable(self, client):
with patch.object(client, "fetch_index", return_value=None):
results = client.search("anything")
assert results == []
def test_returns_empty_on_no_match(self, client):
with patch.object(client, "fetch_index", return_value=_SAMPLE_INDEX):
results = client.search("xyzzy-no-match")
assert results == []
class TestGetSkillEntry:
def test_finds_by_exact_name(self, client):
with patch.object(client, "fetch_index", return_value=_SAMPLE_INDEX):
entry = client.get_skill_entry("deep-research")
assert entry is not None
assert entry["name"] == "deep-research"
def test_returns_none_when_not_found(self, client):
with patch.object(client, "fetch_index", return_value=_SAMPLE_INDEX):
entry = client.get_skill_entry("nonexistent")
assert entry is None
def test_returns_none_when_index_unavailable(self, client):
with patch.object(client, "fetch_index", return_value=None):
entry = client.get_skill_entry("deep-research")
assert entry is None
class TestGetPack:
def test_returns_skill_names(self, client):
with patch.object(client, "fetch_index", return_value=_SAMPLE_INDEX):
skills = client.get_pack("research-starter")
assert skills == ["deep-research"]
def test_returns_none_when_pack_not_found(self, client):
with patch.object(client, "fetch_index", return_value=_SAMPLE_INDEX):
result = client.get_pack("nonexistent-pack")
assert result is None
def test_returns_none_when_index_unavailable(self, client):
with patch.object(client, "fetch_index", return_value=None):
result = client.get_pack("research-starter")
assert result is None
class TestResolveGitUrl:
def test_returns_git_url_and_subdirectory(self, client):
with patch.object(client, "fetch_index", return_value=_SAMPLE_INDEX):
result = client.resolve_git_url("deep-research")
assert result == ("https://github.com/anthropics/skills", "deep-research")
def test_returns_none_subdirectory_when_absent(self, client):
with patch.object(client, "fetch_index", return_value=_SAMPLE_INDEX):
result = client.resolve_git_url("code-review")
git_url, subdir = result
assert subdir is None
def test_returns_none_when_not_in_registry(self, client):
with patch.object(client, "fetch_index", return_value=_SAMPLE_INDEX):
result = client.resolve_git_url("not-there")
assert result is None
class TestCacheAtomicWrite:
def test_atomic_write_uses_tmp_then_replace(self, client, cache_dir, monkeypatch):
written_paths = []
original_write = Path.write_text
def tracking_write(self, data, encoding=None):
written_paths.append(str(self))
return original_write(self, data, encoding=encoding or "utf-8")
monkeypatch.setattr(Path, "write_text", tracking_write)
client._save_cache(_SAMPLE_INDEX)
# .tmp file should have been written (then replaced — may not exist now)
assert any(".tmp" in p for p in written_paths)
# Final index file should exist
assert client._index_path.exists()
def test_save_and_load_round_trip(self, client):
client._save_cache(_SAMPLE_INDEX)
loaded = client._load_cache()
assert loaded == _SAMPLE_INDEX
def test_load_returns_none_when_absent(self, client):
result = client._load_cache()
assert result is None
+401
View File
@@ -0,0 +1,401 @@
"""Tests for strict SKILL.md validation (hive skill validate).
One test per strict check happy path plus each individual failure mode.
"""
from __future__ import annotations
from pathlib import Path
from framework.skills.validator import validate_strict
def _write_skill(tmp_path: Path, content: str, dir_name: str = "my-skill") -> Path:
"""Write a SKILL.md in a named subdirectory and return the path."""
skill_dir = tmp_path / dir_name
skill_dir.mkdir(parents=True, exist_ok=True)
skill_md = skill_dir / "SKILL.md"
skill_md.write_text(content, encoding="utf-8")
return skill_md
_VALID_CONTENT = """\
---
name: my-skill
description: A test skill for validation.
version: 0.1.0
license: MIT
compatibility:
- claude-code
- hive
metadata:
tags: []
---
## Instructions
Do the thing properly.
"""
class TestHappyPath:
def test_valid_skill_passes(self, tmp_path):
path = _write_skill(tmp_path, _VALID_CONTENT)
result = validate_strict(path)
assert result.passed is True
assert result.errors == []
def test_namespace_prefix_name_allowed(self, tmp_path):
"""hive.my-skill with directory my-skill is valid."""
content = """\
---
name: hive.my-skill
description: A namespaced skill.
license: MIT
---
## Body
"""
path = _write_skill(tmp_path, content, dir_name="my-skill")
result = validate_strict(path)
assert result.passed is True
def test_warning_on_missing_license(self, tmp_path):
content = """\
---
name: my-skill
description: No license field.
---
## Body
"""
path = _write_skill(tmp_path, content)
result = validate_strict(path)
assert result.passed is True
assert any("license" in w.lower() for w in result.warnings)
class TestCheck1FileExists:
def test_error_on_missing_file(self, tmp_path):
path = tmp_path / "nonexistent" / "SKILL.md"
result = validate_strict(path)
assert result.passed is False
assert any("not found" in e.lower() for e in result.errors)
class TestCheck2FileNotEmpty:
def test_error_on_empty_file(self, tmp_path):
skill_dir = tmp_path / "my-skill"
skill_dir.mkdir()
path = skill_dir / "SKILL.md"
path.write_text(" \n", encoding="utf-8")
result = validate_strict(path)
assert result.passed is False
assert any("empty" in e.lower() for e in result.errors)
class TestCheck3FrontmatterPresent:
def test_error_on_missing_delimiters(self, tmp_path):
path = _write_skill(tmp_path, "name: my-skill\ndescription: no delimiters\n")
result = validate_strict(path)
assert result.passed is False
assert any("frontmatter" in e.lower() or "---" in e for e in result.errors)
class TestCheck4YamlNoFixup:
def test_error_on_yaml_requiring_fixup(self, tmp_path):
"""Unquoted colon in value — lenient parser accepts, strict rejects."""
content = """\
---
name: my-skill
description: Use for: research tasks
---
## Body
"""
path = _write_skill(tmp_path, content)
result = validate_strict(path)
assert result.passed is False
assert any("YAML" in e or "parse" in e.lower() for e in result.errors)
def test_quoted_colon_passes(self, tmp_path):
content = """\
---
name: my-skill
description: "Use for: research tasks"
license: MIT
---
## Body
"""
path = _write_skill(tmp_path, content)
result = validate_strict(path)
assert result.passed is True
class TestCheck5Description:
def test_error_on_missing_description(self, tmp_path):
content = """\
---
name: my-skill
license: MIT
---
## Body
"""
path = _write_skill(tmp_path, content)
result = validate_strict(path)
assert result.passed is False
assert any("description" in e.lower() for e in result.errors)
def test_error_on_empty_description(self, tmp_path):
content = """\
---
name: my-skill
description: ""
license: MIT
---
## Body
"""
path = _write_skill(tmp_path, content)
result = validate_strict(path)
assert result.passed is False
class TestCheck6NamePresent:
def test_error_on_missing_name(self, tmp_path):
content = """\
---
description: A skill without a name.
license: MIT
---
## Body
"""
path = _write_skill(tmp_path, content)
result = validate_strict(path)
assert result.passed is False
assert any("name" in e.lower() for e in result.errors)
class TestCheck7NameLength:
def test_error_on_name_too_long(self, tmp_path):
long_name = "a" * 65
skill_dir = tmp_path / long_name
skill_dir.mkdir(parents=True)
content = f"---\nname: {long_name}\ndescription: Too long.\nlicense: MIT\n---\n\n## Body\n"
path = skill_dir / "SKILL.md"
path.write_text(content, encoding="utf-8")
result = validate_strict(path)
assert result.passed is False
assert any("64" in e or "characters" in e.lower() for e in result.errors)
def test_exactly_64_chars_passes(self, tmp_path):
name = "a" * 64
skill_dir = tmp_path / name
skill_dir.mkdir(parents=True)
content = f"---\nname: {name}\ndescription: Exactly 64.\nlicense: MIT\n---\n\n## Body\n"
path = skill_dir / "SKILL.md"
path.write_text(content, encoding="utf-8")
result = validate_strict(path)
# May have other warnings but should not error on length
assert not any("64" in e or "characters" in e.lower() for e in result.errors)
class TestCheck8NameDirectoryMatch:
def test_error_on_name_dir_mismatch(self, tmp_path):
content = """\
---
name: other-skill
description: Wrong name.
license: MIT
---
## Body
"""
# Directory is my-skill but name is other-skill
path = _write_skill(tmp_path, content, dir_name="my-skill")
result = validate_strict(path)
assert result.passed is False
assert any("other-skill" in e or "my-skill" in e for e in result.errors)
def test_exact_match_passes(self, tmp_path):
content = """\
---
name: my-skill
description: Exact match.
license: MIT
---
## Body
"""
path = _write_skill(tmp_path, content, dir_name="my-skill")
result = validate_strict(path)
assert result.passed is True
def test_dot_namespace_prefix_passes(self, tmp_path):
"""hive.my-skill with dir my-skill is valid (namespace prefix)."""
content = """\
---
name: org.my-skill
description: Namespaced.
license: MIT
---
## Body
"""
path = _write_skill(tmp_path, content, dir_name="my-skill")
result = validate_strict(path)
# Should not error on name/dir mismatch for namespace prefix
assert not any("my-skill" in e and "other" in e for e in result.errors)
# Check no dir mismatch error specifically
name_mismatch_errors = [e for e in result.errors if "my-skill" in e and "org.my-skill" in e]
assert len(name_mismatch_errors) == 0
class TestCheck9BodyNotEmpty:
def test_error_on_empty_body(self, tmp_path):
content = """\
---
name: my-skill
description: No body.
license: MIT
---
"""
path = _write_skill(tmp_path, content)
result = validate_strict(path)
assert result.passed is False
assert any("body" in e.lower() or "instructions" in e.lower() for e in result.errors)
class TestCheck11Scripts:
def test_error_on_non_executable_script(self, tmp_path):
path = _write_skill(tmp_path, _VALID_CONTENT)
scripts_dir = path.parent / "scripts"
scripts_dir.mkdir()
script = scripts_dir / "run.sh"
script.write_text("#!/bin/sh\necho hi")
# Ensure NOT executable
script.chmod(0o644)
result = validate_strict(path)
assert result.passed is False
assert any("executable" in e.lower() for e in result.errors)
def test_passes_with_executable_script(self, tmp_path):
path = _write_skill(tmp_path, _VALID_CONTENT)
scripts_dir = path.parent / "scripts"
scripts_dir.mkdir()
script = scripts_dir / "run.sh"
script.write_text("#!/bin/sh\necho hi")
script.chmod(0o755)
result = validate_strict(path)
assert result.passed is True
class TestCheck12AllowedTools:
def test_warning_on_malformed_allowed_tools(self, tmp_path):
content = """\
---
name: my-skill
description: Skill with bad tools.
license: MIT
allowed-tools: "not a list"
---
## Body
"""
path = _write_skill(tmp_path, content)
result = validate_strict(path)
assert any("allowed-tools" in w.lower() for w in result.warnings)
def test_valid_allowed_tools_no_warning(self, tmp_path):
content = """\
---
name: my-skill
description: Valid tools list.
license: MIT
allowed-tools:
- web_search
- file_read
---
## Body
"""
path = _write_skill(tmp_path, content)
result = validate_strict(path)
assert not any("allowed-tools" in w.lower() for w in result.warnings)
class TestCheck13Compatibility:
def test_error_on_non_list_compatibility(self, tmp_path):
content = """\
---
name: my-skill
description: Bad compat.
license: MIT
compatibility: "claude-code"
---
## Body
"""
path = _write_skill(tmp_path, content)
result = validate_strict(path)
assert result.passed is False
assert any("compatibility" in e.lower() for e in result.errors)
def test_valid_compatibility_passes(self, tmp_path):
content = """\
---
name: my-skill
description: Good compat.
license: MIT
compatibility:
- claude-code
- hive
---
## Body
"""
path = _write_skill(tmp_path, content)
result = validate_strict(path)
assert result.passed is True
class TestCheck14Metadata:
def test_error_on_non_dict_metadata(self, tmp_path):
content = """\
---
name: my-skill
description: Bad metadata.
license: MIT
metadata: "not a dict"
---
## Body
"""
path = _write_skill(tmp_path, content)
result = validate_strict(path)
assert result.passed is False
assert any("metadata" in e.lower() for e in result.errors)
def test_valid_metadata_passes(self, tmp_path):
content = """\
---
name: my-skill
description: Good metadata.
license: MIT
metadata:
tags:
- research
---
## Body
"""
path = _write_skill(tmp_path, content)
result = validate_strict(path)
assert result.passed is True
+3 -496
View File
@@ -1,18 +1,16 @@
"""Tests for the storage module - FileStorage and ConcurrentStorage backends.
"""Tests for the storage module - ConcurrentStorage backend.
DEPRECATED: FileStorage and ConcurrentStorage are deprecated.
DEPRECATED: FileStorage has been removed.
New sessions use unified storage at sessions/{session_id}/state.json.
These tests are kept for backward compatibility verification only.
These tests are kept for backward compatibility verification of ConcurrentStorage only.
"""
import json
import time
from pathlib import Path
import pytest
from framework.schemas.run import Run, RunMetrics, RunStatus
from framework.storage.backend import FileStorage
from framework.storage.concurrent import CacheEntry, ConcurrentStorage
# === HELPER FUNCTIONS ===
@@ -40,277 +38,6 @@ def create_test_run(
)
# === FILESTORAGE TESTS ===
@pytest.mark.skip(reason="FileStorage is deprecated - use unified session storage")
class TestFileStorageBasics:
"""Test basic FileStorage operations."""
def test_init_creates_directories(self, tmp_path: Path):
"""FileStorage should create the directory structure on init."""
FileStorage(tmp_path)
assert (tmp_path / "runs").exists()
assert (tmp_path / "summaries").exists()
assert (tmp_path / "indexes" / "by_goal").exists()
assert (tmp_path / "indexes" / "by_status").exists()
assert (tmp_path / "indexes" / "by_node").exists()
def test_init_with_string_path(self, tmp_path: Path):
"""FileStorage should accept string paths."""
storage = FileStorage(str(tmp_path))
assert storage.base_path == tmp_path
@pytest.mark.skip(reason="FileStorage is deprecated - use unified session storage")
class TestFileStorageRunOperations:
"""Test FileStorage run CRUD operations."""
def test_save_and_load_run(self, tmp_path: Path):
"""Test saving and loading a run."""
storage = FileStorage(tmp_path)
run = create_test_run()
storage.save_run(run)
loaded = storage.load_run(run.id)
assert loaded is not None
assert loaded.id == run.id
assert loaded.goal_id == run.goal_id
assert loaded.status == run.status
def test_load_nonexistent_run_returns_none(self, tmp_path: Path):
"""Loading a nonexistent run should return None."""
storage = FileStorage(tmp_path)
result = storage.load_run("nonexistent_id")
assert result is None
def test_save_creates_json_file(self, tmp_path: Path):
"""Saving a run should create a JSON file."""
storage = FileStorage(tmp_path)
run = create_test_run(run_id="my_run")
storage.save_run(run)
run_file = tmp_path / "runs" / "my_run.json"
assert run_file.exists()
# Verify it's valid JSON
with open(run_file, encoding="utf-8") as f:
data = json.load(f)
assert data["id"] == "my_run"
def test_save_creates_summary(self, tmp_path: Path):
"""Saving a run should also create a summary file."""
storage = FileStorage(tmp_path)
run = create_test_run(run_id="my_run")
storage.save_run(run)
summary_file = tmp_path / "summaries" / "my_run.json"
assert summary_file.exists()
def test_load_summary(self, tmp_path: Path):
"""Test loading a run summary."""
storage = FileStorage(tmp_path)
run = create_test_run()
storage.save_run(run)
summary = storage.load_summary(run.id)
assert summary is not None
assert summary.run_id == run.id
assert summary.goal_id == run.goal_id
assert summary.status == run.status
def test_load_summary_fallback_to_run(self, tmp_path: Path):
"""If summary file is missing, load_summary should compute from run."""
storage = FileStorage(tmp_path)
run = create_test_run()
storage.save_run(run)
# Delete the summary file
summary_file = tmp_path / "summaries" / f"{run.id}.json"
summary_file.unlink()
# Should still work by computing from run
summary = storage.load_summary(run.id)
assert summary is not None
assert summary.run_id == run.id
def test_delete_run(self, tmp_path: Path):
"""Test deleting a run."""
storage = FileStorage(tmp_path)
run = create_test_run()
storage.save_run(run)
assert storage.load_run(run.id) is not None
result = storage.delete_run(run.id)
assert result is True
assert storage.load_run(run.id) is None
def test_delete_nonexistent_run_returns_false(self, tmp_path: Path):
"""Deleting a nonexistent run should return False."""
storage = FileStorage(tmp_path)
result = storage.delete_run("nonexistent")
assert result is False
@pytest.mark.skip(reason="FileStorage is deprecated - use unified session storage")
class TestFileStorageIndexing:
"""Test FileStorage index operations."""
def test_index_by_goal(self, tmp_path: Path):
"""Runs should be indexed by goal_id."""
storage = FileStorage(tmp_path)
run1 = create_test_run(run_id="run_1", goal_id="goal_a")
run2 = create_test_run(run_id="run_2", goal_id="goal_a")
run3 = create_test_run(run_id="run_3", goal_id="goal_b")
storage.save_run(run1)
storage.save_run(run2)
storage.save_run(run3)
goal_a_runs = storage.get_runs_by_goal("goal_a")
goal_b_runs = storage.get_runs_by_goal("goal_b")
assert len(goal_a_runs) == 2
assert "run_1" in goal_a_runs
assert "run_2" in goal_a_runs
assert len(goal_b_runs) == 1
assert "run_3" in goal_b_runs
def test_index_by_status(self, tmp_path: Path):
"""Runs should be indexed by status."""
storage = FileStorage(tmp_path)
run1 = create_test_run(run_id="run_1", status=RunStatus.COMPLETED)
run2 = create_test_run(run_id="run_2", status=RunStatus.FAILED)
run3 = create_test_run(run_id="run_3", status=RunStatus.COMPLETED)
storage.save_run(run1)
storage.save_run(run2)
storage.save_run(run3)
completed = storage.get_runs_by_status(RunStatus.COMPLETED)
failed = storage.get_runs_by_status(RunStatus.FAILED)
assert len(completed) == 2
assert len(failed) == 1
def test_index_by_status_string(self, tmp_path: Path):
"""get_runs_by_status should accept string status."""
storage = FileStorage(tmp_path)
run = create_test_run(status=RunStatus.RUNNING)
storage.save_run(run)
runs = storage.get_runs_by_status("running")
assert len(runs) == 1
def test_index_by_node(self, tmp_path: Path):
"""Runs should be indexed by executed nodes."""
storage = FileStorage(tmp_path)
run1 = create_test_run(run_id="run_1", nodes_executed=["node_a", "node_b"])
run2 = create_test_run(run_id="run_2", nodes_executed=["node_a", "node_c"])
storage.save_run(run1)
storage.save_run(run2)
node_a_runs = storage.get_runs_by_node("node_a")
node_b_runs = storage.get_runs_by_node("node_b")
node_c_runs = storage.get_runs_by_node("node_c")
assert len(node_a_runs) == 2
assert len(node_b_runs) == 1
assert len(node_c_runs) == 1
def test_delete_removes_from_indexes(self, tmp_path: Path):
"""Deleting a run should remove it from all indexes."""
storage = FileStorage(tmp_path)
run = create_test_run(
run_id="run_1",
goal_id="goal_a",
status=RunStatus.COMPLETED,
nodes_executed=["node_1"],
)
storage.save_run(run)
# Verify indexed
assert "run_1" in storage.get_runs_by_goal("goal_a")
assert "run_1" in storage.get_runs_by_status(RunStatus.COMPLETED)
assert "run_1" in storage.get_runs_by_node("node_1")
# Delete
storage.delete_run("run_1")
# Verify removed from indexes
assert "run_1" not in storage.get_runs_by_goal("goal_a")
assert "run_1" not in storage.get_runs_by_status(RunStatus.COMPLETED)
assert "run_1" not in storage.get_runs_by_node("node_1")
def test_empty_index_returns_empty_list(self, tmp_path: Path):
"""Querying an empty index should return empty list."""
storage = FileStorage(tmp_path)
assert storage.get_runs_by_goal("nonexistent") == []
assert storage.get_runs_by_status("nonexistent") == []
assert storage.get_runs_by_node("nonexistent") == []
@pytest.mark.skip(reason="FileStorage is deprecated - use unified session storage")
class TestFileStorageListOperations:
"""Test FileStorage list operations."""
def test_list_all_runs(self, tmp_path: Path):
"""Test listing all run IDs."""
storage = FileStorage(tmp_path)
storage.save_run(create_test_run(run_id="run_1"))
storage.save_run(create_test_run(run_id="run_2"))
storage.save_run(create_test_run(run_id="run_3"))
all_runs = storage.list_all_runs()
assert len(all_runs) == 3
assert set(all_runs) == {"run_1", "run_2", "run_3"}
def test_list_all_goals(self, tmp_path: Path):
"""Test listing all goal IDs that have runs."""
storage = FileStorage(tmp_path)
storage.save_run(create_test_run(run_id="run_1", goal_id="goal_a"))
storage.save_run(create_test_run(run_id="run_2", goal_id="goal_b"))
storage.save_run(create_test_run(run_id="run_3", goal_id="goal_a"))
all_goals = storage.list_all_goals()
assert len(all_goals) == 2
assert set(all_goals) == {"goal_a", "goal_b"}
def test_get_stats(self, tmp_path: Path):
"""Test getting storage statistics."""
storage = FileStorage(tmp_path)
storage.save_run(create_test_run(run_id="run_1", goal_id="goal_a"))
storage.save_run(create_test_run(run_id="run_2", goal_id="goal_b"))
stats = storage.get_stats()
assert stats["total_runs"] == 2
assert stats["total_goals"] == 2
assert stats["storage_path"] == str(tmp_path)
# === CACHE ENTRY TESTS ===
@@ -332,7 +59,6 @@ class TestCacheEntry:
# === CONCURRENTSTORAGE TESTS ===
@pytest.mark.skip(reason="ConcurrentStorage is deprecated - wraps deprecated FileStorage")
class TestConcurrentStorageBasics:
"""Test basic ConcurrentStorage operations."""
@@ -377,168 +103,6 @@ class TestConcurrentStorageBasics:
assert storage._running is False
@pytest.mark.skip(reason="ConcurrentStorage is deprecated - wraps deprecated FileStorage")
class TestConcurrentStorageRunOperations:
"""Test ConcurrentStorage run operations."""
@pytest.mark.asyncio
async def test_save_and_load_run(self, tmp_path: Path):
"""Test async save and load of a run."""
storage = ConcurrentStorage(tmp_path)
await storage.start()
try:
run = create_test_run()
await storage.save_run(run, immediate=True)
loaded = await storage.load_run(run.id)
assert loaded is not None
assert loaded.id == run.id
assert loaded.goal_id == run.goal_id
finally:
await storage.stop()
@pytest.mark.asyncio
async def test_load_run_uses_cache(self, tmp_path: Path):
"""Second load should use cached value."""
storage = ConcurrentStorage(tmp_path)
await storage.start()
try:
run = create_test_run()
await storage.save_run(run, immediate=True)
# First load
loaded1 = await storage.load_run(run.id)
# Second load (should use cache)
loaded2 = await storage.load_run(run.id, use_cache=True)
assert loaded1 is not None
assert loaded2 is not None
# Cache should return same object
assert loaded1 is loaded2
finally:
await storage.stop()
@pytest.mark.asyncio
async def test_load_run_bypass_cache(self, tmp_path: Path):
"""Load with use_cache=False should bypass cache."""
storage = ConcurrentStorage(tmp_path)
await storage.start()
try:
run = create_test_run()
await storage.save_run(run, immediate=True)
loaded1 = await storage.load_run(run.id)
loaded2 = await storage.load_run(run.id, use_cache=False)
assert loaded1 is not None
assert loaded2 is not None
# Fresh load should be different object
assert loaded1 is not loaded2
finally:
await storage.stop()
@pytest.mark.asyncio
async def test_delete_run(self, tmp_path: Path):
"""Test async delete of a run."""
storage = ConcurrentStorage(tmp_path)
await storage.start()
try:
run = create_test_run()
await storage.save_run(run, immediate=True)
result = await storage.delete_run(run.id)
assert result is True
loaded = await storage.load_run(run.id)
assert loaded is None
finally:
await storage.stop()
@pytest.mark.asyncio
async def test_delete_clears_cache(self, tmp_path: Path):
"""Deleting a run should clear it from cache."""
storage = ConcurrentStorage(tmp_path)
await storage.start()
try:
run = create_test_run()
await storage.save_run(run, immediate=True)
# Load to populate cache
await storage.load_run(run.id)
assert f"run:{run.id}" in storage._cache
# Delete
await storage.delete_run(run.id)
# Cache should be cleared
assert f"run:{run.id}" not in storage._cache
finally:
await storage.stop()
@pytest.mark.skip(reason="ConcurrentStorage is deprecated - wraps deprecated FileStorage")
class TestConcurrentStorageQueryOperations:
"""Test ConcurrentStorage query operations."""
@pytest.mark.asyncio
async def test_get_runs_by_goal(self, tmp_path: Path):
"""Test async query by goal."""
storage = ConcurrentStorage(tmp_path)
await storage.start()
try:
run1 = create_test_run(run_id="run_1", goal_id="goal_a")
run2 = create_test_run(run_id="run_2", goal_id="goal_a")
await storage.save_run(run1, immediate=True)
await storage.save_run(run2, immediate=True)
runs = await storage.get_runs_by_goal("goal_a")
assert len(runs) == 2
finally:
await storage.stop()
@pytest.mark.asyncio
async def test_get_runs_by_status(self, tmp_path: Path):
"""Test async query by status."""
storage = ConcurrentStorage(tmp_path)
await storage.start()
try:
run = create_test_run(status=RunStatus.FAILED)
await storage.save_run(run, immediate=True)
runs = await storage.get_runs_by_status(RunStatus.FAILED)
assert len(runs) == 1
finally:
await storage.stop()
@pytest.mark.asyncio
async def test_list_all_runs(self, tmp_path: Path):
"""Test async list all runs."""
storage = ConcurrentStorage(tmp_path)
await storage.start()
try:
await storage.save_run(create_test_run(run_id="run_1"), immediate=True)
await storage.save_run(create_test_run(run_id="run_2"), immediate=True)
runs = await storage.list_all_runs()
assert len(runs) == 2
finally:
await storage.stop()
@pytest.mark.skip(reason="ConcurrentStorage is deprecated - wraps deprecated FileStorage")
class TestConcurrentStorageCacheManagement:
"""Test ConcurrentStorage cache management."""
@@ -576,60 +140,3 @@ class TestConcurrentStorageCacheManagement:
assert stats["total_entries"] == 2
assert stats["expired_entries"] == 1
assert stats["valid_entries"] == 1
@pytest.mark.skip(reason="ConcurrentStorage is deprecated - wraps deprecated FileStorage")
class TestConcurrentStorageSyncAPI:
"""Test ConcurrentStorage synchronous API for backward compatibility."""
def test_save_run_sync(self, tmp_path: Path):
"""Test synchronous save."""
storage = ConcurrentStorage(tmp_path)
run = create_test_run()
storage.save_run_sync(run)
# Verify saved
loaded = storage.load_run_sync(run.id)
assert loaded is not None
assert loaded.id == run.id
def test_load_run_sync(self, tmp_path: Path):
"""Test synchronous load."""
storage = ConcurrentStorage(tmp_path)
run = create_test_run()
storage.save_run_sync(run)
loaded = storage.load_run_sync(run.id)
assert loaded is not None
def test_load_run_sync_nonexistent(self, tmp_path: Path):
"""Synchronous load of nonexistent run returns None."""
storage = ConcurrentStorage(tmp_path)
loaded = storage.load_run_sync("nonexistent")
assert loaded is None
@pytest.mark.skip(reason="ConcurrentStorage is deprecated - wraps deprecated FileStorage")
class TestConcurrentStorageStats:
"""Test ConcurrentStorage statistics."""
@pytest.mark.asyncio
async def test_get_stats(self, tmp_path: Path):
"""Test getting async storage stats."""
storage = ConcurrentStorage(tmp_path)
await storage.start()
try:
await storage.save_run(create_test_run(), immediate=True)
stats = await storage.get_stats()
assert stats["total_runs"] == 1
assert "cache" in stats
assert "pending_writes" in stats
assert stats["running"] is True
finally:
await storage.stop()
+1 -1
View File
@@ -214,7 +214,7 @@ def test_load_registry_servers_retries_when_registration_returns_zero(monkeypatc
registry = ToolRegistry()
attempts = {"count": 0}
def fake_register(server_config, use_connection_manager=True):
def fake_register(server_config, use_connection_manager=True, **kwargs):
attempts["count"] += 1
return 0 if attempts["count"] == 1 else 2
+53 -40
View File
@@ -198,33 +198,44 @@ Use the coder-tools MCP tools from your IDE agent chat (e.g., initialize_and_bui
If you prefer to build agents manually:
```python
# exports/my_agent/agent.json
```jsonc
// exports/my_agent/agent.json
{
"goal": {
"agent": {
"id": "my_agent",
"name": "Support Ticket Handler",
"version": "1.0.0",
"description": "Process customer support tickets"
},
"graph": {
"id": "my_agent-graph",
"goal_id": "support_ticket",
"entry_node": "analyze",
"terminal_nodes": ["analyze"],
"nodes": [
{
"id": "analyze",
"name": "Analyze Ticket",
"description": "Categorize and prioritize the support ticket",
"node_type": "event_loop",
"system_prompt": "Analyze this support ticket...",
"input_keys": ["ticket_content"],
"output_keys": ["category", "priority"]
}
],
"edges": []
},
"goal": {
"id": "support_ticket",
"name": "Support Ticket Handler",
"description": "Process customer support tickets",
"success_criteria": "Ticket is categorized, prioritized, and routed correctly"
},
"nodes": [
{
"node_id": "analyze",
"name": "Analyze Ticket",
"node_type": "event_loop",
"system_prompt": "Analyze this support ticket...",
"input_keys": ["ticket_content"],
"output_keys": ["category", "priority"]
}
],
"edges": [
{
"edge_id": "start_to_analyze",
"source": "START",
"target": "analyze",
"condition": "on_success"
}
]
"success_criteria": [
{
"id": "sc-categorized",
"description": "Ticket is categorized and prioritized correctly"
}
]
}
}
```
@@ -532,16 +543,17 @@ def my_custom_tool(param1: str, param2: int) -> Dict[str, Any]:
# Implementation
return {"result": "success", "data": ...}
# Register tool in agent.json
# Register tool in agent.json (inside "graph" → "nodes")
{
"nodes": [
{
"node_id": "use_tool",
"node_type": "event_loop",
"tools": ["my_custom_tool"],
...
}
]
"graph": {
"nodes": [
{
"id": "use_tool",
"node_type": "event_loop",
"tools": ["my_custom_tool"]
}
]
}
}
```
@@ -560,15 +572,16 @@ def my_custom_tool(param1: str, param2: int) -> Dict[str, Any]:
}
}
# 2. Reference tools in agent.json
# 2. Reference tools in agent.json (inside "graph" → "nodes")
{
"nodes": [
{
"node_id": "search",
"tools": ["web_search", "web_scrape"],
...
}
]
"graph": {
"nodes": [
{
"id": "search",
"tools": ["web_search", "web_scrape"]
}
]
}
}
```
+4 -3
View File
@@ -31,8 +31,8 @@
"nullable_output_keys": [],
"input_schema": {},
"output_schema": {},
"system_prompt": "You are a career analyst helping a job seeker find their best opportunities.\n\n**STEP 1 \u2014 Greet and collect resume (text only, NO tool calls):**\n\nAsk the user to paste their resume. Be friendly and concise:\n\"Please paste your resume below. I'll analyze your experience and identify the roles where you have the strongest chance of success.\"\n\n**STEP 2 \u2014 After the user provides their resume:**\n\nAnalyze the resume thoroughly:\n1. Identify key skills (technical and soft skills)\n2. Summarize years and types of experience\n3. Identify 3-5 SPECIFIC, GRANULAR role types where they're competitive\n\n**IMPORTANT \u2014 Role Specificity:**\nRespect the job seeker by providing granular options, not generic buckets.\n- BAD: \"Software Engineer\" (too broad)\n- GOOD: \"Backend Engineer (Python/Django)\", \"Platform Engineer\", \"API Developer\", \"Data Pipeline Engineer\"\n\nEach role should be distinct and searchable. The more specific, the better the job matches will be\n\nPresent your analysis to the user and ask if they agree with the role types identified. DO NOT ask follow-up questions. DO NOT ask which roles to focus on.\n\n**STEP 3 \u2014 After user confirms roles, call set_output:**\n\nUse set_output to store:\n- set_output(\"resume_text\", \"<the full resume text>\")\n- set_output(\"role_analysis\", \"<JSON with: skills, experience_summary, target_roles (3-5 specific role titles)>\")\n\nIMPORTANT: When the user says \"yes\", \"sure\", \"go ahead\", \"find jobs\" or similar, call set_output IMMEDIATELY. NEVER ask the user to pick between roles.",
"tools": [],
"system_prompt": "You are a career analyst helping a job seeker find their best opportunities.\n\n**STEP 1 \u2014 Greet and collect resume:**\n\nAsk the user to provide their resume. They can either paste the text directly or provide a path to a PDF file. Be friendly and concise:\n\"Please paste your resume below, or provide the file path to your PDF resume (e.g., /path/to/resume.pdf). I'll analyze your experience and identify the roles where you have the strongest chance of success.\"\n\nIf the user provides a file path to a PDF, call pdf_read(file_path=\"<path>\") to extract the text before proceeding.\n\n**STEP 2 \u2014 After the user provides their resume:**\n\nAnalyze the resume thoroughly:\n1. Identify key skills (technical and soft skills)\n2. Summarize years and types of experience\n3. Identify 3-5 SPECIFIC, GRANULAR role types where they're competitive\n\n**IMPORTANT \u2014 Role Specificity:**\nRespect the job seeker by providing granular options, not generic buckets.\n- BAD: \"Software Engineer\" (too broad)\n- GOOD: \"Backend Engineer (Python/Django)\", \"Platform Engineer\", \"API Developer\", \"Data Pipeline Engineer\"\n\nEach role should be distinct and searchable. The more specific, the better the job matches will be\n\nPresent your analysis to the user and ask if they agree with the role types identified. DO NOT ask follow-up questions. DO NOT ask which roles to focus on.\n\n**STEP 3 \u2014 After user confirms roles, call set_output:**\n\nUse set_output to store:\n- set_output(\"resume_text\", \"<the full resume text>\")\n- set_output(\"role_analysis\", \"<JSON with: skills, experience_summary, target_roles (3-5 specific role titles)>\")\n\nIMPORTANT: When the user says \"yes\", \"sure\", \"go ahead\", \"find jobs\" or similar, call set_output IMMEDIATELY. NEVER ask the user to pick between roles.",
"tools": ["pdf_read"],
"model": null,
"function": null,
"routes": {},
@@ -261,7 +261,8 @@
"append_data",
"serve_file_to_user",
"web_scrape",
"gmail_create_draft"
"gmail_create_draft",
"pdf_read"
],
"metadata": {
"created_at": "2026-02-13T18:41:10.324531",
@@ -9,9 +9,9 @@ intake_node = NodeSpec(
name="Intake",
description="Analyze resume and identify 3-5 strongest role types",
node_type="event_loop",
client_facing=False,
client_facing=True,
max_node_visits=1,
input_keys=["resume_text"],
input_keys=[],
output_keys=["resume_text", "role_analysis"],
success_criteria=(
"The user's resume has been analyzed and 3-5 target roles identified "
@@ -20,6 +20,12 @@ intake_node = NodeSpec(
system_prompt="""\
You are a career analyst. Your task is to analyze the user's resume and identify the best role fits.
**ACCEPTING THE RESUME:**
The user can provide their resume in two ways:
1. **Paste text** The user pastes their resume content directly.
2. **PDF file path** The user provides a path to a PDF file (e.g., "/path/to/resume.pdf"). \
If a file path is provided, call pdf_read(file_path="<path>") to extract the text before analyzing.
**PROCESS:**
1. Identify key skills (technical and soft skills).
2. Summarize years and types of experience.
@@ -32,7 +38,7 @@ You MUST call set_output to store:
Do NOT wait for user confirmation. Simply perform the analysis and set the outputs.
""",
tools=[],
tools=["pdf_read"],
)
# Node 2: Job Search (simple)
@@ -0,0 +1 @@
{ "include": ["hive-tools"] }
+165 -13
View File
@@ -1022,6 +1022,10 @@ $hiveKey = [System.Environment]::GetEnvironmentVariable("HIVE_API_KEY", "User")
if (-not $hiveKey) { $hiveKey = $env:HIVE_API_KEY }
if ($hiveKey) { $HiveCredDetected = $true }
$AntigravityCredDetected = $false
$antigravityAuthPath = Join-Path $env:USERPROFILE ".hive\antigravity-accounts.json"
if (Test-Path $antigravityAuthPath) { $AntigravityCredDetected = $true }
# Detect API key providers
$ProviderMenuEnvVars = @("ANTHROPIC_API_KEY", "OPENAI_API_KEY", "GEMINI_API_KEY", "GROQ_API_KEY", "CEREBRAS_API_KEY", "OPENROUTER_API_KEY")
$ProviderMenuNames = @("Anthropic (Claude) - Recommended", "OpenAI (GPT)", "Google Gemini - Free tier available", "Groq - Fast, free tier", "Cerebras - Fast, free tier", "OpenRouter - Bring any OpenRouter model")
@@ -1035,6 +1039,12 @@ $ProviderMenuUrls = @(
"https://openrouter.ai/keys"
)
$OllamaDetected = $false
try {
$null = & ollama list 2>$null
if ($LASTEXITCODE -eq 0) { $OllamaDetected = $true }
} catch { }
# ── Read previous configuration (if any) ──────────────────────
$PrevProvider = ""
$PrevModel = ""
@@ -1051,6 +1061,7 @@ if (Test-Path $HiveConfigFile) {
if ($prevLlm.use_claude_code_subscription) { $PrevSubMode = "claude_code" }
elseif ($prevLlm.use_codex_subscription) { $PrevSubMode = "codex" }
elseif ($prevLlm.use_kimi_code_subscription) { $PrevSubMode = "kimi_code" }
elseif ($prevLlm.use_antigravity_subscription) { $PrevSubMode = "antigravity" }
elseif ($prevLlm.api_base -and $prevLlm.api_base -like "*api.z.ai*") { $PrevSubMode = "zai_code" }
elseif ($prevLlm.provider -eq "minimax" -or ($prevLlm.api_base -and $prevLlm.api_base -like "*api.minimax.io*")) { $PrevSubMode = "minimax_code" }
elseif ($prevLlm.api_base -and $prevLlm.api_base -like "*api.kimi.com*") { $PrevSubMode = "kimi_code" }
@@ -1070,8 +1081,11 @@ if ($PrevSubMode -or $PrevProvider) {
"minimax_code" { if ($MinimaxCredDetected) { $prevCredValid = $true } }
"kimi_code" { if ($KimiCredDetected) { $prevCredValid = $true } }
"hive_llm" { if ($HiveCredDetected) { $prevCredValid = $true } }
"antigravity" { if ($AntigravityCredDetected) { $prevCredValid = $true } }
default {
if ($PrevEnvVar) {
if ($PrevProvider -eq "ollama") {
$prevCredValid = $true
} elseif ($PrevEnvVar) {
$envVal = [System.Environment]::GetEnvironmentVariable($PrevEnvVar, "Process")
if (-not $envVal) { $envVal = [System.Environment]::GetEnvironmentVariable($PrevEnvVar, "User") }
if ($envVal) { $prevCredValid = $true }
@@ -1086,17 +1100,20 @@ if ($PrevSubMode -or $PrevProvider) {
"minimax_code" { $DefaultChoice = "4" }
"kimi_code" { $DefaultChoice = "5" }
"hive_llm" { $DefaultChoice = "6" }
"antigravity" { $DefaultChoice = "7" }
}
if (-not $DefaultChoice) {
switch ($PrevProvider) {
"anthropic" { $DefaultChoice = "7" }
"openai" { $DefaultChoice = "8" }
"gemini" { $DefaultChoice = "9" }
"groq" { $DefaultChoice = "10" }
"cerebras" { $DefaultChoice = "11" }
"openrouter" { $DefaultChoice = "12" }
"anthropic" { $DefaultChoice = "8" }
"openai" { $DefaultChoice = "9" }
"gemini" { $DefaultChoice = "10" }
"groq" { $DefaultChoice = "11" }
"cerebras" { $DefaultChoice = "12" }
"openrouter" { $DefaultChoice = "13" }
"ollama" { $DefaultChoice = "14" }
"minimax" { $DefaultChoice = "4" }
"kimi" { $DefaultChoice = "5" }
"hive" { $DefaultChoice = "6" }
}
}
}
@@ -1149,12 +1166,19 @@ Write-Host ") Hive LLM " -NoNewline
Write-Color -Text "(use your Hive API key)" -Color DarkGray -NoNewline
if ($HiveCredDetected) { Write-Color -Text " (credential detected)" -Color Green } else { Write-Host "" }
# 7) Antigravity
Write-Host " " -NoNewline
Write-Color -Text "7" -Color Cyan -NoNewline
Write-Host ") Antigravity Subscription " -NoNewline
Write-Color -Text "(use your Google/Gemini plan)" -Color DarkGray -NoNewline
if ($AntigravityCredDetected) { Write-Color -Text " (credential detected)" -Color Green } else { Write-Host "" }
Write-Host ""
Write-Color -Text " API key providers:" -Color Cyan
# 7-12) API key providers
# 8-13) API key providers
for ($idx = 0; $idx -lt $ProviderMenuEnvVars.Count; $idx++) {
$num = $idx + 7
$num = $idx + 8
$envVal = [System.Environment]::GetEnvironmentVariable($ProviderMenuEnvVars[$idx], "Process")
if (-not $envVal) { $envVal = [System.Environment]::GetEnvironmentVariable($ProviderMenuEnvVars[$idx], "User") }
Write-Host " " -NoNewline
@@ -1163,7 +1187,17 @@ for ($idx = 0; $idx -lt $ProviderMenuEnvVars.Count; $idx++) {
if ($envVal) { Write-Color -Text " (credential detected)" -Color Green } else { Write-Host "" }
}
$SkipChoice = 7 + $ProviderMenuEnvVars.Count
# 14) Local (Ollama) - no API key needed
Write-Host " " -NoNewline
Write-Color -Text "14" -Color Cyan -NoNewline
if ($OllamaDetected) {
Write-Host ") Local (Ollama) - No API key needed " -NoNewline
Write-Color -Text "(ollama detected)" -Color Green
} else {
Write-Host ") Local (Ollama) - No API key needed"
}
$SkipChoice = 8 + $ProviderMenuEnvVars.Count + 1
Write-Host " " -NoNewline
Write-Color -Text "$SkipChoice" -Color Cyan -NoNewline
Write-Host ") Skip for now"
@@ -1301,9 +1335,48 @@ switch ($num) {
}
Write-Color -Text " Model: $SelectedModel | API: $HiveLlmEndpoint" -Color DarkGray
}
{ $_ -ge 7 -and $_ -le 12 } {
7 {
# Antigravity Subscription
if (-not $AntigravityCredDetected) {
Write-Host ""
Write-Color -Text " Setting up Antigravity authentication..." -Color Cyan
Write-Host ""
Write-Warn "A browser window will open for Google OAuth."
Write-Host " Sign in with your Google account that has Antigravity access."
Write-Host ""
try {
$null = & $UvCmd run python (Join-Path $ScriptDir "core\antigravity_auth.py") auth account add 2>&1
if ($LASTEXITCODE -eq 0 -and (Test-Path $antigravityAuthPath)) {
$AntigravityCredDetected = $true
}
} catch {
$AntigravityCredDetected = $false
}
if (-not $AntigravityCredDetected) {
Write-Host ""
Write-Fail "Authentication failed or was cancelled."
Write-Host ""
$SelectedProviderId = ""
}
}
if ($AntigravityCredDetected) {
$SubscriptionMode = "antigravity"
$SelectedProviderId = "openai"
$SelectedModel = "gemini-3-flash"
$SelectedMaxTokens = 32768
$SelectedMaxContextTokens = 1000000
Write-Host ""
Write-Warn "Using Antigravity can technically cause your account suspension. Please use at your own risk."
Write-Host ""
Write-Ok "Using Antigravity subscription"
Write-Color -Text " Model: gemini-3-flash | Direct OAuth (no proxy required)" -Color DarkGray
}
}
{ $_ -ge 8 -and $_ -le 13 } {
# API key providers
$provIdx = $num - 7
$provIdx = $num - 8
$SelectedEnvVar = $ProviderMenuEnvVars[$provIdx]
$SelectedProviderId = $ProviderMenuIds[$provIdx]
$providerName = $ProviderMenuNames[$provIdx] -replace ' - .*', '' # strip description
@@ -1383,6 +1456,75 @@ switch ($num) {
}
}
}
14 {
# Local (Ollama)
if (-not $OllamaDetected) {
Write-Host ""
Write-Warn "Ollama depends on a local Ollama server, but 'ollama list' failed."
Write-Host " Please install Ollama (https://ollama.com) and start the server,"
Write-Host " then run this quickstart again."
Write-Host ""
exit 1
}
$SelectedProviderId = "ollama"
Write-Host ""
Write-Ok "Using Local (Ollama)"
Write-Host ""
# Fetch available models
$ollamaModels = @()
try {
$listOutput = & ollama list 2>$null
if ($listOutput.Count -gt 1) {
for ($i = 1; $i -lt $listOutput.Count; $i++) {
$line = $listOutput[$i].Trim()
if ($line) {
$mName = ($line -split '\s+')[0]
if ($mName) { $ollamaModels += $mName }
}
}
}
} catch { }
if ($ollamaModels.Count -eq 0) {
Write-Warn "No Ollama models found."
Write-Host " Please open another terminal, run 'ollama run <model>' (e.g. 'ollama run llama3'),"
Write-Host " and then run this quickstart again."
Write-Host ""
exit 1
}
# Show model picker
Write-Host " Select an Ollama model:"
Write-Host ""
$defaultIdx = "1"
for ($i = 0; $i -lt $ollamaModels.Count; $i++) {
Write-Color -Text " $($i + 1)" -Color Cyan -NoNewline
Write-Host ") $($ollamaModels[$i])"
if ($PrevProvider -eq "ollama" -and $PrevModel -eq $ollamaModels[$i]) {
$defaultIdx = [string]($i + 1)
}
}
Write-Host ""
while ($true) {
$raw = Read-Host "Enter choice (1-$($ollamaModels.Count)) [$defaultIdx]"
if ([string]::IsNullOrWhiteSpace($raw)) { $raw = $defaultIdx }
if ($raw -match '^\d+$') {
$num = [int]$raw
if ($num -ge 1 -and $num -le $ollamaModels.Count) {
$SelectedModel = $ollamaModels[$num - 1]
Write-Host ""
Write-Ok "Model: $SelectedModel"
$SelectedMaxTokens = 8192
$SelectedMaxContextTokens = 16384
$SelectedApiBase = "http://localhost:11434"
break
}
}
Write-Color -Text "Invalid choice. Please enter 1-$($ollamaModels.Count)" -Color Red
}
}
{ $_ -eq $SkipChoice } {
Write-Host ""
Write-Warn "Skipped. An LLM API key is required to test and use worker agents."
@@ -1686,6 +1828,8 @@ if ($SelectedProviderId) {
$config.llm["use_claude_code_subscription"] = $true
} elseif ($SubscriptionMode -eq "codex") {
$config.llm["use_codex_subscription"] = $true
} elseif ($SubscriptionMode -eq "antigravity") {
$config.llm["use_antigravity_subscription"] = $true
} elseif ($SubscriptionMode -eq "zai_code") {
$config.llm["api_base"] = "https://api.z.ai/api/coding/paas/v4"
$config.llm["api_key_env_var"] = $SelectedEnvVar
@@ -1701,8 +1845,13 @@ if ($SelectedProviderId) {
} elseif ($SelectedProviderId -eq "openrouter") {
$config.llm["api_base"] = "https://openrouter.ai/api/v1"
$config.llm["api_key_env_var"] = $SelectedEnvVar
} else {
} elseif ($SelectedProviderId -eq "ollama") {
$config.llm["api_base"] = "http://localhost:11434"
$config.llm.Remove("api_key_env_var")
} elseif ($SelectedEnvVar) {
$config.llm["api_key_env_var"] = $SelectedEnvVar
} else {
$config.llm.Remove("api_key_env_var")
}
$config | ConvertTo-Json -Depth 4 | Set-Content -Path $HiveConfigFile -Encoding UTF8
@@ -2003,6 +2152,9 @@ if ($SelectedProviderId) {
Write-Color -Text " API: api.minimax.io/v1 (OpenAI-compatible)" -Color DarkGray
} elseif ($SubscriptionMode -eq "codex") {
Write-Ok "OpenAI Codex Subscription -> $SelectedModel"
} elseif ($SubscriptionMode -eq "antigravity") {
Write-Ok "Antigravity Subscription -> $SelectedModel"
Write-Color -Text " Direct OAuth (no proxy required)" -Color DarkGray
} elseif ($SelectedProviderId -eq "openrouter") {
Write-Ok "OpenRouter API Key -> $SelectedModel"
Write-Color -Text " API: openrouter.ai/api/v1 (OpenAI-compatible)" -Color DarkGray
+97 -8
View File
@@ -673,7 +673,18 @@ detect_shell_rc() {
fi
;;
bash)
if [ -f "$HOME/.bashrc" ]; then
# Git Bash on Windows commonly starts as a login shell, so prefer
# .bash_profile there when it already exists. On Unix-like shells,
# keep the traditional .bashrc-first behavior.
if [ -n "$MSYSTEM" ] || [ -n "$MINGW_PREFIX" ]; then
if [ -f "$HOME/.bash_profile" ]; then
echo "$HOME/.bash_profile"
elif [ -f "$HOME/.bashrc" ]; then
echo "$HOME/.bashrc"
else
echo "$HOME/.profile"
fi
elif [ -f "$HOME/.bashrc" ]; then
echo "$HOME/.bashrc"
elif [ -f "$HOME/.bash_profile" ]; then
echo "$HOME/.bash_profile"
@@ -912,8 +923,9 @@ config["llm"] = {
"model": model,
"max_tokens": int(max_tokens),
"max_context_tokens": int(max_context_tokens),
"api_key_env_var": env_var,
}
if env_var:
config["llm"]["api_key_env_var"] = env_var
config["created_at"] = created_at
if use_claude_code_sub == "true":
@@ -1024,6 +1036,11 @@ elif [ -f "$HOME/.hive/antigravity-accounts.json" ]; then
ANTIGRAVITY_CRED_DETECTED=true
fi
OLLAMA_DETECTED=false
if ollama list >/dev/null 2>&1; then
OLLAMA_DETECTED=true
fi
# Detect API key providers
if [ "$USE_ASSOC_ARRAYS" = true ]; then
for env_var in "${!PROVIDER_NAMES[@]}"; do
@@ -1056,9 +1073,12 @@ try:
with open(cfg_path, encoding="utf-8-sig") as f:
c = json.load(f)
llm = c.get("llm", {})
print(f"PREV_PROVIDER={llm.get(\"provider\", \"\")}")
print(f"PREV_MODEL={llm.get(\"model\", \"\")}")
print(f"PREV_ENV_VAR={llm.get(\"api_key_env_var\", \"\")}")
prov = llm.get("provider", "")
mod = llm.get("model", "")
env = llm.get("api_key_env_var", "")
print(f"PREV_PROVIDER='{prov}'")
print(f"PREV_MODEL='{mod}'")
print(f"PREV_ENV_VAR='{env}'")
sub = ""
if llm.get("use_claude_code_subscription"):
sub = "claude_code"
@@ -1093,8 +1113,12 @@ if [ -n "$PREV_SUB_MODE" ] || [ -n "$PREV_PROVIDER" ]; then
hive_llm) [ "$HIVE_CRED_DETECTED" = true ] && PREV_CRED_VALID=true ;;
antigravity) [ "$ANTIGRAVITY_CRED_DETECTED" = true ] && PREV_CRED_VALID=true ;;
*)
# API key provider — check if the env var is set
if [ -n "$PREV_ENV_VAR" ] && [ -n "${!PREV_ENV_VAR}" ]; then
# API key provider — check if the env var is set; ollama uses local runtime detection
if [ "$PREV_PROVIDER" = "ollama" ]; then
if [ "$OLLAMA_DETECTED" = true ]; then
PREV_CRED_VALID=true
fi
elif [ -n "$PREV_ENV_VAR" ] && [ -n "${!PREV_ENV_VAR}" ]; then
PREV_CRED_VALID=true
fi
;;
@@ -1118,6 +1142,7 @@ if [ -n "$PREV_SUB_MODE" ] || [ -n "$PREV_PROVIDER" ]; then
groq) DEFAULT_CHOICE=11 ;;
cerebras) DEFAULT_CHOICE=12 ;;
openrouter) DEFAULT_CHOICE=13 ;;
ollama) DEFAULT_CHOICE=14 ;;
minimax) DEFAULT_CHOICE=4 ;;
kimi) DEFAULT_CHOICE=5 ;;
hive) DEFAULT_CHOICE=6 ;;
@@ -1196,7 +1221,14 @@ for idx in "${!PROVIDER_MENU_ENVS[@]}"; do
fi
done
SKIP_CHOICE=$((8 + ${#PROVIDER_MENU_ENVS[@]}))
# 14) Local (Ollama) — no API key needed
if [ "$OLLAMA_DETECTED" = true ]; then
echo -e " ${CYAN}14)${NC} Local (Ollama) - No API key needed ${GREEN}(ollama detected)${NC}"
else
echo -e " ${CYAN}14)${NC} Local (Ollama) - No API key needed"
fi
SKIP_CHOICE=$((8 + ${#PROVIDER_MENU_ENVS[@]} + 1))
echo -e " ${CYAN}$SKIP_CHOICE)${NC} Skip for now"
echo ""
@@ -1414,6 +1446,56 @@ case $choice in
PROVIDER_NAME="OpenRouter"
SIGNUP_URL="https://openrouter.ai/keys"
;;
14)
# Local (Ollama) — no API key; pick model from ollama list
if [ "$OLLAMA_DETECTED" != true ]; then
echo ""
echo -e "${YELLOW}Ollama depends on a local Ollama server, but 'ollama list' failed.${NC}"
echo -e " Please install Ollama (https://ollama.com) and start the server,"
echo -e " then run this quickstart again."
echo ""
exit 1
fi
SELECTED_PROVIDER_ID="ollama"
SELECTED_ENV_VAR=""
SELECTED_MAX_TOKENS=8192
SELECTED_MAX_CONTEXT_TOKENS=16384
OLLAMA_MODELS=()
while IFS= read -r line; do
[ -n "$line" ] && OLLAMA_MODELS+=("$line")
done < <(ollama list 2>/dev/null | tail -n +2 | awk '{print $1}')
if [ ${#OLLAMA_MODELS[@]} -gt 0 ]; then
echo ""
echo -e "${BOLD}Select an Ollama model:${NC}"
echo ""
for idx in "${!OLLAMA_MODELS[@]}"; do
num=$((idx + 1))
echo -e " ${CYAN}$num)${NC} ${OLLAMA_MODELS[$idx]}"
done
echo ""
while true; do
read -r -p "Enter choice (1-${#OLLAMA_MODELS[@]}): " model_choice
if [[ "$model_choice" =~ ^[0-9]+$ ]] && [ "$model_choice" -ge 1 ] && [ "$model_choice" -le ${#OLLAMA_MODELS[@]} ]; then
SELECTED_MODEL="${OLLAMA_MODELS[$((model_choice - 1))]}"
SELECTED_API_BASE="http://localhost:11434"
break
fi
echo -e "${RED}Invalid choice. Please enter 1-${#OLLAMA_MODELS[@]}${NC}"
done
echo ""
echo -e "${GREEN}${NC} Using Ollama with model ${DIM}$SELECTED_MODEL${NC}"
echo -e "${YELLOW} ⚠ Note: The framework uses a ~9,500 token system prompt and requires strong tool use.${NC}"
echo -e "${YELLOW} For best results, use models like qwen2.5:72b+ or mistral-large.${NC}"
echo ""
else
echo ""
echo -e "${RED}No Ollama models found.${NC}"
echo -e " Please open another terminal, run ${CYAN}ollama pull llama3${NC} (or another model),"
echo -e " and then run this quickstart again."
echo ""
exit 1
fi
;;
"$SKIP_CHOICE")
echo ""
echo -e "${YELLOW}Skipped.${NC} An LLM API key is required to test and use worker agents."
@@ -1584,6 +1666,10 @@ if [ -n "$SELECTED_PROVIDER_ID" ]; then
save_configuration "$SELECTED_PROVIDER_ID" "$SELECTED_ENV_VAR" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" "" "$SELECTED_API_BASE" > /dev/null || SAVE_OK=false
elif [ "$SELECTED_PROVIDER_ID" = "openrouter" ]; then
save_configuration "$SELECTED_PROVIDER_ID" "$SELECTED_ENV_VAR" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" "" "$SELECTED_API_BASE" > /dev/null || SAVE_OK=false
elif [ "$SELECTED_PROVIDER_ID" = "ollama" ]; then
# Pass api_base explicitly — LiteLLM requires this to route ollama/* models
# to the local Ollama server instead of trying to reach a remote endpoint.
save_configuration "ollama" "" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" "" "http://localhost:11434" > /dev/null || SAVE_OK=false
else
save_configuration "$SELECTED_PROVIDER_ID" "$SELECTED_ENV_VAR" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" > /dev/null || SAVE_OK=false
fi
@@ -1859,6 +1945,9 @@ if [ -n "$SELECTED_PROVIDER_ID" ]; then
elif [ "$SELECTED_PROVIDER_ID" = "openrouter" ]; then
echo -e " ${GREEN}${NC} OpenRouter API Key → ${DIM}$SELECTED_MODEL${NC}"
echo -e " ${DIM}API: openrouter.ai/api/v1 (OpenAI-compatible)${NC}"
elif [ "$SELECTED_PROVIDER_ID" = "ollama" ]; then
echo -e " ${GREEN}${NC} Local (Ollama) → ${DIM}$SELECTED_MODEL${NC}"
echo -e " ${DIM}No API key required (runs locally via http://localhost:11434)${NC}"
else
echo -e " ${CYAN}$SELECTED_PROVIDER_ID${NC}${DIM}$SELECTED_MODEL${NC}"
fi
+16 -1
View File
@@ -318,7 +318,22 @@ PROVIDERS = {
key, "https://api.cerebras.ai/v1/models", "Cerebras"
),
"openrouter": lambda key, **kw: check_openrouter(key, **kw),
"minimax": lambda key, **kw: check_minimax(key),
"deepseek": lambda key, **_: check_openai_compatible(
key, "https://api.deepseek.com/v1/models", "DeepSeek"
),
"together": lambda key, **_: check_openai_compatible(
key, "https://api.together.xyz/v1/models", "Together AI"
),
"mistral": lambda key, **_: check_openai_compatible(
key, "https://api.mistral.ai/v1/models", "Mistral"
),
"xai": lambda key, **_: check_openai_compatible(
key, "https://api.x.ai/v1/models", "xAI"
),
"perplexity": lambda key, **_: check_openai_compatible(
key, "https://api.perplexity.ai/v1/models", "Perplexity"
),
"minimax": lambda key, **_: check_minimax(key),
# Kimi For Coding uses an Anthropic-compatible endpoint; check via /v1/messages
# with empty messages (same as check_anthropic, triggers 400 not 401).
"kimi": lambda key, **kw: check_anthropic_compatible(
@@ -95,6 +95,7 @@ from .kafka import KAFKA_CREDENTIALS
from .langfuse import LANGFUSE_CREDENTIALS
from .linear import LINEAR_CREDENTIALS
from .lusha import LUSHA_CREDENTIALS
from .mattermost import MATTERMOST_CREDENTIALS
from .microsoft_graph import MICROSOFT_GRAPH_CREDENTIALS
from .mongodb import MONGODB_CREDENTIALS
from .n8n import N8N_CREDENTIALS
@@ -179,6 +180,7 @@ CREDENTIAL_SPECS = {
**LANGFUSE_CREDENTIALS,
**LINEAR_CREDENTIALS,
**LUSHA_CREDENTIALS,
**MATTERMOST_CREDENTIALS,
**MICROSOFT_GRAPH_CREDENTIALS,
**MONGODB_CREDENTIALS,
**N8N_CREDENTIALS,
@@ -271,6 +273,7 @@ __all__ = [
"LANGFUSE_CREDENTIALS",
"LINEAR_CREDENTIALS",
"LUSHA_CREDENTIALS",
"MATTERMOST_CREDENTIALS",
"MICROSOFT_GRAPH_CREDENTIALS",
"MONGODB_CREDENTIALS",
"N8N_CREDENTIALS",
@@ -0,0 +1,65 @@
"""
Mattermost tool credentials.
Contains credentials for Mattermost server integration.
"""
from .base import CredentialSpec
MATTERMOST_CREDENTIALS = {
"mattermost": CredentialSpec(
env_var="MATTERMOST_ACCESS_TOKEN",
tools=[
"mattermost_list_teams",
"mattermost_list_channels",
"mattermost_get_channel",
"mattermost_send_message",
"mattermost_get_posts",
"mattermost_create_reaction",
"mattermost_delete_post",
],
required=True,
startup_required=False,
help_url="https://developers.mattermost.com/integrate/reference/personal-access-token/",
description="Mattermost Personal Access Token",
aden_supported=False,
direct_api_key_supported=True,
api_key_instructions="""To get a Mattermost Personal Access Token:
1. Log in to your Mattermost server
2. Go to Profile > Security > Personal Access Tokens
3. Click "Create Token"
4. Give it a description and click "Save"
5. Copy the token (it won't be shown again)
Note: Personal access tokens must be enabled by your System Admin.
Also set MATTERMOST_URL to your server URL (e.g. https://mattermost.example.com)""",
health_check_endpoint=None,
health_check_method="GET",
credential_id="mattermost",
credential_key="access_token",
),
"mattermost_url": CredentialSpec(
env_var="MATTERMOST_URL",
tools=[
"mattermost_list_teams",
"mattermost_list_channels",
"mattermost_get_channel",
"mattermost_send_message",
"mattermost_get_posts",
"mattermost_create_reaction",
"mattermost_delete_post",
],
required=True,
startup_required=False,
help_url="https://developers.mattermost.com/integrate/reference/personal-access-token/",
description="Mattermost Server URL (e.g. https://mattermost.example.com)",
aden_supported=False,
direct_api_key_supported=True,
api_key_instructions="""Set this to your Mattermost server URL, e.g. https://mattermost.example.com
Do not include /api/v4 it will be added automatically.""",
health_check_endpoint=None,
health_check_method="GET",
credential_id="mattermost_url",
credential_key="url",
),
}
@@ -1,13 +1,15 @@
"""
Shell configuration utilities for persisting environment variables.
Supports both bash and zsh, detecting the user's default shell.
Supports bash and zsh with platform-aware fallbacks for login-shell config
files such as ``.bash_profile``, ``.zshenv``, and ``.profile``.
Used primarily for persisting ADEN_API_KEY across sessions.
"""
from __future__ import annotations
import os
import platform
import re
from pathlib import Path
from typing import Literal
@@ -34,9 +36,9 @@ def detect_shell() -> ShellType:
else:
# Try to detect from config file existence
home = Path.home()
if (home / ".zshrc").exists():
if (home / ".zshrc").exists() or (home / ".zshenv").exists():
return "zsh"
elif (home / ".bashrc").exists():
elif (home / ".bashrc").exists() or (home / ".bash_profile").exists():
return "bash"
return "unknown"
@@ -55,14 +57,12 @@ def get_shell_config_path(shell_type: ShellType | None = None) -> Path:
shell_type = detect_shell()
home = Path.home()
candidates = _get_shell_config_candidates(home, shell_type)
if shell_type == "zsh":
return home / ".zshrc"
elif shell_type == "bash":
return home / ".bashrc"
else:
# Default to .bashrc for unknown shells
return home / ".bashrc"
for candidate in candidates:
if candidate.exists():
return candidate
return candidates[0]
def check_env_var_in_shell_config(
@@ -79,29 +79,47 @@ def check_env_var_in_shell_config(
Returns:
Tuple of (exists, current_value or None)
"""
config_path = get_shell_config_path(shell_type)
if shell_type is None:
shell_type = detect_shell()
if not config_path.exists():
return False, None
for config_path in _get_shell_config_candidates(Path.home(), shell_type):
if not config_path.exists():
continue
content = config_path.read_text(encoding="utf-8")
content = config_path.read_text(encoding="utf-8")
# Look for export ENV_VAR=value or export ENV_VAR="value"
pattern = rf"^export\s+{re.escape(env_var)}=(.+)$"
match = re.search(pattern, content, re.MULTILINE)
# Look for export ENV_VAR=value or export ENV_VAR="value"
pattern = rf"^export\s+{re.escape(env_var)}=(.+)$"
match = re.search(pattern, content, re.MULTILINE)
if match:
value = match.group(1).strip()
# Remove surrounding quotes if present
if (value.startswith('"') and value.endswith('"')) or (
value.startswith("'") and value.endswith("'")
):
value = value[1:-1]
return True, value
if match:
value = match.group(1).strip()
# Remove surrounding quotes if present
if (value.startswith('"') and value.endswith('"')) or (
value.startswith("'") and value.endswith("'")
):
value = value[1:-1]
return True, value
return False, None
def _get_shell_config_candidates(home: Path, shell_type: ShellType) -> list[Path]:
"""Return candidate config files in lookup order for the detected shell."""
if shell_type == "zsh":
return [home / ".zshrc", home / ".zshenv"]
if shell_type == "bash":
# Git Bash commonly launches login shells on Windows, so prefer
# ``.bash_profile`` there for writes, but keep ``.bashrc`` in the
# lookup list so older setups continue to work.
if platform.system() == "Windows":
return [home / ".bash_profile", home / ".bashrc", home / ".profile"]
return [home / ".bashrc", home / ".bash_profile", home / ".profile"]
return [home / ".profile", home / ".bashrc"]
def add_env_var_to_shell_config(
env_var: str,
value: str,
+2
View File
@@ -88,6 +88,7 @@ from .kafka_tool import register_tools as register_kafka
from .langfuse_tool import register_tools as register_langfuse
from .linear_tool import register_tools as register_linear
from .lusha_tool import register_tools as register_lusha
from .mattermost_tool import register_tools as register_mattermost
from .microsoft_graph_tool import register_tools as register_microsoft_graph
from .mongodb_tool import register_tools as register_mongodb
from .n8n_tool import register_tools as register_n8n
@@ -266,6 +267,7 @@ def _register_unverified(
register_langfuse(mcp, credentials=credentials)
register_linear(mcp, credentials=credentials)
register_lusha(mcp, credentials=credentials)
register_mattermost(mcp, credentials=credentials)
register_microsoft_graph(mcp, credentials=credentials)
register_mongodb(mcp, credentials=credentials)
register_n8n(mcp, credentials=credentials)
+25 -24
View File
@@ -2,6 +2,7 @@
import csv
import os
import re
from fastmcp import FastMCP
@@ -330,39 +331,39 @@ def register_tools(mcp: FastMCP) -> None:
if not query or not query.strip():
return {"error": "query cannot be empty"}
# Security: only allow SELECT statements
query_upper = query.strip().upper()
if not query_upper.startswith("SELECT"):
# Security: allow SELECT/WITH only
query_upper = query.lstrip().upper()
if not (query_upper.startswith("SELECT") or query_upper.startswith("WITH")):
return {"error": "Only SELECT queries are allowed for security reasons"}
# Disallowed keywords for security
disallowed = [
"INSERT",
"UPDATE",
"DELETE",
"DROP",
"CREATE",
"ALTER",
"TRUNCATE",
"EXEC",
"EXECUTE",
]
for keyword in disallowed:
if keyword in query_upper:
return {"error": f"'{keyword}' is not allowed in queries"}
# Disallowed keywords for security (word-boundary match to avoid
# false positives on column names like created_at, updated_at, etc.)
_WRITE_PATTERN = re.compile(
r"\b(INSERT|UPDATE|DELETE|DROP|CREATE|ALTER|TRUNCATE|EXEC|EXECUTE)\b",
re.IGNORECASE,
)
match = _WRITE_PATTERN.search(query)
if match:
return {"error": f"'{match.group().upper()}' is not allowed in queries"}
# Block obvious multi-statement / injection attempts
q_lower = query.lower()
for token in [";", "--", "/*", "*/"]:
if token in q_lower:
return {"error": "Multiple statements or comments are not allowed"}
# Execute query using in-memory DuckDB
con = duckdb.connect(":memory:")
try:
# Load CSV as 'data' table
con.execute(f"CREATE TABLE data AS SELECT * FROM read_csv_auto('{secure_path}')")
# SAFE: parameter binding (no string interpolation)
con.execute(
"CREATE TABLE data AS SELECT * FROM read_csv_auto(?)",
[str(secure_path)],
)
# Execute user query
result = con.execute(query)
columns = [desc[0] for desc in result.description]
rows = result.fetchall()
# Convert to list of dicts
rows_as_dicts = [dict(zip(columns, row, strict=False)) for row in rows]
return {
@@ -374,12 +375,12 @@ def register_tools(mcp: FastMCP) -> None:
"rows": rows_as_dicts,
"row_count": len(rows_as_dicts),
}
finally:
con.close()
except Exception as e:
error_msg = str(e)
# Make DuckDB errors more readable
if "Catalog Error" in error_msg:
return {"error": f"SQL error: {error_msg}. Remember the table is named 'data'."}
return {"error": f"Query failed: {error_msg}"}
@@ -12,7 +12,6 @@ import os
from typing import TYPE_CHECKING, Literal
import httpx
import resend
from fastmcp import FastMCP
if TYPE_CHECKING:
@@ -35,6 +34,15 @@ def register_tools(
bcc: list[str] | None = None,
) -> dict:
"""Send email using Resend API."""
try:
import resend
except ImportError:
return {
"error": (
"resend not installed. Install with: "
"pip install resend or pip install tools[email]"
)
}
resend.api_key = api_key
try:
payload: dict = {
@@ -0,0 +1,5 @@
"""Mattermost Tool - Send messages and interact with Mattermost servers."""
from .mattermost_tool import register_tools
__all__ = ["register_tools"]
@@ -0,0 +1,447 @@
"""
Mattermost Tool - Send messages and interact with Mattermost servers via Mattermost API.
Supports:
- Personal access tokens (MATTERMOST_ACCESS_TOKEN)
- Self-hosted and cloud Mattermost instances (MATTERMOST_URL)
API Reference: https://api.mattermost.com/
"""
from __future__ import annotations
import os
import time
from typing import TYPE_CHECKING, Any
import httpx
from fastmcp import FastMCP
if TYPE_CHECKING:
from aden_tools.credentials import CredentialStoreAdapter
MAX_MESSAGE_LENGTH = 16383 # Mattermost API limit
MAX_RETRIES = 2 # 3 total attempts on 429
MAX_RETRY_WAIT = 60 # cap wait at 60s
class _MattermostClient:
"""Internal client wrapping Mattermost API calls."""
def __init__(self, access_token: str, base_url: str):
# Strip trailing slash and ensure /api/v4 suffix
base_url = base_url.rstrip("/")
if not base_url.endswith("/api/v4"):
base_url = f"{base_url}/api/v4"
self._base_url = base_url
self._token = access_token
@property
def _headers(self) -> dict[str, str]:
return {
"Authorization": f"Bearer {self._token}",
"Content-Type": "application/json",
}
def _request_with_retry(
self,
method: str,
url: str,
**kwargs: Any,
) -> dict[str, Any]:
"""Make HTTP request with retry on 429 rate limit."""
request_kwargs = {"headers": self._headers, "timeout": 30.0, **kwargs}
for attempt in range(MAX_RETRIES + 1):
response = httpx.request(method, url, **request_kwargs)
if response.status_code == 429 and attempt < MAX_RETRIES:
try:
wait = min(float(response.headers.get("Retry-After", 1)), MAX_RETRY_WAIT)
except (ValueError, TypeError):
wait = min(2**attempt, MAX_RETRY_WAIT)
time.sleep(wait)
continue
return self._handle_response(response)
return self._handle_response(response)
def _handle_response(self, response: httpx.Response) -> dict[str, Any]:
"""Handle Mattermost API response format."""
if response.status_code == 204:
return {"success": True}
if response.status_code == 429:
try:
retry_after = float(response.headers.get("Retry-After", 60))
except (ValueError, TypeError):
retry_after = 60
return {
"error": f"Mattermost rate limit exceeded. Retry after {retry_after}s",
"retry_after": retry_after,
}
if response.status_code not in (200, 201):
try:
data = response.json()
message = data.get("message", response.text)
except Exception:
message = response.text
return {"error": f"HTTP {response.status_code}: {message}"}
return response.json()
def get_me(self) -> dict[str, Any]:
"""Get the authenticated user's info (health check)."""
return self._request_with_retry("GET", f"{self._base_url}/users/me")
def list_teams(self) -> dict[str, Any]:
"""List teams the authenticated user belongs to."""
return self._request_with_retry("GET", f"{self._base_url}/users/me/teams")
def list_channels(self, team_id: str, per_page: int = 100) -> dict[str, Any]:
"""List public channels for a team."""
return self._request_with_retry(
"GET",
f"{self._base_url}/teams/{team_id}/channels",
params={"per_page": min(per_page, 200)},
)
def get_channel(self, channel_id: str) -> dict[str, Any]:
"""Get detailed information about a channel."""
return self._request_with_retry("GET", f"{self._base_url}/channels/{channel_id}")
def send_message(
self,
channel_id: str,
message: str,
*,
root_id: str = "",
) -> dict[str, Any]:
"""Create a post in a channel."""
body: dict[str, Any] = {
"channel_id": channel_id,
"message": message,
}
if root_id:
body["root_id"] = root_id
return self._request_with_retry(
"POST",
f"{self._base_url}/posts",
json=body,
)
def get_posts(
self,
channel_id: str,
per_page: int = 60,
page: int = 0,
before: str = "",
after: str = "",
) -> dict[str, Any]:
"""Get posts from a channel."""
params: dict[str, Any] = {
"per_page": min(per_page, 200),
"page": page,
}
if before:
params["before"] = before
if after:
params["after"] = after
return self._request_with_retry(
"GET",
f"{self._base_url}/channels/{channel_id}/posts",
params=params,
)
def create_reaction(
self,
post_id: str,
emoji_name: str,
) -> dict[str, Any]:
"""Add a reaction to a post.
API ref: POST /reactions
"""
# Need user_id for the reaction; fetch from /users/me
me = self.get_me()
if isinstance(me, dict) and "error" in me:
return me
user_id = me.get("id", "")
return self._request_with_retry(
"POST",
f"{self._base_url}/reactions",
json={
"user_id": user_id,
"post_id": post_id,
"emoji_name": emoji_name,
},
)
def delete_post(self, post_id: str) -> dict[str, Any]:
"""Delete a post."""
return self._request_with_retry("DELETE", f"{self._base_url}/posts/{post_id}")
def register_tools(
mcp: FastMCP,
credentials: CredentialStoreAdapter | None = None,
) -> None:
"""Register Mattermost tools with the MCP server."""
def _get_token(account: str = "") -> str | None:
"""Get Mattermost access token from credential manager or environment."""
if credentials is not None:
if account:
return credentials.get_by_alias("mattermost", account)
token = credentials.get("mattermost")
if token is not None and not isinstance(token, str):
raise TypeError(
"Expected string from credentials.get('mattermost'), "
f"got {type(token).__name__}"
)
return token
return os.getenv("MATTERMOST_ACCESS_TOKEN")
def _get_url() -> str | None:
"""Get Mattermost server URL from credential manager or environment."""
if credentials is not None:
url = credentials.get("mattermost_url")
if url is not None and not isinstance(url, str):
raise TypeError(
"Expected string from credentials.get('mattermost_url'), "
f"got {type(url).__name__}"
)
if url:
return url
return os.getenv("MATTERMOST_URL")
def _get_client(account: str = "") -> _MattermostClient | dict[str, str]:
"""Get a Mattermost client, or return an error dict if no credentials."""
token = _get_token(account)
if not token:
return {
"error": "Mattermost credentials not configured",
"help": (
"Set MATTERMOST_ACCESS_TOKEN and MATTERMOST_URL environment variables "
"or configure via credential store"
),
}
url = _get_url()
if not url:
return {
"error": "Mattermost server URL not configured",
"help": (
"Set MATTERMOST_URL environment variable (e.g. https://mattermost.example.com) "
"or configure via credential store"
),
}
return _MattermostClient(token, url)
@mcp.tool()
def mattermost_list_teams(account: str = "") -> dict:
"""
List Mattermost teams the authenticated user belongs to.
Returns team IDs and names. Use team IDs with mattermost_list_channels.
Returns:
Dict with list of teams or error
"""
client = _get_client(account)
if isinstance(client, dict):
return client
try:
result = client.list_teams()
if isinstance(result, dict) and "error" in result:
return result
return {"teams": result, "success": True}
except httpx.TimeoutException:
return {"error": "Request timed out"}
except httpx.RequestError as e:
return {"error": f"Network error: {e}"}
@mcp.tool()
def mattermost_list_channels(team_id: str, per_page: int = 100, account: str = "") -> dict:
"""
List public channels for a Mattermost team.
Args:
team_id: Team ID. Use mattermost_list_teams to find team IDs.
per_page: Max channels to return (1-200, default 100).
Returns:
Dict with list of channels or error
"""
client = _get_client(account)
if isinstance(client, dict):
return client
try:
result = client.list_channels(team_id, per_page=per_page)
if isinstance(result, dict) and "error" in result:
return result
return {"channels": result, "success": True}
except httpx.TimeoutException:
return {"error": "Request timed out"}
except httpx.RequestError as e:
return {"error": f"Network error: {e}"}
@mcp.tool()
def mattermost_get_channel(channel_id: str, account: str = "") -> dict:
"""
Get detailed information about a Mattermost channel.
Returns channel metadata including name, display name, header, purpose,
and type.
Args:
channel_id: Channel ID
Returns:
Dict with channel details or error
"""
client = _get_client(account)
if isinstance(client, dict):
return client
try:
result = client.get_channel(channel_id)
if isinstance(result, dict) and "error" in result:
return result
return {"channel": result, "success": True}
except httpx.TimeoutException:
return {"error": "Request timed out"}
except httpx.RequestError as e:
return {"error": f"Network error: {e}"}
@mcp.tool()
def mattermost_send_message(
channel_id: str,
message: str,
root_id: str = "",
account: str = "",
) -> dict:
"""
Send a message (post) to a Mattermost channel.
Args:
channel_id: Channel ID to post in
message: Message text (max 16383 characters). Supports Markdown.
root_id: Optional post ID to reply to (creates a thread)
Returns:
Dict with post details or error
"""
if len(message) > MAX_MESSAGE_LENGTH:
return {
"error": f"Message exceeds {MAX_MESSAGE_LENGTH} character limit",
"max_length": MAX_MESSAGE_LENGTH,
"provided": len(message),
}
client = _get_client(account)
if isinstance(client, dict):
return client
try:
result = client.send_message(channel_id, message, root_id=root_id)
if isinstance(result, dict) and "error" in result:
return result
return {"success": True, "post": result}
except httpx.TimeoutException:
return {"error": "Request timed out"}
except httpx.RequestError as e:
return {"error": f"Network error: {e}"}
@mcp.tool()
def mattermost_get_posts(
channel_id: str,
per_page: int = 60,
page: int = 0,
before: str = "",
after: str = "",
account: str = "",
) -> dict:
"""
Get posts from a Mattermost channel.
Args:
channel_id: Channel ID
per_page: Max posts to return (1-200, default 60)
page: Page number for pagination (default 0)
before: Post ID to get posts before (for pagination)
after: Post ID to get posts after (for pagination)
Returns:
Dict with posts or error
"""
client = _get_client(account)
if isinstance(client, dict):
return client
try:
result = client.get_posts(
channel_id,
per_page=per_page,
page=page,
before=before,
after=after,
)
if isinstance(result, dict) and "error" in result:
return result
return {"posts": result, "success": True}
except httpx.TimeoutException:
return {"error": "Request timed out"}
except httpx.RequestError as e:
return {"error": f"Network error: {e}"}
@mcp.tool()
def mattermost_create_reaction(
post_id: str,
emoji_name: str,
account: str = "",
) -> dict:
"""
Add a reaction to a Mattermost post.
Args:
post_id: ID of the post to react to
emoji_name: Emoji name without colons (e.g. "thumbsup", "heart")
Returns:
Dict with success status or error
"""
client = _get_client(account)
if isinstance(client, dict):
return client
try:
result = client.create_reaction(post_id, emoji_name)
if isinstance(result, dict) and "error" in result:
return result
return {"success": True}
except httpx.TimeoutException:
return {"error": "Request timed out"}
except httpx.RequestError as e:
return {"error": f"Network error: {e}"}
@mcp.tool()
def mattermost_delete_post(
post_id: str,
account: str = "",
) -> dict:
"""
Delete a post from Mattermost.
Requires appropriate permissions (post author or admin).
Args:
post_id: ID of the post to delete
Returns:
Dict with success status or error
"""
client = _get_client(account)
if isinstance(client, dict):
return client
try:
result = client.delete_post(post_id)
if isinstance(result, dict) and "error" in result:
return result
return {"success": True, "deleted_post_id": post_id}
except httpx.TimeoutException:
return {"error": "Request timed out"}
except httpx.RequestError as e:
return {"error": f"Network error: {e}"}
+38 -16
View File
@@ -8,25 +8,13 @@ from __future__ import annotations
import base64
import os
from typing import Any
from typing import TYPE_CHECKING, Any
import httpx
from fastmcp import FastMCP
def _get_config() -> tuple[str, dict] | dict:
"""Return (base_url, headers) or error dict."""
base_url = os.getenv("SAP_BASE_URL", "").rstrip("/")
username = os.getenv("SAP_USERNAME", "")
password = os.getenv("SAP_PASSWORD", "")
if not base_url or not username or not password:
return {
"error": "SAP_BASE_URL, SAP_USERNAME, and SAP_PASSWORD are required",
"help": "Set SAP_BASE_URL, SAP_USERNAME, and SAP_PASSWORD environment variables",
}
creds = base64.b64encode(f"{username}:{password}".encode()).decode()
headers = {"Authorization": f"Basic {creds}", "Accept": "application/json"}
return base_url, headers
if TYPE_CHECKING:
from aden_tools.credentials import CredentialStoreAdapter
def _get(url: str, headers: dict, params: dict | None = None) -> dict:
@@ -45,9 +33,43 @@ def _odata_list(data: dict) -> tuple[list, int | None]:
return results, count
def register_tools(mcp: FastMCP, credentials: Any = None) -> None:
def register_tools(
mcp: FastMCP,
credentials: CredentialStoreAdapter | None = None,
) -> None:
"""Register SAP S/4HANA tools."""
def _get_config() -> tuple[str, dict] | dict[str, str]:
"""Return (base_url, headers) or error dict."""
base_url = username = password = None
if credentials is not None:
try:
base_url = credentials.get("sap_base_url")
username = credentials.get("sap_username")
password = credentials.get("sap_password")
except KeyError:
pass
base_url = base_url or os.getenv("SAP_BASE_URL")
username = username or os.getenv("SAP_USERNAME")
password = password or os.getenv("SAP_PASSWORD")
if not base_url or not username or not password:
return {
"error": "SAP credentials not configured",
"help": (
"Set SAP_BASE_URL, SAP_USERNAME, and SAP_PASSWORD "
"environment variables or configure via credential store"
),
}
base_url = base_url.rstrip("/")
encoded = base64.b64encode(f"{username}:{password}".encode()).decode()
headers = {
"Authorization": f"Basic {encoded}",
"Accept": "application/json",
}
return base_url, headers
@mcp.tool()
def sap_list_purchase_orders(
top: int = 50,
@@ -4,10 +4,13 @@ Web Scrape Tool - Extract content from web pages.
Uses Playwright with stealth for headless browser scraping,
enabling JavaScript-rendered content and bot detection evasion.
Uses BeautifulSoup for HTML parsing and content extraction.
Validates URLs against internal network ranges to prevent SSRF attacks.
"""
from __future__ import annotations
import ipaddress
import socket
from typing import Any
from urllib.parse import urljoin, urlparse
from urllib.robotparser import RobotFileParser
@@ -29,6 +32,49 @@ BROWSER_USER_AGENT = (
)
def _is_internal_address(raw_ip: str) -> bool:
"""Check whether an IP address targets non-public infrastructure."""
ip_str = raw_ip.split("%")[0] if "%" in raw_ip else raw_ip
try:
addr = ipaddress.ip_address(ip_str)
except ValueError:
return True # Unparseable — fail closed
return not addr.is_global or addr.is_multicast
def _check_url_target(url: str) -> str | None:
"""Resolve a URL's hostname and reject it if any address is non-public.
Returns an error message if blocked, None if safe.
"""
hostname = urlparse(url).hostname
if not hostname:
return "Invalid URL: missing hostname"
# Fast-path for raw IP literals
try:
ipaddress.ip_address(hostname)
if _is_internal_address(hostname):
return f"Blocked: direct request to internal address ({hostname})"
except ValueError:
pass # Not an IP literal, resolve below
try:
results = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
except socket.gaierror:
return f"DNS resolution failed for host: {hostname}"
if not results:
return f"No DNS records found for host: {hostname}"
for entry in results:
resolved_ip = str(entry[4][0])
if _is_internal_address(resolved_ip):
return f"Blocked: {hostname} resolves to internal address"
return None
def register_tools(mcp: FastMCP) -> None:
"""Register web scrape tools with the MCP server."""
@@ -65,6 +111,12 @@ def register_tools(mcp: FastMCP) -> None:
# Validate max_length
max_length = max(1000, min(max_length, 500000))
# SSRF check: validate URL before making any request (must run
# before robots.txt fetch, which also makes a network request)
block_reason = _check_url_target(url)
if block_reason is not None:
return {"error": block_reason, "blocked_by_ssrf_protection": True, "url": url}
# Check robots.txt before launching browser
if respect_robots_txt:
try:
@@ -102,12 +154,44 @@ def register_tools(mcp: FastMCP) -> None:
page = await context.new_page()
await Stealth().apply_stealth_async(page)
# Intercept navigation requests to block SSRF via redirects.
# Only check "document" requests (navigations), not
# sub-resources (CSS/JS/images) to avoid false positives
# and unnecessary DNS lookups.
ssrf_blocked: dict[str, Any] | None = None
async def _ssrf_route_handler(route):
nonlocal ssrf_blocked
req_url = route.request.url
# Skip non-network schemes (data:, blob:, etc.)
if urlparse(req_url).scheme not in {"http", "https"}:
await route.continue_()
return
block = _check_url_target(req_url)
if block is not None:
ssrf_blocked = {
"error": block,
"blocked_by_ssrf_protection": True,
"url": req_url,
}
await route.abort("blockedbyclient")
else:
await route.continue_()
await page.route("**/*", _ssrf_route_handler)
response = await page.goto(
url,
wait_until="domcontentloaded",
timeout=60000,
)
# Check if a redirect was blocked by SSRF protection
if ssrf_blocked is not None:
return ssrf_blocked
# Validate response before waiting for JS render
if response is None:
return {"error": "Navigation failed: no response received"}
+90
View File
@@ -0,0 +1,90 @@
"""Tests for shell config path selection and env-var lookups."""
from pathlib import Path
from aden_tools.credentials import shell_config
def _mock_home(monkeypatch, tmp_path: Path) -> None:
monkeypatch.setattr(shell_config.Path, "home", staticmethod(lambda: tmp_path))
def test_get_shell_config_path_prefers_existing_bash_profile(monkeypatch, tmp_path):
_mock_home(monkeypatch, tmp_path)
monkeypatch.setenv("SHELL", "/usr/bin/bash")
monkeypatch.setattr(shell_config.platform, "system", lambda: "Windows")
(tmp_path / ".bashrc").write_text("# bashrc\n", encoding="utf-8")
(tmp_path / ".bash_profile").write_text("# bash profile\n", encoding="utf-8")
assert shell_config.get_shell_config_path() == tmp_path / ".bash_profile"
def test_get_shell_config_path_prefers_bashrc_for_non_windows_bash(monkeypatch, tmp_path):
_mock_home(monkeypatch, tmp_path)
monkeypatch.setenv("SHELL", "/usr/bin/bash")
monkeypatch.setattr(shell_config.platform, "system", lambda: "Linux")
(tmp_path / ".bashrc").write_text("# bashrc\n", encoding="utf-8")
(tmp_path / ".bash_profile").write_text("# bash profile\n", encoding="utf-8")
assert shell_config.get_shell_config_path() == tmp_path / ".bashrc"
def test_check_env_var_in_shell_config_reads_bash_profile(monkeypatch, tmp_path):
_mock_home(monkeypatch, tmp_path)
monkeypatch.setenv("SHELL", "/usr/bin/bash")
monkeypatch.setattr(shell_config.platform, "system", lambda: "Windows")
(tmp_path / ".bash_profile").write_text(
'export HIVE_API_KEY="hive-key-123"\n',
encoding="utf-8",
)
assert shell_config.check_env_var_in_shell_config("HIVE_API_KEY") == (
True,
"hive-key-123",
)
def test_check_env_var_in_shell_config_falls_back_to_bashrc_on_windows(monkeypatch, tmp_path):
_mock_home(monkeypatch, tmp_path)
monkeypatch.setenv("SHELL", "/usr/bin/bash")
monkeypatch.setattr(shell_config.platform, "system", lambda: "Windows")
(tmp_path / ".bash_profile").write_text("# no key here\n", encoding="utf-8")
(tmp_path / ".bashrc").write_text(
'export HIVE_API_KEY="hive-key-from-bashrc"\n',
encoding="utf-8",
)
assert shell_config.check_env_var_in_shell_config("HIVE_API_KEY") == (
True,
"hive-key-from-bashrc",
)
def test_check_env_var_in_shell_config_reads_zshenv_when_zshrc_missing(monkeypatch, tmp_path):
_mock_home(monkeypatch, tmp_path)
monkeypatch.setenv("SHELL", "/bin/zsh")
monkeypatch.setattr(shell_config.platform, "system", lambda: "Darwin")
(tmp_path / ".zshenv").write_text(
"export OPENROUTER_API_KEY='or-key-123'\n",
encoding="utf-8",
)
assert shell_config.check_env_var_in_shell_config("OPENROUTER_API_KEY") == (
True,
"or-key-123",
)
def test_get_shell_config_path_falls_back_to_profile_for_unknown_shell(monkeypatch, tmp_path):
_mock_home(monkeypatch, tmp_path)
monkeypatch.setenv("SHELL", "/usr/bin/fish")
monkeypatch.setattr(shell_config.platform, "system", lambda: "Linux")
(tmp_path / ".profile").write_text("# profile\n", encoding="utf-8")
assert shell_config.get_shell_config_path() == tmp_path / ".profile"
+94
View File
@@ -732,6 +732,100 @@ class TestCsvSql:
assert "id" in result["columns"]
assert "name" in result["columns"]
def test_path_with_single_quote(self, csv_tools, session_dir, tmp_path):
"""Regression: CSV paths containing single quotes should work (parameter binding)."""
csv_file = session_dir / "O'Reilly.csv"
csv_file.write_text("name,age\nAlice,21\nBob,22\n", encoding="utf-8")
with patch("aden_tools.tools.file_system_toolkits.security.WORKSPACES_DIR", str(tmp_path)):
result = csv_tools["csv_sql"](
path="O'Reilly.csv",
workspace_id=TEST_WORKSPACE_ID,
agent_id=TEST_AGENT_ID,
session_id=TEST_SESSION_ID,
query="SELECT * FROM data",
)
assert "error" not in result, result
assert result["success"] is True
assert result["row_count"] == 2
names = [row["name"] for row in result["rows"]]
assert "Alice" in names
assert "Bob" in names
# --- NEW: security regression tests required by Issue #1256 ---
def test_reject_non_select(self, csv_tools, products_csv, tmp_path):
"""Reject any non-SELECT / non-WITH query."""
with patch("aden_tools.tools.file_system_toolkits.security.WORKSPACES_DIR", str(tmp_path)):
result = csv_tools["csv_sql"](
path=products_csv.name,
workspace_id=TEST_WORKSPACE_ID,
agent_id=TEST_AGENT_ID,
session_id=TEST_SESSION_ID,
query="DROP TABLE data",
)
assert "error" in result
def test_reject_multi_statement(self, csv_tools, products_csv, tmp_path):
"""Reject multi-statement queries with semicolons."""
with patch("aden_tools.tools.file_system_toolkits.security.WORKSPACES_DIR", str(tmp_path)):
result = csv_tools["csv_sql"](
path=products_csv.name,
workspace_id=TEST_WORKSPACE_ID,
agent_id=TEST_AGENT_ID,
session_id=TEST_SESSION_ID,
query="SELECT * FROM data; DROP TABLE data",
)
assert "error" in result
def test_reject_sql_comment_dash(self, csv_tools, products_csv, tmp_path):
"""Reject queries with SQL line comments."""
with patch("aden_tools.tools.file_system_toolkits.security.WORKSPACES_DIR", str(tmp_path)):
result = csv_tools["csv_sql"](
path=products_csv.name,
workspace_id=TEST_WORKSPACE_ID,
agent_id=TEST_AGENT_ID,
session_id=TEST_SESSION_ID,
query="SELECT * FROM data -- WHERE id = 1",
)
assert "error" in result
def test_with_cte_allowed(self, csv_tools, products_csv, tmp_path):
"""Allow valid WITH (CTE) queries."""
with patch("aden_tools.tools.file_system_toolkits.security.WORKSPACES_DIR", str(tmp_path)):
result = csv_tools["csv_sql"](
path=products_csv.name,
workspace_id=TEST_WORKSPACE_ID,
agent_id=TEST_AGENT_ID,
session_id=TEST_SESSION_ID,
query=(
"WITH electronics AS (SELECT * FROM data"
" WHERE category = 'Electronics')"
" SELECT * FROM electronics"
),
)
assert result["success"] is True
def test_keyword_in_column_name_allowed(self, csv_tools, session_dir, tmp_path):
"""Column names like created_at should not trigger keyword blocking."""
csv_file = session_dir / "timestamps.csv"
csv_file.write_text(
"created_at,updated_at,value\n2024-01-01,2024-01-02,100\n",
encoding="utf-8",
)
with patch("aden_tools.tools.file_system_toolkits.security.WORKSPACES_DIR", str(tmp_path)):
result = csv_tools["csv_sql"](
path="timestamps.csv",
workspace_id=TEST_WORKSPACE_ID,
agent_id=TEST_AGENT_ID,
session_id=TEST_SESSION_ID,
query="SELECT created_at, updated_at FROM data",
)
assert "error" not in result, result
assert result["success"] is True
def test_where_clause(self, csv_tools, products_csv, tmp_path):
"""Filter with WHERE clause."""
with patch("aden_tools.tools.file_system_toolkits.security.WORKSPACES_DIR", str(tmp_path)):
+612
View File
@@ -0,0 +1,612 @@
"""
Tests for Mattermost tool.
Covers:
- _MattermostClient methods (list_teams, list_channels, send_message, get_posts, etc.)
- Error handling (401, 403, 404, 429, timeout)
- Credential retrieval (CredentialStoreAdapter vs env var)
- All MCP tool functions
"""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from aden_tools.tools.mattermost_tool.mattermost_tool import (
MAX_MESSAGE_LENGTH,
MAX_RETRIES,
_MattermostClient,
register_tools,
)
# --- _MattermostClient tests ---
class TestMattermostClient:
def setup_method(self):
self.client = _MattermostClient("test-access-token", "https://mattermost.example.com")
def test_headers(self):
headers = self.client._headers
assert headers["Content-Type"] == "application/json"
assert headers["Authorization"] == "Bearer test-access-token"
def test_base_url_strips_trailing_slash(self):
client = _MattermostClient("tok", "https://mm.example.com/")
assert client._base_url == "https://mm.example.com/api/v4"
def test_base_url_preserves_api_v4(self):
client = _MattermostClient("tok", "https://mm.example.com/api/v4")
assert client._base_url == "https://mm.example.com/api/v4"
def test_base_url_appends_api_v4(self):
client = _MattermostClient("tok", "https://mm.example.com")
assert client._base_url == "https://mm.example.com/api/v4"
def test_handle_response_success(self):
response = MagicMock()
response.status_code = 200
response.json.return_value = {"id": "abc123", "username": "testbot"}
assert self.client._handle_response(response) == {
"id": "abc123",
"username": "testbot",
}
def test_handle_response_201(self):
response = MagicMock()
response.status_code = 201
response.json.return_value = {"id": "post123", "message": "hello"}
result = self.client._handle_response(response)
assert result == {"id": "post123", "message": "hello"}
def test_handle_response_204(self):
response = MagicMock()
response.status_code = 204
result = self.client._handle_response(response)
assert result == {"success": True}
def test_handle_response_rate_limit_429(self):
response = MagicMock()
response.status_code = 429
response.headers = {"Retry-After": "2.5"}
result = self.client._handle_response(response)
assert "error" in result
assert "rate limit" in result["error"].lower()
assert result["retry_after"] == 2.5
@pytest.mark.parametrize("status_code", [401, 403, 404, 500])
def test_handle_response_errors(self, status_code):
response = MagicMock()
response.status_code = status_code
response.json.return_value = {"message": "Test error"}
response.text = "Test error"
result = self.client._handle_response(response)
assert "error" in result
assert str(status_code) in result["error"]
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
def test_list_teams(self, mock_request):
mock_request.return_value = MagicMock(
status_code=200,
json=MagicMock(
return_value=[
{"id": "t1", "name": "test-team", "display_name": "Test Team"},
{"id": "t2", "name": "dev-team", "display_name": "Dev Team"},
]
),
)
result = self.client.list_teams()
mock_request.assert_called_once()
assert mock_request.call_args[0][0] == "GET"
assert "users/me/teams" in mock_request.call_args[0][1]
assert len(result) == 2
assert result[0]["display_name"] == "Test Team"
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
def test_list_channels(self, mock_request):
mock_request.return_value = MagicMock(
status_code=200,
json=MagicMock(
return_value=[
{"id": "c1", "name": "town-square", "type": "O"},
{"id": "c2", "name": "off-topic", "type": "O"},
]
),
)
result = self.client.list_channels("t1")
mock_request.assert_called_once()
assert "teams/t1/channels" in mock_request.call_args[0][1]
assert len(result) == 2
assert result[0]["name"] == "town-square"
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
def test_send_message(self, mock_request):
mock_request.return_value = MagicMock(
status_code=201,
json=MagicMock(
return_value={
"id": "p123",
"channel_id": "c1",
"message": "Hello world",
}
),
)
result = self.client.send_message("c1", "Hello world")
mock_request.assert_called_once()
assert mock_request.call_args[0][0] == "POST"
assert "posts" in mock_request.call_args[0][1]
assert result["message"] == "Hello world"
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
def test_send_message_with_thread(self, mock_request):
mock_request.return_value = MagicMock(
status_code=201,
json=MagicMock(
return_value={
"id": "p124",
"channel_id": "c1",
"message": "Reply",
"root_id": "p123",
}
),
)
result = self.client.send_message("c1", "Reply", root_id="p123")
call_kwargs = mock_request.call_args[1]
assert call_kwargs["json"]["root_id"] == "p123"
assert result["root_id"] == "p123"
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
def test_get_posts(self, mock_request):
mock_request.return_value = MagicMock(
status_code=200,
json=MagicMock(
return_value={
"order": ["p1", "p2"],
"posts": {
"p1": {"id": "p1", "message": "First"},
"p2": {"id": "p2", "message": "Second"},
},
}
),
)
result = self.client.get_posts("c1", per_page=10)
mock_request.assert_called_once()
assert mock_request.call_args[1]["params"]["per_page"] == 10
assert "order" in result
assert len(result["order"]) == 2
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
def test_get_channel(self, mock_request):
mock_request.return_value = MagicMock(
status_code=200,
json=MagicMock(
return_value={
"id": "c1",
"name": "town-square",
"display_name": "Town Square",
"type": "O",
}
),
)
result = self.client.get_channel("c1")
assert result["name"] == "town-square"
assert result["type"] == "O"
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
def test_delete_post(self, mock_request):
mock_request.return_value = MagicMock(
status_code=200, json=MagicMock(return_value={"status": "ok"})
)
self.client.delete_post("p123")
assert mock_request.call_args[0][0] == "DELETE"
assert "posts/p123" in mock_request.call_args[0][1]
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
def test_create_reaction(self, mock_request):
# First call returns user info, second creates the reaction
mock_request.side_effect = [
MagicMock(
status_code=200,
json=MagicMock(return_value={"id": "user123", "username": "testbot"}),
),
MagicMock(
status_code=200,
json=MagicMock(
return_value={
"user_id": "user123",
"post_id": "p123",
"emoji_name": "thumbsup",
}
),
),
]
result = self.client.create_reaction("p123", "thumbsup")
assert result["emoji_name"] == "thumbsup"
# Second call should be the reaction POST
assert mock_request.call_args_list[1][1]["json"]["emoji_name"] == "thumbsup"
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.time.sleep")
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
def test_retry_on_429_then_success(self, mock_request, mock_sleep):
mock_request.side_effect = [
MagicMock(
status_code=429,
headers={"Retry-After": "0.01"},
text="{}",
),
MagicMock(
status_code=200,
json=MagicMock(return_value=[{"id": "t1", "name": "team"}]),
),
]
result = self.client.list_teams()
assert len(result) == 1
assert result[0]["name"] == "team"
assert mock_request.call_count == 2
mock_sleep.assert_called_once_with(0.01)
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.time.sleep")
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
def test_retry_exhausted_returns_error(self, mock_request, mock_sleep):
mock_request.return_value = MagicMock(
status_code=429,
headers={"Retry-After": "0.01"},
text="{}",
)
result = self.client.list_teams()
assert "error" in result
assert "rate limit" in result["error"].lower()
assert mock_request.call_count == MAX_RETRIES + 1
# --- Tool registration tests ---
class TestMattermostListTeamsTool:
def setup_method(self):
self.mcp = MagicMock()
self.fns = []
self.mcp.tool.return_value = lambda fn: self.fns.append(fn) or fn
cred = MagicMock()
cred.get.side_effect = lambda key: {
"mattermost": "test-token",
"mattermost_url": "https://mattermost.example.com",
}.get(key)
register_tools(self.mcp, credentials=cred)
def _fn(self, name):
return next(f for f in self.fns if f.__name__ == name)
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
def test_list_teams_success(self, mock_request):
mock_request.return_value = MagicMock(
status_code=200,
json=MagicMock(return_value=[{"id": "t1", "name": "test-team"}]),
)
result = self._fn("mattermost_list_teams")()
assert result["success"] is True
assert len(result["teams"]) == 1
assert result["teams"][0]["name"] == "test-team"
def test_list_teams_no_credentials(self):
mcp = MagicMock()
fns = []
mcp.tool.return_value = lambda fn: fns.append(fn) or fn
register_tools(mcp, credentials=None)
with patch.dict("os.environ", {"MATTERMOST_ACCESS_TOKEN": ""}, clear=False):
result = next(f for f in fns if f.__name__ == "mattermost_list_teams")()
assert "error" in result
assert "not configured" in result["error"]
class TestMattermostListChannelsTool:
def setup_method(self):
self.mcp = MagicMock()
self.fns = []
self.mcp.tool.return_value = lambda fn: self.fns.append(fn) or fn
cred = MagicMock()
cred.get.side_effect = lambda key: {
"mattermost": "test-token",
"mattermost_url": "https://mattermost.example.com",
}.get(key)
register_tools(self.mcp, credentials=cred)
def _fn(self, name):
return next(f for f in self.fns if f.__name__ == name)
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
def test_list_channels_success(self, mock_request):
mock_request.return_value = MagicMock(
status_code=200,
json=MagicMock(
return_value=[
{"id": "c1", "name": "town-square", "type": "O"},
]
),
)
result = self._fn("mattermost_list_channels")("team-123")
assert result["success"] is True
assert len(result["channels"]) == 1
assert result["channels"][0]["name"] == "town-square"
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
def test_list_channels_error(self, mock_request):
mock_request.return_value = MagicMock(
status_code=404,
json=MagicMock(return_value={"message": "Unknown Team"}),
text="Unknown Team",
)
result = self._fn("mattermost_list_channels")("bad-team")
assert "error" in result
assert "404" in result["error"]
class TestMattermostSendMessageTool:
def setup_method(self):
self.mcp = MagicMock()
self.fns = []
self.mcp.tool.return_value = lambda fn: self.fns.append(fn) or fn
cred = MagicMock()
cred.get.side_effect = lambda key: {
"mattermost": "test-token",
"mattermost_url": "https://mattermost.example.com",
}.get(key)
register_tools(self.mcp, credentials=cred)
def _fn(self, name):
return next(f for f in self.fns if f.__name__ == name)
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
def test_send_message_success(self, mock_request):
mock_request.return_value = MagicMock(
status_code=201,
json=MagicMock(
return_value={
"id": "p123",
"channel_id": "c1",
"message": "Incident resolved",
}
),
)
result = self._fn("mattermost_send_message")("c1", "Incident resolved")
assert result["success"] is True
assert result["post"]["message"] == "Incident resolved"
def test_send_message_length_validation(self):
long_content = "x" * (MAX_MESSAGE_LENGTH + 1)
result = self._fn("mattermost_send_message")("c1", long_content)
assert "error" in result
assert str(MAX_MESSAGE_LENGTH) in result["error"]
assert result["max_length"] == MAX_MESSAGE_LENGTH
assert result["provided"] == MAX_MESSAGE_LENGTH + 1
def test_send_message_exactly_at_limit(self):
content = "x" * MAX_MESSAGE_LENGTH
with patch(
"aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request"
) as mock_request:
mock_request.return_value = MagicMock(
status_code=201,
json=MagicMock(return_value={"id": "p1", "channel_id": "c1", "message": content}),
)
result = self._fn("mattermost_send_message")("c1", content)
assert result["success"] is True
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
def test_send_message_rate_limit_429_exhausted(self, mock_request):
mock_request.return_value = MagicMock(
status_code=429,
headers={"Retry-After": "5"},
text="{}",
)
result = self._fn("mattermost_send_message")("c1", "Hello")
assert "error" in result
assert "rate limit" in result["error"].lower()
assert mock_request.call_count == MAX_RETRIES + 1
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
def test_send_message_rate_limit_then_success(self, mock_request):
mock_request.side_effect = [
MagicMock(
status_code=429,
headers={"Retry-After": "0.01"},
text="{}",
),
MagicMock(
status_code=201,
json=MagicMock(return_value={"id": "p1", "channel_id": "c1", "message": "Hi"}),
),
]
result = self._fn("mattermost_send_message")("c1", "Hi")
assert result["success"] is True
assert result["post"]["message"] == "Hi"
assert mock_request.call_count == 2
class TestMattermostGetPostsTool:
def setup_method(self):
self.mcp = MagicMock()
self.fns = []
self.mcp.tool.return_value = lambda fn: self.fns.append(fn) or fn
cred = MagicMock()
cred.get.side_effect = lambda key: {
"mattermost": "test-token",
"mattermost_url": "https://mattermost.example.com",
}.get(key)
register_tools(self.mcp, credentials=cred)
def _fn(self, name):
return next(f for f in self.fns if f.__name__ == name)
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
def test_get_posts_success(self, mock_request):
mock_request.return_value = MagicMock(
status_code=200,
json=MagicMock(
return_value={
"order": ["p1"],
"posts": {"p1": {"id": "p1", "message": "First message"}},
}
),
)
result = self._fn("mattermost_get_posts")("c1", per_page=10)
assert result["success"] is True
assert "posts" in result
class TestMattermostGetChannelTool:
def setup_method(self):
self.mcp = MagicMock()
self.fns = []
self.mcp.tool.return_value = lambda fn: self.fns.append(fn) or fn
cred = MagicMock()
cred.get.side_effect = lambda key: {
"mattermost": "test-token",
"mattermost_url": "https://mattermost.example.com",
}.get(key)
register_tools(self.mcp, credentials=cred)
def _fn(self, name):
return next(f for f in self.fns if f.__name__ == name)
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
def test_get_channel_success(self, mock_request):
mock_request.return_value = MagicMock(
status_code=200,
json=MagicMock(
return_value={
"id": "c1",
"name": "town-square",
"display_name": "Town Square",
"type": "O",
}
),
)
result = self._fn("mattermost_get_channel")("c1")
assert result["success"] is True
assert result["channel"]["name"] == "town-square"
class TestMattermostDeletePostTool:
def setup_method(self):
self.mcp = MagicMock()
self.fns = []
self.mcp.tool.return_value = lambda fn: self.fns.append(fn) or fn
cred = MagicMock()
cred.get.side_effect = lambda key: {
"mattermost": "test-token",
"mattermost_url": "https://mattermost.example.com",
}.get(key)
register_tools(self.mcp, credentials=cred)
def _fn(self, name):
return next(f for f in self.fns if f.__name__ == name)
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
def test_delete_post_success(self, mock_request):
mock_request.return_value = MagicMock(
status_code=200, json=MagicMock(return_value={"status": "ok"})
)
result = self._fn("mattermost_delete_post")("p123")
assert result["success"] is True
assert result["deleted_post_id"] == "p123"
class TestMattermostCreateReactionTool:
def setup_method(self):
self.mcp = MagicMock()
self.fns = []
self.mcp.tool.return_value = lambda fn: self.fns.append(fn) or fn
cred = MagicMock()
cred.get.side_effect = lambda key: {
"mattermost": "test-token",
"mattermost_url": "https://mattermost.example.com",
}.get(key)
register_tools(self.mcp, credentials=cred)
def _fn(self, name):
return next(f for f in self.fns if f.__name__ == name)
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
def test_create_reaction_success(self, mock_request):
mock_request.side_effect = [
MagicMock(
status_code=200,
json=MagicMock(return_value={"id": "user123"}),
),
MagicMock(
status_code=200,
json=MagicMock(
return_value={
"user_id": "user123",
"post_id": "p123",
"emoji_name": "thumbsup",
}
),
),
]
result = self._fn("mattermost_create_reaction")("p123", "thumbsup")
assert result["success"] is True
class TestMattermostNoUrl:
"""Test that missing URL returns a helpful error."""
def test_missing_url_returns_error(self):
mcp = MagicMock()
fns = []
mcp.tool.return_value = lambda fn: fns.append(fn) or fn
cred = MagicMock()
# Token is set but URL is not
cred.get.side_effect = lambda key: {
"mattermost": "test-token",
"mattermost_url": None,
}.get(key)
register_tools(mcp, credentials=cred)
with patch.dict("os.environ", {"MATTERMOST_URL": ""}, clear=False):
fn = next(f for f in fns if f.__name__ == "mattermost_list_teams")
result = fn()
assert "error" in result
assert "URL" in result["error"]
# --- Credential spec tests ---
class TestCredentialSpec:
def test_mattermost_credential_spec_exists(self):
from aden_tools.credentials import CREDENTIAL_SPECS
assert "mattermost" in CREDENTIAL_SPECS
def test_mattermost_spec_env_var(self):
from aden_tools.credentials import CREDENTIAL_SPECS
spec = CREDENTIAL_SPECS["mattermost"]
assert spec.env_var == "MATTERMOST_ACCESS_TOKEN"
def test_mattermost_spec_tools(self):
from aden_tools.credentials import CREDENTIAL_SPECS
spec = CREDENTIAL_SPECS["mattermost"]
assert "mattermost_list_teams" in spec.tools
assert "mattermost_list_channels" in spec.tools
assert "mattermost_get_channel" in spec.tools
assert "mattermost_send_message" in spec.tools
assert "mattermost_get_posts" in spec.tools
assert "mattermost_create_reaction" in spec.tools
assert "mattermost_delete_post" in spec.tools
assert len(spec.tools) == 7
def test_mattermost_url_credential_spec_exists(self):
from aden_tools.credentials import CREDENTIAL_SPECS
assert "mattermost_url" in CREDENTIAL_SPECS
def test_mattermost_url_spec_env_var(self):
from aden_tools.credentials import CREDENTIAL_SPECS
spec = CREDENTIAL_SPECS["mattermost_url"]
assert spec.env_var == "MATTERMOST_URL"
+81 -1
View File
@@ -22,13 +22,30 @@ def _mock_resp(data, status_code=200):
return resp
def _mock_credentials() -> MagicMock:
creds = MagicMock()
creds.get.side_effect = lambda key: {
"sap_base_url": "https://cred-store.s4hana.ondemand.com",
"sap_username": "CRED_USER",
"sap_password": "cred-password",
}.get(key)
return creds
@pytest.fixture
def tool_fns(mcp: FastMCP):
def tool_fns(mcp: FastMCP) -> dict:
register_tools(mcp, credentials=None)
tools = mcp._tool_manager._tools
return {name: tools[name].fn for name in tools}
@pytest.fixture
def tool_fns_with_creds(mcp: FastMCP) -> dict:
register_tools(mcp, credentials=_mock_credentials())
tools = mcp._tool_manager._tools
return {name: tools[name].fn for name in tools}
class TestSAPListPurchaseOrders:
def test_missing_credentials(self, tool_fns):
with patch.dict("os.environ", {}, clear=True):
@@ -176,3 +193,66 @@ class TestSAPListSalesOrders:
assert result["count"] == 1
assert result["sales_orders"][0]["sales_order"] == "1"
assert result["sales_orders"][0]["net_amount"] == "25000.00"
class TestCredentialStoreAdapter:
"""Verify credentials are resolved via CredentialStoreAdapter."""
def test_credential_store_used(self, tool_fns_with_creds):
data = {
"d": {
"__count": "1",
"results": [
{
"PurchaseOrder": "4500000001",
"PurchaseOrderType": "NB",
"CompanyCode": "1010",
"Supplier": "17300001",
"CreationDate": "/Date(1672531200000)/",
"PurchaseOrderNetAmount": "15000.00",
"DocumentCurrency": "USD",
}
],
}
}
with patch(
"aden_tools.tools.sap_tool.sap_tool.httpx.get",
return_value=_mock_resp(data),
) as mock_get:
result = tool_fns_with_creds["sap_list_purchase_orders"]()
assert result["count"] == 1
call_url = mock_get.call_args.args[0]
assert "cred-store.s4hana.ondemand.com" in call_url
def test_credential_store_missing_values(self):
creds = MagicMock()
creds.get.return_value = None
mcp = FastMCP("test")
register_tools(mcp, credentials=creds)
tools = mcp._tool_manager._tools
fn = tools["sap_list_purchase_orders"].fn
result = fn()
assert "error" in result
def test_env_fallback_when_no_adapter(self, tool_fns):
data = {
"d": {
"__count": "0",
"results": [],
}
}
with (
patch.dict("os.environ", ENV),
patch(
"aden_tools.tools.sap_tool.sap_tool.httpx.get",
return_value=_mock_resp(data),
) as mock_get,
):
result = tool_fns["sap_list_purchase_orders"]()
assert result["count"] == 0
call_url = mock_get.call_args.args[0]
assert "my-tenant-api.s4hana.ondemand.com" in call_url
+102
View File
@@ -1,11 +1,16 @@
"""Tests for web_scrape tool (FastMCP)."""
import socket
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastmcp import FastMCP
from aden_tools.tools.web_scrape_tool import register_tools
from aden_tools.tools.web_scrape_tool.web_scrape_tool import (
_check_url_target,
_is_internal_address,
)
@pytest.fixture
@@ -430,3 +435,100 @@ class TestWebScrapeToolRobotsTxt:
result = await web_scrape_fn(url="https://example.com", respect_robots_txt=False)
assert "error" not in result
mock_rp_cls.assert_not_called()
_MOD = "aden_tools.tools.web_scrape_tool.web_scrape_tool"
class TestIsInternalAddress:
"""Tests for _is_internal_address."""
def test_loopback_ipv4(self):
assert _is_internal_address("127.0.0.1") is True
def test_private_10_range(self):
assert _is_internal_address("10.0.0.1") is True
def test_private_192_168(self):
assert _is_internal_address("192.168.1.1") is True
def test_link_local_aws_metadata(self):
assert _is_internal_address("169.254.169.254") is True
def test_public_ipv4(self):
assert _is_internal_address("8.8.8.8") is False
def test_public_ipv6(self):
assert _is_internal_address("2607:f8b0:4004:800::200e") is False
def test_invalid_string_blocked(self):
assert _is_internal_address("not-an-ip") is True
def _fake_addrinfo(ip: str, port: int = 443) -> list[tuple]:
return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", (ip, port))]
class TestCheckUrlTarget:
"""Tests for _check_url_target."""
@patch(f"{_MOD}.socket.getaddrinfo")
def test_public_hostname_allowed(self, mock_dns):
mock_dns.return_value = _fake_addrinfo("93.184.216.34")
assert _check_url_target("https://example.com/page") is None
@patch(f"{_MOD}.socket.getaddrinfo")
def test_private_hostname_blocked(self, mock_dns):
mock_dns.return_value = _fake_addrinfo("10.0.0.1")
result = _check_url_target("https://evil.com/steal")
assert result is not None
assert "internal" in result.lower()
def test_raw_private_ip_blocked(self):
result = _check_url_target("http://127.0.0.1/admin")
assert result is not None
@patch(
f"{_MOD}.socket.getaddrinfo",
side_effect=socket.gaierror("NXDOMAIN"),
)
def test_dns_failure_returns_error(self, _mock_dns):
result = _check_url_target("https://nonexistent.invalid/")
assert result is not None
assert "DNS" in result
class TestWebScrapeSSRF:
"""SSRF protection through the web_scrape tool."""
@pytest.mark.asyncio
async def test_blocks_private_ip(self, web_scrape_fn):
result = await web_scrape_fn(url="http://192.168.1.1/admin")
assert "error" in result
assert result.get("blocked_by_ssrf_protection") is True
@pytest.mark.asyncio
async def test_blocks_localhost(self, web_scrape_fn):
result = await web_scrape_fn(url="http://127.0.0.1/secret")
assert "error" in result
assert result.get("blocked_by_ssrf_protection") is True
@pytest.mark.asyncio
async def test_blocks_metadata_endpoint(self, web_scrape_fn):
result = await web_scrape_fn(url="http://169.254.169.254/latest/meta-data/")
assert "error" in result
assert result.get("blocked_by_ssrf_protection") is True
@pytest.mark.asyncio
@patch(_STEALTH_PATH)
@patch(_PW_PATH)
@patch(f"{_MOD}._check_url_target", return_value=None)
async def test_allows_public_url(self, _mock_check, mock_pw, mock_stealth, web_scrape_fn):
html = "<html><body><p>Hello world</p></body></html>"
mock_cm, _, _ = _make_playwright_mocks(html)
mock_pw.return_value = mock_cm
mock_stealth.return_value.apply_stealth_async = AsyncMock()
result = await web_scrape_fn(url="https://example.com/")
assert "error" not in result
assert "Hello world" in result["content"]