Compare commits
78 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f748187391 | |||
| eafbeb78b4 | |||
| 5cb5083f8d | |||
| bf86daee92 | |||
| 43bbd0f31f | |||
| 2cf962b538 | |||
| 4298196700 | |||
| bc1f712e42 | |||
| cccbcc8ec3 | |||
| 0722f83f16 | |||
| 72091d2783 | |||
| 3bb69a5784 | |||
| 63fb089062 | |||
| d5ba985e29 | |||
| 6ee510d2f6 | |||
| 45b350e7c8 | |||
| 7e690de12f | |||
| ae85d2bf59 | |||
| e9fd0158b9 | |||
| 9a68a5d7ee | |||
| 33edf4a207 | |||
| f9fdaf5adc | |||
| eabb17934c | |||
| eba7524955 | |||
| c56440340a | |||
| c889ffd85d | |||
| 905a4f3516 | |||
| 941605720f | |||
| 72e5c5c1c6 | |||
| 0f42c8c8c1 | |||
| c3c3075610 | |||
| 86ef6fd8c5 | |||
| 95bdf4fe32 | |||
| 890d303d26 | |||
| 7fe60991e1 | |||
| a72938a163 | |||
| 326a3dd1b7 | |||
| 183c6e2620 | |||
| 1b40bff7da | |||
| 38b79edaee | |||
| eb4f180192 | |||
| bf0b9a1edb | |||
| 9667dd25cb | |||
| 33e4e8d440 | |||
| c5ac29c81d | |||
| 13c072d731 | |||
| 5e31975cc3 | |||
| 82af76e72a | |||
| a483f8d06a | |||
| e188c26e9f | |||
| 22d9fba1fd | |||
| c7d0afc775 | |||
| 45aafbc52b | |||
| 567340c05d | |||
| 8d8656193d | |||
| ef317371ce | |||
| d5596ccb0a | |||
| 5f1530ec5b | |||
| 8af32b421c | |||
| 95cc8a4513 | |||
| d648f3d315 | |||
| b43044cf4d | |||
| 4724320946 | |||
| 89ab2e0a74 | |||
| ee4682c565 | |||
| 9b59255770 | |||
| 49fd443da8 | |||
| b599a760e8 | |||
| b4a37cdb03 | |||
| 4885db318e | |||
| fa7ce53fb3 | |||
| 75a2ef2c4a | |||
| a0b9d6afaf | |||
| 74c0a85e3f | |||
| 96609386a3 | |||
| 0cef0e6990 | |||
| 2f15a16159 | |||
| d433cda209 |
+4
-2
@@ -13,6 +13,10 @@ out/
|
||||
.env
|
||||
.env.local
|
||||
.env.*.local
|
||||
.venv
|
||||
/venv
|
||||
tools/src/uv.lock
|
||||
|
||||
|
||||
# User configuration (copied from .example)
|
||||
config.yaml
|
||||
@@ -69,8 +73,6 @@ exports/*
|
||||
|
||||
.claude/settings.local.json
|
||||
|
||||
.venv
|
||||
|
||||
docs/github-issues/*
|
||||
core/tests/*dumps/*
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<img src="https://img.shields.io/badge/Agent_Harness-Runtime_Layer-ff6600?style=flat-square" alt="Agent Harness" />
|
||||
<img src="https://img.shields.io/badge/AI_Agents-Self--Improving-brightgreen?style=flat-square" alt="AI Agents" />
|
||||
<img src="https://img.shields.io/badge/Multi--Agent-Systems-blue?style=flat-square" alt="Multi-Agent" />
|
||||
<img src="https://img.shields.io/badge/Headless-Development-purple?style=flat-square" alt="Headless" />
|
||||
@@ -35,39 +36,42 @@
|
||||
<img src="https://img.shields.io/badge/Google_Gemini-supported-4285F4?style=flat-square&logo=google" alt="Gemini" />
|
||||
</p>
|
||||
|
||||
<p align="center"><em>The agent harness for production workloads — state management, failure recovery, observability, and human oversight so your agents actually run.</em></p>
|
||||
|
||||
## Overview
|
||||
|
||||
Generate a swarm of worker agents with a coding agent(queen) that control them. Define your goal through conversation with hive queen, and the framework generates a node graph with dynamically created connection code. When things break, the framework captures failure data, evolves the agent through the coding agent, and redeploys. Built-in human-in-the-loop nodes, browser use, credential management, and real-time monitoring give you control without sacrificing adaptability.
|
||||
Hive is a runtime harness for AI agents in production. You describe your goal in natural language; a coding agent (the queen) generates the agent graph and connection code to achieve it. During execution, the harness manages state isolation, checkpoint-based crash recovery, cost enforcement, and real-time observability. When agents fail, the framework captures failure data, evolves the graph through the coding agent, and redeploys automatically. Built-in human-in-the-loop nodes, browser control, credential management, and parallel execution give you production reliability without sacrificing adaptability.
|
||||
|
||||
Visit [adenhq.com](https://adenhq.com) for complete documentation, examples, and guides.
|
||||
|
||||
Visit [HoneyComb](http://honeycomb.open-hive.com/) to see what jobs are being automated by AI. It’s a stock market for jobs, driven by our community’s AI agent progress. You can long and short jobs (with no real money but compute token)based on how much you think a job is going to be replaced by AI.
|
||||
|
||||
https://github.com/user-attachments/assets/bf10edc3-06ba-48b6-98ba-d069b15fb69d
|
||||
|
||||
|
||||
## Who Is Hive For?
|
||||
|
||||
Hive is designed for developers and teams who want to build many **autonomous AI agents** fast without manually wiring complex workflows.
|
||||
Hive is the harness layer for teams moving AI agents from prototype to production. Models are getting better on their own — the bottleneck is the infrastructure around them: state management, failure recovery, cost control, and observability.
|
||||
|
||||
Hive is a good fit if you:
|
||||
|
||||
- Want AI agents that **execute real business processes**, not demos
|
||||
- Need **fast or high volume agent execution** over open workflow
|
||||
- Need a **runtime that handles state, recovery, and parallel execution** at scale
|
||||
- Need **self-healing and adaptive agents** that improve over time
|
||||
- Require **human-in-the-loop control**, observability, and cost limits
|
||||
- Plan to run agents in **production environments**
|
||||
- Plan to run agents in **production** where uptime, cost, and auditability matter
|
||||
|
||||
Hive may not be the best fit if you’re only experimenting with simple agent chains or one-off scripts.
|
||||
|
||||
## When Should You Use Hive?
|
||||
|
||||
Use Hive when you need:
|
||||
Use Hive when the bottleneck is no longer the model but the harness around it:
|
||||
|
||||
- Long-running, autonomous agents
|
||||
- Strong guardrails, process, and controls
|
||||
- Continuous improvement based on failures
|
||||
- Multi-agent coordination
|
||||
- A framework that evolves with your goals
|
||||
- Long-running agents that need **state persistence and crash recovery**
|
||||
- Production workloads requiring **cost enforcement, observability, and audit trails**
|
||||
- Agents that **self-heal** through failure capture and graph evolution
|
||||
- Multi-agent coordination with **session isolation and shared memory**
|
||||
- A framework that **scales with model improvements** rather than fighting them
|
||||
|
||||
## Quick Links
|
||||
|
||||
@@ -100,9 +104,11 @@ Use Hive when you need:
|
||||
git clone https://github.com/aden-hive/hive.git
|
||||
cd hive
|
||||
|
||||
|
||||
# Run quickstart setup
|
||||
# Run quickstart setup (macOS/Linux)
|
||||
./quickstart.sh
|
||||
|
||||
# Windows (PowerShell)
|
||||
.\quickstart.ps1
|
||||
```
|
||||
|
||||
This sets up:
|
||||
@@ -152,9 +158,9 @@ Hive is built to be model-agnostic and system-agnostic.
|
||||
- **LLM flexibility** - Hive Framework supports Anthropic, OpenAI, OpenRouter, Hive LLM, and other hosted or local models through LiteLLM-compatible providers.
|
||||
- **Business system connectivity** - Hive Framework is designed to connect to all kinds of business systems as tools, such as CRM, support, messaging, data, file, and internal APIs via MCP.
|
||||
|
||||
## Why Aden
|
||||
## Why Hive
|
||||
|
||||
Hive focuses on generating agents that run real business processes rather than generic agents. Instead of requiring you to manually design workflows, define agent interactions, and handle failures reactively, Hive flips the paradigm: **you describe outcomes, and the system builds itself**—delivering an outcome-driven, adaptive experience with an easy-to-use set of tools and integrations.
|
||||
As models improve, the upper bound of what agents can do rises — but their reliability and production value are determined by the harness. Hive focuses on generating agents that run real business processes rather than generic agents. Instead of requiring you to manually design workflows, define agent interactions, and handle failures reactively, Hive flips the paradigm: **you describe outcomes, and the system builds itself**—delivering an outcome-driven, adaptive experience with an easy-to-use set of tools and integrations.
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
@@ -190,8 +196,9 @@ flowchart LR
|
||||
|
||||
### The Hive Advantage
|
||||
|
||||
| Traditional Frameworks | Hive |
|
||||
| Typical Agent Frameworks | Hive |
|
||||
| -------------------------- | -------------------------------------- |
|
||||
| Focus on model orchestration | **Production harness**: state, recovery, observability |
|
||||
| Hardcode agent workflows | Describe goals in natural language |
|
||||
| Manual graph definition | Auto-generated agent graphs |
|
||||
| Reactive error handling | Outcome-evaluation and adaptiveness |
|
||||
@@ -385,7 +392,7 @@ Yes! Hive supports local models through LiteLLM. Simply use the model name forma
|
||||
|
||||
**Q: What makes Hive different from other agent frameworks?**
|
||||
|
||||
Hive generates your entire agent system from natural language goals using a coding agent—you don't hardcode workflows or manually define graphs. When agents fail, the framework automatically captures failure data, [evolves the agent graph](docs/key_concepts/evolution.md), and redeploys. This self-improving loop is unique to Aden.
|
||||
Hive is an agent harness, not just an orchestration framework. It provides the production runtime layer — session isolation, checkpoint-based crash recovery, cost enforcement, real-time observability, and human-in-the-loop controls — that makes agents reliable enough to run real workloads. On top of that, Hive generates your entire agent system from natural language goals and automatically [evolves the graph](docs/key_concepts/evolution.md) when agents fail. The combination of a robust harness with self-improving generation is what sets Hive apart.
|
||||
|
||||
**Q: Is Hive open-source?**
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ This guide explains how to integrate Model Context Protocol (MCP) servers with t
|
||||
|
||||
The framework provides built-in support for MCP servers, allowing you to:
|
||||
|
||||
- **Register MCP servers** via STDIO or HTTP transport
|
||||
- **Register MCP servers** via STDIO, HTTP, Unix socket, or SSE transport
|
||||
- **Auto-discover tools** from registered servers
|
||||
- **Use MCP tools** seamlessly in your agents
|
||||
- **Manage multiple MCP servers** simultaneously
|
||||
@@ -104,6 +104,48 @@ runner.register_mcp_server(
|
||||
- `url`: Base URL of the MCP server
|
||||
- `headers`: HTTP headers to include (optional)
|
||||
|
||||
### Unix Socket Transport
|
||||
|
||||
Best for same-host inter-process communication with lower overhead than TCP:
|
||||
|
||||
```python
|
||||
runner.register_mcp_server(
|
||||
name="local-ipc-tools",
|
||||
transport="unix",
|
||||
url="http://localhost",
|
||||
socket_path="/tmp/mcp_server.sock",
|
||||
headers={
|
||||
"Authorization": "Bearer token"
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
**Configuration:**
|
||||
|
||||
- `url`: Base URL for HTTP requests over the socket (required, e.g., `"http://localhost"`)
|
||||
- `socket_path`: Absolute path to the Unix socket file (required, e.g., `"/tmp/mcp_server.sock"`)
|
||||
- `headers`: HTTP headers to include (optional)
|
||||
|
||||
### SSE Transport
|
||||
|
||||
Best for real-time, event-driven connections using the MCP SDK's SSE client:
|
||||
|
||||
```python
|
||||
runner.register_mcp_server(
|
||||
name="streaming-tools",
|
||||
transport="sse",
|
||||
url="http://localhost:8000/sse",
|
||||
headers={
|
||||
"Authorization": "Bearer token"
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
**Configuration:**
|
||||
|
||||
- `url`: SSE endpoint URL (required, e.g., `"http://localhost:8000/sse"`)
|
||||
- `headers`: HTTP headers for the SSE connection (optional)
|
||||
|
||||
## Using MCP Tools in Agents
|
||||
|
||||
Once registered, MCP tools are available just like any other tool:
|
||||
@@ -258,7 +300,32 @@ runner.register_mcp_server(
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Handle Cleanup
|
||||
### 3. Use Unix Socket for Same-Host IPC
|
||||
|
||||
When both the agent and MCP server run on the same machine, Unix sockets avoid TCP overhead:
|
||||
|
||||
```python
|
||||
runner.register_mcp_server(
|
||||
name="fast-local-tools",
|
||||
transport="unix",
|
||||
url="http://localhost",
|
||||
socket_path="/tmp/mcp_server.sock"
|
||||
)
|
||||
```
|
||||
|
||||
### 4. Use SSE for Streaming and Real-Time Tools
|
||||
|
||||
SSE transport maintains a persistent connection, ideal for event-driven servers:
|
||||
|
||||
```python
|
||||
runner.register_mcp_server(
|
||||
name="realtime-tools",
|
||||
transport="sse",
|
||||
url="http://realtime-server:8000/sse"
|
||||
)
|
||||
```
|
||||
|
||||
### 5. Handle Cleanup
|
||||
|
||||
Always clean up MCP connections when done:
|
||||
|
||||
@@ -280,7 +347,7 @@ async with AgentRunner.load("exports/my-agent") as runner:
|
||||
# Automatic cleanup
|
||||
```
|
||||
|
||||
### 4. Tool Name Conflicts
|
||||
### 6. Tool Name Conflicts
|
||||
|
||||
If multiple MCP servers provide tools with the same name, the last registered server wins. To avoid conflicts:
|
||||
|
||||
@@ -315,6 +382,24 @@ If HTTP transport fails:
|
||||
2. Check firewall settings
|
||||
3. Verify the URL and port are correct
|
||||
|
||||
### Unix Socket Not Connecting
|
||||
|
||||
If Unix socket transport fails:
|
||||
|
||||
1. Verify the socket file exists: `ls -la /tmp/mcp_server.sock`
|
||||
2. Check file permissions on the socket
|
||||
3. Ensure no other process has locked the socket
|
||||
4. Verify the `url` field is set (e.g., `"http://localhost"`)
|
||||
|
||||
### SSE Connection Issues
|
||||
|
||||
If SSE transport fails:
|
||||
|
||||
1. Verify the server supports SSE at the given URL
|
||||
2. Check that the `mcp` Python package is installed (`pip install mcp`)
|
||||
3. Ensure the SSE endpoint is accessible: `curl http://localhost:8000/sse`
|
||||
4. Check for firewall or proxy issues blocking long-lived connections
|
||||
|
||||
## Example: Full Agent with MCP Tools
|
||||
|
||||
Here's a complete example of an agent that uses MCP tools:
|
||||
|
||||
@@ -584,11 +584,19 @@ class CredentialTesterAgent:
|
||||
self._tool_registry.load_mcp_config(mcp_config_path)
|
||||
|
||||
try:
|
||||
agent_dir = Path(__file__).parent
|
||||
registry = MCPRegistry()
|
||||
registry.initialize()
|
||||
registry_configs = registry.load_agent_selection(Path(__file__).parent)
|
||||
if (agent_dir / "mcp_registry.json").is_file():
|
||||
self._tool_registry.set_mcp_registry_agent_path(agent_dir)
|
||||
registry_configs, selection_max_tools = registry.load_agent_selection(agent_dir)
|
||||
if registry_configs:
|
||||
self._tool_registry.load_registry_servers(registry_configs)
|
||||
self._tool_registry.load_registry_servers(
|
||||
registry_configs,
|
||||
preserve_existing_tools=True,
|
||||
log_collisions=True,
|
||||
max_tools=selection_max_tools,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("MCP registry config failed to load", exc_info=True)
|
||||
|
||||
|
||||
@@ -31,6 +31,11 @@ def _queen_dir() -> Path:
|
||||
return Path.home() / ".hive" / "queen"
|
||||
|
||||
|
||||
def format_memory_date(d: date) -> str:
|
||||
"""Return a cross-platform long date label without a zero-padded day."""
|
||||
return f"{d.strftime('%B')} {d.day}, {d.year}"
|
||||
|
||||
|
||||
def semantic_memory_path() -> Path:
|
||||
return _queen_dir() / "MEMORY.md"
|
||||
|
||||
@@ -91,9 +96,9 @@ def format_for_injection() -> str:
|
||||
content = content[:_EPISODIC_CHAR_BUDGET] + "\n\n…(truncated)"
|
||||
today = date.today()
|
||||
if d == today:
|
||||
label = f"## Today — {d.strftime('%B %-d, %Y')}"
|
||||
label = f"## Today — {format_memory_date(d)}"
|
||||
else:
|
||||
label = f"## {d.strftime('%B %-d, %Y')}"
|
||||
label = f"## {format_memory_date(d)}"
|
||||
parts.append(f"{label}\n\n{content}")
|
||||
|
||||
if not parts:
|
||||
@@ -127,7 +132,7 @@ def append_episodic_entry(content: str) -> None:
|
||||
ep_path = episodic_memory_path()
|
||||
ep_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
today = date.today()
|
||||
today_str = f"{today.strftime('%B')} {today.day}, {today.year}"
|
||||
today_str = format_memory_date(today)
|
||||
timestamp = datetime.now().strftime("%H:%M")
|
||||
if not ep_path.exists():
|
||||
header = f"# {today_str}\n\n"
|
||||
@@ -331,7 +336,7 @@ async def consolidate_queen_memory(
|
||||
existing_semantic = read_semantic_memory()
|
||||
today_journal = read_episodic_memory()
|
||||
today = date.today()
|
||||
today_str = f"{today.strftime('%B')} {today.day}, {today.year}"
|
||||
today_str = format_memory_date(today)
|
||||
adapt_path = session_dir / "data" / "adapt.md"
|
||||
|
||||
user_msg = (
|
||||
|
||||
@@ -99,6 +99,11 @@ def main():
|
||||
|
||||
register_debugger_commands(subparsers)
|
||||
|
||||
# Register MCP registry commands (mcp install, mcp add, ...)
|
||||
from framework.runner.mcp_registry_cli import register_mcp_commands
|
||||
|
||||
register_mcp_commands(subparsers)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if hasattr(args, "func"):
|
||||
|
||||
@@ -186,6 +186,8 @@ def get_worker_llm_extra_kwargs() -> dict[str, Any]:
|
||||
"store": False,
|
||||
"allowed_openai_params": ["store"],
|
||||
}
|
||||
if worker_llm.get("provider") == "ollama":
|
||||
return {"num_ctx": worker_llm.get("num_ctx", 16384)}
|
||||
return {}
|
||||
|
||||
|
||||
@@ -432,6 +434,11 @@ def get_llm_extra_kwargs() -> dict[str, Any]:
|
||||
"store": False,
|
||||
"allowed_openai_params": ["store"],
|
||||
}
|
||||
if llm.get("provider") == "ollama":
|
||||
# Pass num_ctx to Ollama so it doesn't silently truncate the ~9.5k Queen prompt.
|
||||
# Ollama's default num_ctx is only 2048. We set it to 16384 here so LiteLLM
|
||||
# passes it through as a provider-specific option.
|
||||
return {"num_ctx": llm.get("num_ctx", 16384)}
|
||||
return {}
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
"""EventLoopNode subpackage — modular components of the event loop orchestrator.
|
||||
|
||||
All public symbols are re-exported by the parent ``event_loop_node.py`` for
|
||||
backward compatibility. Internal consumers may import directly from these
|
||||
submodules for clarity.
|
||||
"""
|
||||
@@ -0,0 +1,652 @@
|
||||
"""Conversation compaction pipeline.
|
||||
|
||||
Implements the multi-level compaction strategy:
|
||||
1. Prune old tool results
|
||||
2. Structure-preserving compaction (spillover)
|
||||
3. LLM summary compaction (with recursive splitting)
|
||||
4. Emergency deterministic summary (no LLM)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from framework.graph.conversation import NodeConversation
|
||||
from framework.graph.event_loop.event_publishing import publish_context_usage
|
||||
from framework.graph.event_loop.types import LoopConfig, OutputAccumulator
|
||||
from framework.graph.node import NodeContext
|
||||
from framework.runtime.event_bus import EventBus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Limits for LLM compaction
|
||||
LLM_COMPACT_CHAR_LIMIT: int = 240_000
|
||||
LLM_COMPACT_MAX_DEPTH: int = 10
|
||||
|
||||
|
||||
async def compact(
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
accumulator: OutputAccumulator | None,
|
||||
*,
|
||||
config: LoopConfig,
|
||||
event_bus: EventBus | None,
|
||||
char_limit: int = LLM_COMPACT_CHAR_LIMIT,
|
||||
max_depth: int = LLM_COMPACT_MAX_DEPTH,
|
||||
) -> None:
|
||||
"""Run the full compaction pipeline if conversation needs compaction.
|
||||
|
||||
Pipeline stages (in order, short-circuits when budget is restored):
|
||||
1. Prune old tool results
|
||||
2. Structure-preserving compaction (free, no LLM)
|
||||
3. LLM summary compaction (recursive split if too large)
|
||||
4. Emergency deterministic summary (fallback)
|
||||
"""
|
||||
ratio_before = conversation.usage_ratio()
|
||||
phase_grad = getattr(ctx, "continuous_mode", False)
|
||||
pre_inventory: list[dict[str, Any]] | None = None
|
||||
|
||||
if ratio_before >= 1.0:
|
||||
pre_inventory = build_message_inventory(conversation)
|
||||
|
||||
# --- Step 1: Prune old tool results (free, fast) ---
|
||||
protect = max(2000, config.max_context_tokens // 12)
|
||||
pruned = await conversation.prune_old_tool_results(
|
||||
protect_tokens=protect,
|
||||
min_prune_tokens=max(1000, protect // 3),
|
||||
)
|
||||
if pruned > 0:
|
||||
logger.info(
|
||||
"Pruned %d old tool results: %.0f%% -> %.0f%%",
|
||||
pruned,
|
||||
ratio_before * 100,
|
||||
conversation.usage_ratio() * 100,
|
||||
)
|
||||
if not conversation.needs_compaction():
|
||||
await log_compaction(
|
||||
ctx,
|
||||
conversation,
|
||||
ratio_before,
|
||||
event_bus,
|
||||
pre_inventory=pre_inventory,
|
||||
)
|
||||
return
|
||||
|
||||
# --- Step 2: Standard structure-preserving compaction (free, no LLM) ---
|
||||
spill_dir = config.spillover_dir
|
||||
if spill_dir:
|
||||
await conversation.compact_preserving_structure(
|
||||
spillover_dir=spill_dir,
|
||||
keep_recent=4,
|
||||
phase_graduated=phase_grad,
|
||||
)
|
||||
if not conversation.needs_compaction():
|
||||
await log_compaction(
|
||||
ctx,
|
||||
conversation,
|
||||
ratio_before,
|
||||
event_bus,
|
||||
pre_inventory=pre_inventory,
|
||||
)
|
||||
return
|
||||
|
||||
# --- Step 3: LLM summary compaction ---
|
||||
if ctx.llm is not None:
|
||||
logger.info(
|
||||
"LLM summary compaction triggered (%.0f%% usage)",
|
||||
conversation.usage_ratio() * 100,
|
||||
)
|
||||
try:
|
||||
summary = await llm_compact(
|
||||
ctx,
|
||||
list(conversation.messages),
|
||||
accumulator,
|
||||
char_limit=char_limit,
|
||||
max_depth=max_depth,
|
||||
max_context_tokens=config.max_context_tokens,
|
||||
)
|
||||
await conversation.compact(
|
||||
summary,
|
||||
keep_recent=2,
|
||||
phase_graduated=phase_grad,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("LLM compaction failed: %s", e)
|
||||
|
||||
if not conversation.needs_compaction():
|
||||
await log_compaction(
|
||||
ctx,
|
||||
conversation,
|
||||
ratio_before,
|
||||
event_bus,
|
||||
pre_inventory=pre_inventory,
|
||||
)
|
||||
return
|
||||
|
||||
# --- Step 4: Emergency deterministic summary (LLM failed/unavailable) ---
|
||||
logger.warning(
|
||||
"Emergency compaction (%.0f%% usage)",
|
||||
conversation.usage_ratio() * 100,
|
||||
)
|
||||
summary = build_emergency_summary(ctx, accumulator, conversation, config)
|
||||
await conversation.compact(
|
||||
summary,
|
||||
keep_recent=1,
|
||||
phase_graduated=phase_grad,
|
||||
)
|
||||
await log_compaction(
|
||||
ctx,
|
||||
conversation,
|
||||
ratio_before,
|
||||
event_bus,
|
||||
pre_inventory=pre_inventory,
|
||||
)
|
||||
|
||||
|
||||
# --- LLM compaction with binary-search splitting ----------------------
|
||||
|
||||
|
||||
async def llm_compact(
|
||||
ctx: NodeContext,
|
||||
messages: list,
|
||||
accumulator: OutputAccumulator | None = None,
|
||||
_depth: int = 0,
|
||||
*,
|
||||
char_limit: int = LLM_COMPACT_CHAR_LIMIT,
|
||||
max_depth: int = LLM_COMPACT_MAX_DEPTH,
|
||||
max_context_tokens: int = 128_000,
|
||||
) -> str:
|
||||
"""Summarise *messages* with LLM, splitting recursively if too large.
|
||||
|
||||
If the formatted text exceeds ``LLM_COMPACT_CHAR_LIMIT`` or the LLM
|
||||
rejects the call with a context-length error, the messages are split
|
||||
in half and each half is summarised independently. Tool history is
|
||||
appended once at the top-level call (``_depth == 0``).
|
||||
"""
|
||||
from framework.graph.conversation import extract_tool_call_history
|
||||
from framework.graph.event_loop.tool_result_handler import is_context_too_large_error
|
||||
|
||||
if _depth > max_depth:
|
||||
raise RuntimeError(f"LLM compaction recursion limit ({max_depth})")
|
||||
|
||||
formatted = format_messages_for_summary(messages)
|
||||
|
||||
# Proactive split: avoid wasting an API call on oversized input
|
||||
if len(formatted) > char_limit and len(messages) > 1:
|
||||
summary = await _llm_compact_split(
|
||||
ctx,
|
||||
messages,
|
||||
accumulator,
|
||||
_depth,
|
||||
char_limit=char_limit,
|
||||
max_depth=max_depth,
|
||||
max_context_tokens=max_context_tokens,
|
||||
)
|
||||
else:
|
||||
prompt = build_llm_compaction_prompt(
|
||||
ctx,
|
||||
accumulator,
|
||||
formatted,
|
||||
max_context_tokens=max_context_tokens,
|
||||
)
|
||||
summary_budget = max(1024, max_context_tokens // 2)
|
||||
try:
|
||||
response = await ctx.llm.acomplete(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
system=(
|
||||
"You are a conversation compactor for an AI agent. "
|
||||
"Write a detailed summary that allows the agent to "
|
||||
"continue its work. Preserve user-stated rules, "
|
||||
"constraints, and account/identity preferences verbatim."
|
||||
),
|
||||
max_tokens=summary_budget,
|
||||
)
|
||||
summary = response.content
|
||||
except Exception as e:
|
||||
if is_context_too_large_error(e) and len(messages) > 1:
|
||||
logger.info(
|
||||
"LLM context too large (depth=%d, msgs=%d) — splitting",
|
||||
_depth,
|
||||
len(messages),
|
||||
)
|
||||
summary = await _llm_compact_split(
|
||||
ctx,
|
||||
messages,
|
||||
accumulator,
|
||||
_depth,
|
||||
char_limit=char_limit,
|
||||
max_depth=max_depth,
|
||||
max_context_tokens=max_context_tokens,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
# Append tool history at top level only
|
||||
if _depth == 0:
|
||||
tool_history = extract_tool_call_history(messages)
|
||||
if tool_history and "TOOLS ALREADY CALLED" not in summary:
|
||||
summary += "\n\n" + tool_history
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
async def _llm_compact_split(
|
||||
ctx: NodeContext,
|
||||
messages: list,
|
||||
accumulator: OutputAccumulator | None,
|
||||
_depth: int,
|
||||
*,
|
||||
char_limit: int = LLM_COMPACT_CHAR_LIMIT,
|
||||
max_depth: int = LLM_COMPACT_MAX_DEPTH,
|
||||
max_context_tokens: int = 128_000,
|
||||
) -> str:
|
||||
"""Split messages in half and summarise each half independently."""
|
||||
mid = max(1, len(messages) // 2)
|
||||
s1 = await llm_compact(
|
||||
ctx,
|
||||
messages[:mid],
|
||||
None,
|
||||
_depth + 1,
|
||||
char_limit=char_limit,
|
||||
max_depth=max_depth,
|
||||
max_context_tokens=max_context_tokens,
|
||||
)
|
||||
s2 = await llm_compact(
|
||||
ctx,
|
||||
messages[mid:],
|
||||
accumulator,
|
||||
_depth + 1,
|
||||
char_limit=char_limit,
|
||||
max_depth=max_depth,
|
||||
max_context_tokens=max_context_tokens,
|
||||
)
|
||||
return s1 + "\n\n" + s2
|
||||
|
||||
|
||||
# --- Compaction helpers ------------------------------------------------
|
||||
|
||||
|
||||
def format_messages_for_summary(messages: list) -> str:
|
||||
"""Format messages as text for LLM summarisation."""
|
||||
lines: list[str] = []
|
||||
for m in messages:
|
||||
if m.role == "tool":
|
||||
content = m.content[:500]
|
||||
if len(m.content) > 500:
|
||||
content += "..."
|
||||
lines.append(f"[tool result]: {content}")
|
||||
elif m.role == "assistant" and m.tool_calls:
|
||||
names = [tc.get("function", {}).get("name", "?") for tc in m.tool_calls]
|
||||
text = m.content[:200] if m.content else ""
|
||||
lines.append(f"[assistant (calls: {', '.join(names)})]: {text}")
|
||||
else:
|
||||
lines.append(f"[{m.role}]: {m.content}")
|
||||
return "\n\n".join(lines)
|
||||
|
||||
|
||||
def build_llm_compaction_prompt(
|
||||
ctx: NodeContext,
|
||||
accumulator: OutputAccumulator | None,
|
||||
formatted_messages: str,
|
||||
*,
|
||||
max_context_tokens: int = 128_000,
|
||||
) -> str:
|
||||
"""Build prompt for LLM compaction targeting 50% of token budget."""
|
||||
spec = ctx.node_spec
|
||||
ctx_lines = [f"NODE: {spec.name} (id={spec.id})"]
|
||||
if spec.description:
|
||||
ctx_lines.append(f"PURPOSE: {spec.description}")
|
||||
if spec.success_criteria:
|
||||
ctx_lines.append(f"SUCCESS CRITERIA: {spec.success_criteria}")
|
||||
|
||||
if accumulator:
|
||||
acc = accumulator.to_dict()
|
||||
done = {k: v for k, v in acc.items() if v is not None}
|
||||
todo = [k for k, v in acc.items() if v is None]
|
||||
if done:
|
||||
ctx_lines.append(
|
||||
"OUTPUTS ALREADY SET:\n"
|
||||
+ "\n".join(f" {k}: {str(v)[:150]}" for k, v in done.items())
|
||||
)
|
||||
if todo:
|
||||
ctx_lines.append(f"OUTPUTS STILL NEEDED: {', '.join(todo)}")
|
||||
elif spec.output_keys:
|
||||
ctx_lines.append(f"OUTPUTS STILL NEEDED: {', '.join(spec.output_keys)}")
|
||||
|
||||
target_tokens = max_context_tokens // 2
|
||||
target_chars = target_tokens * 4
|
||||
node_ctx = "\n".join(ctx_lines)
|
||||
|
||||
return (
|
||||
"You are compacting an AI agent's conversation history. "
|
||||
"The agent is still working and needs to continue.\n\n"
|
||||
f"AGENT CONTEXT:\n{node_ctx}\n\n"
|
||||
f"CONVERSATION MESSAGES:\n{formatted_messages}\n\n"
|
||||
"INSTRUCTIONS:\n"
|
||||
f"Write a summary of approximately {target_chars} characters "
|
||||
f"(~{target_tokens} tokens).\n"
|
||||
"1. Preserve ALL user-stated rules, constraints, and preferences "
|
||||
"verbatim.\n"
|
||||
"2. Preserve key decisions made and results obtained.\n"
|
||||
"3. Preserve in-progress work state so the agent can continue.\n"
|
||||
"4. Be detailed enough that the agent can resume without "
|
||||
"re-doing work.\n"
|
||||
)
|
||||
|
||||
|
||||
def build_message_inventory(conversation: NodeConversation) -> list[dict[str, Any]]:
|
||||
"""Build a per-message size inventory for debug logging."""
|
||||
inventory: list[dict[str, Any]] = []
|
||||
for message in conversation.messages:
|
||||
content_chars = len(message.content)
|
||||
tool_call_args_chars = 0
|
||||
tool_name = None
|
||||
if message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
args = tool_call.get("function", {}).get("arguments", "")
|
||||
tool_call_args_chars += (
|
||||
len(args) if isinstance(args, str) else len(json.dumps(args))
|
||||
)
|
||||
names = [
|
||||
tool_call.get("function", {}).get("name", "?") for tool_call in message.tool_calls
|
||||
]
|
||||
tool_name = ", ".join(names)
|
||||
elif message.role == "tool" and message.tool_use_id:
|
||||
for previous in conversation.messages:
|
||||
if previous.tool_calls:
|
||||
for tool_call in previous.tool_calls:
|
||||
if tool_call.get("id") == message.tool_use_id:
|
||||
tool_name = tool_call.get("function", {}).get("name", "?")
|
||||
break
|
||||
if tool_name:
|
||||
break
|
||||
entry: dict[str, Any] = {
|
||||
"seq": message.seq,
|
||||
"role": message.role,
|
||||
"content_chars": content_chars,
|
||||
}
|
||||
if tool_call_args_chars:
|
||||
entry["tool_call_args_chars"] = tool_call_args_chars
|
||||
if tool_name:
|
||||
entry["tool"] = tool_name
|
||||
if message.is_error:
|
||||
entry["is_error"] = True
|
||||
if message.phase_id:
|
||||
entry["phase"] = message.phase_id
|
||||
if content_chars > 2000:
|
||||
entry["preview"] = message.content[:200] + "…"
|
||||
inventory.append(entry)
|
||||
return inventory
|
||||
|
||||
|
||||
def write_compaction_debug_log(
|
||||
ctx: NodeContext,
|
||||
before_pct: int,
|
||||
after_pct: int,
|
||||
level: str,
|
||||
inventory: list[dict[str, Any]] | None,
|
||||
) -> None:
|
||||
"""Write detailed compaction analysis to ~/.hive/compaction_log/."""
|
||||
log_dir = Path.home() / ".hive" / "compaction_log"
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ts = datetime.now(UTC).strftime("%Y%m%dT%H%M%S_%f")
|
||||
node_label = ctx.node_id.replace("/", "_")
|
||||
log_path = log_dir / f"{ts}_{node_label}.md"
|
||||
|
||||
lines: list[str] = [
|
||||
f"# Compaction Debug — {ctx.node_id}",
|
||||
f"**Time:** {datetime.now(UTC).isoformat()}",
|
||||
f"**Node:** {ctx.node_spec.name} (`{ctx.node_id}`)",
|
||||
]
|
||||
if ctx.stream_id:
|
||||
lines.append(f"**Stream:** {ctx.stream_id}")
|
||||
lines.append(f"**Level:** {level}")
|
||||
lines.append(f"**Usage:** {before_pct}% → {after_pct}%")
|
||||
lines.append("")
|
||||
|
||||
if inventory:
|
||||
total_chars = sum(
|
||||
entry.get("content_chars", 0) + entry.get("tool_call_args_chars", 0)
|
||||
for entry in inventory
|
||||
)
|
||||
lines.append(
|
||||
"## Pre-Compaction Message Inventory "
|
||||
f"({len(inventory)} messages, {total_chars:,} total chars)"
|
||||
)
|
||||
lines.append("")
|
||||
ranked = sorted(
|
||||
inventory,
|
||||
key=lambda entry: entry.get("content_chars", 0) + entry.get("tool_call_args_chars", 0),
|
||||
reverse=True,
|
||||
)
|
||||
lines.append("| # | seq | role | tool | chars | % of total | flags |")
|
||||
lines.append("|---|-----|------|------|------:|------------|-------|")
|
||||
for i, entry in enumerate(ranked, 1):
|
||||
chars = entry.get("content_chars", 0) + entry.get("tool_call_args_chars", 0)
|
||||
pct = (chars / total_chars * 100) if total_chars else 0
|
||||
tool = entry.get("tool", "")
|
||||
flags: list[str] = []
|
||||
if entry.get("is_error"):
|
||||
flags.append("error")
|
||||
if entry.get("phase"):
|
||||
flags.append(f"phase={entry['phase']}")
|
||||
lines.append(
|
||||
f"| {i} | {entry['seq']} | {entry['role']} | {tool} "
|
||||
f"| {chars:,} | {pct:.1f}% | {', '.join(flags)} |"
|
||||
)
|
||||
|
||||
large = [entry for entry in ranked if entry.get("preview")]
|
||||
if large:
|
||||
lines.append("")
|
||||
lines.append("### Large message previews")
|
||||
for entry in large:
|
||||
lines.append(
|
||||
f"\n**seq={entry['seq']}** ({entry['role']}, {entry.get('tool', '')}):"
|
||||
)
|
||||
lines.append(f"```\n{entry['preview']}\n```")
|
||||
lines.append("")
|
||||
|
||||
try:
|
||||
log_path.write_text("\n".join(lines), encoding="utf-8")
|
||||
logger.debug("Compaction debug log written to %s", log_path)
|
||||
except OSError:
|
||||
logger.debug("Failed to write compaction debug log to %s", log_path)
|
||||
|
||||
|
||||
async def log_compaction(
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
ratio_before: float,
|
||||
event_bus: EventBus | None,
|
||||
*,
|
||||
pre_inventory: list[dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
"""Log compaction result to runtime logger and event bus."""
|
||||
ratio_after = conversation.usage_ratio()
|
||||
before_pct = round(ratio_before * 100)
|
||||
after_pct = round(ratio_after * 100)
|
||||
|
||||
# Determine label from what happened
|
||||
if after_pct >= before_pct - 1:
|
||||
level = "prune_only"
|
||||
elif ratio_after <= 0.6:
|
||||
level = "llm"
|
||||
else:
|
||||
level = "structural"
|
||||
|
||||
logger.info(
|
||||
"Compaction complete (%s): %d%% -> %d%%",
|
||||
level,
|
||||
before_pct,
|
||||
after_pct,
|
||||
)
|
||||
|
||||
if ctx.runtime_logger:
|
||||
ctx.runtime_logger.log_step(
|
||||
node_id=ctx.node_id,
|
||||
node_type="event_loop",
|
||||
step_index=-1,
|
||||
llm_text=f"Context compacted ({level}): {before_pct}% \u2192 {after_pct}%",
|
||||
verdict="COMPACTION",
|
||||
verdict_feedback=f"level={level} before={before_pct}% after={after_pct}%",
|
||||
)
|
||||
|
||||
if event_bus:
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
|
||||
event_data: dict[str, Any] = {
|
||||
"level": level,
|
||||
"usage_before": before_pct,
|
||||
"usage_after": after_pct,
|
||||
}
|
||||
if pre_inventory is not None:
|
||||
event_data["message_inventory"] = pre_inventory
|
||||
await event_bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.CONTEXT_COMPACTED,
|
||||
stream_id=ctx.stream_id or ctx.node_id,
|
||||
node_id=ctx.node_id,
|
||||
data=event_data,
|
||||
)
|
||||
)
|
||||
|
||||
await publish_context_usage(event_bus, ctx, conversation, "post_compaction")
|
||||
|
||||
if os.environ.get("HIVE_COMPACTION_DEBUG"):
|
||||
write_compaction_debug_log(ctx, before_pct, after_pct, level, pre_inventory)
|
||||
|
||||
|
||||
def build_emergency_summary(
|
||||
ctx: NodeContext,
|
||||
accumulator: OutputAccumulator | None = None,
|
||||
conversation: NodeConversation | None = None,
|
||||
config: LoopConfig | None = None,
|
||||
) -> str:
|
||||
"""Build a structured emergency compaction summary.
|
||||
|
||||
Unlike normal/aggressive compaction which uses an LLM summary,
|
||||
emergency compaction cannot afford an LLM call (context is already
|
||||
way over budget). Instead, build a deterministic summary from the
|
||||
node's known state so the LLM can continue working after
|
||||
compaction without losing track of its task and inputs.
|
||||
"""
|
||||
parts = [
|
||||
"EMERGENCY COMPACTION — previous conversation was too large "
|
||||
"and has been replaced with this summary.\n"
|
||||
]
|
||||
|
||||
# 1. Node identity
|
||||
spec = ctx.node_spec
|
||||
parts.append(f"NODE: {spec.name} (id={spec.id})")
|
||||
if spec.description:
|
||||
parts.append(f"PURPOSE: {spec.description}")
|
||||
|
||||
# 2. Inputs the node received
|
||||
input_lines = []
|
||||
for key in spec.input_keys:
|
||||
value = ctx.input_data.get(key) or ctx.memory.read(key)
|
||||
if value is not None:
|
||||
# Truncate long values but keep them recognisable
|
||||
v_str = str(value)
|
||||
if len(v_str) > 200:
|
||||
v_str = v_str[:200] + "…"
|
||||
input_lines.append(f" {key}: {v_str}")
|
||||
if input_lines:
|
||||
parts.append("INPUTS:\n" + "\n".join(input_lines))
|
||||
|
||||
# 3. Output accumulator state (what's been set so far)
|
||||
if accumulator:
|
||||
acc_state = accumulator.to_dict()
|
||||
set_keys = {k: v for k, v in acc_state.items() if v is not None}
|
||||
missing = [k for k, v in acc_state.items() if v is None]
|
||||
if set_keys:
|
||||
lines = [f" {k}: {str(v)[:150]}" for k, v in set_keys.items()]
|
||||
parts.append("OUTPUTS ALREADY SET:\n" + "\n".join(lines))
|
||||
if missing:
|
||||
parts.append(f"OUTPUTS STILL NEEDED: {', '.join(missing)}")
|
||||
elif spec.output_keys:
|
||||
parts.append(f"OUTPUTS STILL NEEDED: {', '.join(spec.output_keys)}")
|
||||
|
||||
# 4. Available tools reminder
|
||||
if spec.tools:
|
||||
parts.append(f"AVAILABLE TOOLS: {', '.join(spec.tools)}")
|
||||
|
||||
# 5. Spillover files — list actual files so the LLM can load
|
||||
# them immediately instead of having to call list_data_files first.
|
||||
# Inline adapt.md (agent memory) directly — it contains user rules
|
||||
# and identity preferences that must survive emergency compaction.
|
||||
spillover_dir = config.spillover_dir if config else None
|
||||
if spillover_dir:
|
||||
try:
|
||||
from pathlib import Path
|
||||
|
||||
data_dir = Path(spillover_dir)
|
||||
if data_dir.is_dir():
|
||||
# Inline adapt.md content directly
|
||||
adapt_path = data_dir / "adapt.md"
|
||||
if adapt_path.is_file():
|
||||
adapt_text = adapt_path.read_text(encoding="utf-8").strip()
|
||||
if adapt_text:
|
||||
parts.append(f"AGENT MEMORY (adapt.md):\n{adapt_text}")
|
||||
|
||||
all_files = sorted(
|
||||
f.name for f in data_dir.iterdir() if f.is_file() and f.name != "adapt.md"
|
||||
)
|
||||
# Separate conversation history files from regular data files
|
||||
conv_files = [f for f in all_files if re.match(r"conversation_\d+\.md$", f)]
|
||||
data_files = [f for f in all_files if f not in conv_files]
|
||||
|
||||
if conv_files:
|
||||
conv_list = "\n".join(
|
||||
f" - {f} (full path: {data_dir / f})" for f in conv_files
|
||||
)
|
||||
parts.append(
|
||||
"CONVERSATION HISTORY (freeform messages saved during compaction — "
|
||||
"use load_data('<filename>') to review earlier dialogue):\n" + conv_list
|
||||
)
|
||||
if data_files:
|
||||
file_list = "\n".join(
|
||||
f" - {f} (full path: {data_dir / f})" for f in data_files[:30]
|
||||
)
|
||||
parts.append("DATA FILES (use load_data('<filename>') to read):\n" + file_list)
|
||||
if not all_files:
|
||||
parts.append(
|
||||
"NOTE: Large tool results may have been saved to files. "
|
||||
"Use list_directory to check the data directory."
|
||||
)
|
||||
except Exception:
|
||||
parts.append(
|
||||
"NOTE: Large tool results were saved to files. "
|
||||
"Use read_file(path='<path>') to read them."
|
||||
)
|
||||
|
||||
# 6. Tool call history (prevent re-calling tools)
|
||||
if conversation is not None:
|
||||
tool_history = _extract_tool_call_history(conversation)
|
||||
if tool_history:
|
||||
parts.append(tool_history)
|
||||
|
||||
parts.append(
|
||||
"\nContinue working towards setting the remaining outputs. "
|
||||
"Use your tools and the inputs above."
|
||||
)
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
def _extract_tool_call_history(conversation: NodeConversation) -> str:
|
||||
"""Extract tool call history from conversation messages.
|
||||
|
||||
This is the instance-level variant that operates on a NodeConversation
|
||||
directly (vs. the module-level extract_tool_call_history in conversation.py
|
||||
which works on raw message lists).
|
||||
"""
|
||||
from framework.graph.conversation import extract_tool_call_history
|
||||
|
||||
return extract_tool_call_history(list(conversation.messages))
|
||||
@@ -0,0 +1,239 @@
|
||||
"""Cursor persistence, queue draining, and pause detection.
|
||||
|
||||
Handles the checkpoint/resume cycle: restoring state from a previous
|
||||
conversation store, writing cursor data, and managing injection/trigger
|
||||
queues between iterations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from framework.graph.conversation import ConversationStore, NodeConversation
|
||||
from framework.graph.event_loop.types import LoopConfig, OutputAccumulator, TriggerEvent
|
||||
from framework.graph.node import NodeContext
|
||||
from framework.llm.capabilities import supports_image_tool_results
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RestoredState:
|
||||
"""State recovered from a previous checkpoint."""
|
||||
|
||||
conversation: NodeConversation
|
||||
accumulator: OutputAccumulator
|
||||
start_iteration: int
|
||||
recent_responses: list[str]
|
||||
recent_tool_fingerprints: list[list[tuple[str, str]]]
|
||||
|
||||
|
||||
async def restore(
|
||||
conversation_store: ConversationStore | None,
|
||||
ctx: NodeContext,
|
||||
config: LoopConfig,
|
||||
) -> RestoredState | None:
|
||||
"""Attempt to restore from a previous checkpoint.
|
||||
|
||||
Returns a ``RestoredState`` with conversation, accumulator, iteration
|
||||
counter, and stall/doom-loop detection state — everything needed to
|
||||
resume exactly where execution stopped.
|
||||
"""
|
||||
if conversation_store is None:
|
||||
return None
|
||||
|
||||
# In isolated mode, filter parts by phase_id so the node only sees
|
||||
# its own messages in the shared flat conversation store. In
|
||||
# continuous mode (or when _restore is called for timer-resume)
|
||||
# load all parts — the full conversation threads across nodes.
|
||||
_is_continuous = getattr(ctx, "continuous_mode", False)
|
||||
phase_filter = None if _is_continuous else ctx.node_id
|
||||
conversation = await NodeConversation.restore(
|
||||
conversation_store,
|
||||
phase_id=phase_filter,
|
||||
)
|
||||
if conversation is None:
|
||||
return None
|
||||
|
||||
accumulator = await OutputAccumulator.restore(conversation_store)
|
||||
accumulator.spillover_dir = config.spillover_dir
|
||||
accumulator.max_value_chars = config.max_output_value_chars
|
||||
|
||||
cursor = await conversation_store.read_cursor()
|
||||
start_iteration = cursor.get("iteration", 0) + 1 if cursor else 0
|
||||
|
||||
# Restore stall/doom-loop detection state
|
||||
recent_responses: list[str] = cursor.get("recent_responses", []) if cursor else []
|
||||
raw_fps = cursor.get("recent_tool_fingerprints", []) if cursor else []
|
||||
recent_tool_fingerprints: list[list[tuple[str, str]]] = [
|
||||
[tuple(pair) for pair in fps] # type: ignore[misc]
|
||||
for fps in raw_fps
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Restored event loop: iteration={start_iteration}, "
|
||||
f"messages={conversation.message_count}, "
|
||||
f"outputs={list(accumulator.values.keys())}, "
|
||||
f"stall_window={len(recent_responses)}, "
|
||||
f"doom_window={len(recent_tool_fingerprints)}"
|
||||
)
|
||||
return RestoredState(
|
||||
conversation=conversation,
|
||||
accumulator=accumulator,
|
||||
start_iteration=start_iteration,
|
||||
recent_responses=recent_responses,
|
||||
recent_tool_fingerprints=recent_tool_fingerprints,
|
||||
)
|
||||
|
||||
|
||||
async def write_cursor(
|
||||
conversation_store: ConversationStore | None,
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
accumulator: OutputAccumulator,
|
||||
iteration: int,
|
||||
*,
|
||||
recent_responses: list[str] | None = None,
|
||||
recent_tool_fingerprints: list[list[tuple[str, str]]] | None = None,
|
||||
) -> None:
|
||||
"""Write checkpoint cursor for crash recovery.
|
||||
|
||||
Persists iteration counter, accumulator outputs, and stall/doom-loop
|
||||
detection state so that resume picks up exactly where execution stopped.
|
||||
"""
|
||||
if conversation_store:
|
||||
cursor = await conversation_store.read_cursor() or {}
|
||||
cursor.update(
|
||||
{
|
||||
"iteration": iteration,
|
||||
"node_id": ctx.node_id,
|
||||
"next_seq": conversation.next_seq,
|
||||
"outputs": accumulator.to_dict(),
|
||||
}
|
||||
)
|
||||
# Persist stall/doom-loop detection state for reliable resume
|
||||
if recent_responses is not None:
|
||||
cursor["recent_responses"] = recent_responses
|
||||
if recent_tool_fingerprints is not None:
|
||||
# Convert list[list[tuple]] → list[list[list]] for JSON
|
||||
cursor["recent_tool_fingerprints"] = [
|
||||
[list(pair) for pair in fps] for fps in recent_tool_fingerprints
|
||||
]
|
||||
await conversation_store.write_cursor(cursor)
|
||||
|
||||
|
||||
async def drain_injection_queue(
|
||||
queue: asyncio.Queue,
|
||||
conversation: NodeConversation,
|
||||
*,
|
||||
ctx: NodeContext,
|
||||
describe_images_as_text_fn: (
|
||||
Callable[[list[dict[str, Any]]], Awaitable[str | None]] | None
|
||||
) = None,
|
||||
) -> int:
|
||||
"""Drain all pending injected events as user messages. Returns count."""
|
||||
count = 0
|
||||
while not queue.empty():
|
||||
try:
|
||||
content, is_client_input, image_content = queue.get_nowait()
|
||||
logger.info(
|
||||
"[drain] injected message (client_input=%s, images=%d): %s",
|
||||
is_client_input,
|
||||
len(image_content) if image_content else 0,
|
||||
content[:200] if content else "(empty)",
|
||||
)
|
||||
if image_content and ctx.llm and not supports_image_tool_results(ctx.llm.model):
|
||||
logger.info(
|
||||
"Model '%s' does not support images; attempting vision fallback",
|
||||
ctx.llm.model,
|
||||
)
|
||||
if describe_images_as_text_fn is not None:
|
||||
description = await describe_images_as_text_fn(image_content)
|
||||
if description:
|
||||
content = f"{content}\n\n{description}" if content else description
|
||||
logger.info("[drain] image described as text via vision fallback")
|
||||
else:
|
||||
logger.info("[drain] no vision fallback available; images dropped")
|
||||
image_content = None
|
||||
# Real user input is stored as-is; external events get a prefix
|
||||
if is_client_input:
|
||||
await conversation.add_user_message(
|
||||
content,
|
||||
is_client_input=True,
|
||||
image_content=image_content,
|
||||
)
|
||||
else:
|
||||
await conversation.add_user_message(f"[External event]: {content}")
|
||||
count += 1
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
return count
|
||||
|
||||
|
||||
async def drain_trigger_queue(
|
||||
queue: asyncio.Queue,
|
||||
conversation: NodeConversation,
|
||||
) -> int:
|
||||
"""Drain all pending trigger events as a single batched user message.
|
||||
|
||||
Multiple triggers are merged so the LLM sees them atomically and can
|
||||
reason about all pending triggers before acting.
|
||||
"""
|
||||
triggers: list[TriggerEvent] = []
|
||||
while not queue.empty():
|
||||
try:
|
||||
triggers.append(queue.get_nowait())
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
if not triggers:
|
||||
return 0
|
||||
|
||||
parts: list[str] = []
|
||||
for t in triggers:
|
||||
task = t.payload.get("task", "")
|
||||
task_line = f"\nTask: {task}" if task else ""
|
||||
payload_str = json.dumps(t.payload, default=str)
|
||||
parts.append(f"[TRIGGER: {t.trigger_type}/{t.source_id}]{task_line}\n{payload_str}")
|
||||
|
||||
combined = "\n\n".join(parts)
|
||||
logger.info("[drain] %d trigger(s): %s", len(triggers), combined[:200])
|
||||
await conversation.add_user_message(combined)
|
||||
return len(triggers)
|
||||
|
||||
|
||||
async def check_pause(
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
iteration: int,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if pause has been requested. Returns True if paused.
|
||||
|
||||
Note: This check happens BEFORE starting iteration N, after completing N-1.
|
||||
If paused, the node exits having completed {iteration} iterations (0 to iteration-1).
|
||||
"""
|
||||
# Check executor-level pause event (for /pause command, Ctrl+Z)
|
||||
if ctx.pause_event and ctx.pause_event.is_set():
|
||||
completed = iteration # 0-indexed: iteration=3 means 3 iterations completed (0,1,2)
|
||||
logger.info(f"⏸ Pausing after {completed} iteration(s) completed (executor-level)")
|
||||
return True
|
||||
|
||||
# Check context-level pause flags (legacy/alternative methods)
|
||||
pause_requested = ctx.input_data.get("pause_requested", False)
|
||||
if not pause_requested:
|
||||
try:
|
||||
pause_requested = ctx.memory.read("pause_requested") or False
|
||||
except (PermissionError, KeyError):
|
||||
pause_requested = False
|
||||
if pause_requested:
|
||||
completed = iteration
|
||||
logger.info(f"⏸ Pausing after {completed} iteration(s) completed (context-level)")
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -0,0 +1,360 @@
|
||||
"""EventBus publishing helpers for the event loop.
|
||||
|
||||
Thin wrappers around EventBus.emit_*() calls that check for bus existence
|
||||
before publishing. Extracted to reduce noise in the main orchestrator.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from framework.graph.conversation import NodeConversation
|
||||
from framework.graph.event_loop.types import HookContext
|
||||
from framework.graph.node import NodeContext
|
||||
from framework.runtime.event_bus import EventBus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def publish_loop_started(
|
||||
event_bus: EventBus | None,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
max_iterations: int,
|
||||
execution_id: str = "",
|
||||
) -> None:
|
||||
if event_bus:
|
||||
await event_bus.emit_node_loop_started(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
max_iterations=max_iterations,
|
||||
execution_id=execution_id,
|
||||
)
|
||||
|
||||
|
||||
async def generate_action_plan(
|
||||
event_bus: EventBus | None,
|
||||
ctx: NodeContext,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
execution_id: str,
|
||||
) -> None:
|
||||
"""Generate a brief action plan via LLM and emit it as an SSE event.
|
||||
|
||||
Runs as a fire-and-forget task so it never blocks the main loop.
|
||||
"""
|
||||
try:
|
||||
system_prompt = ctx.node_spec.system_prompt or ""
|
||||
# Trim to keep the prompt small
|
||||
prompt_summary = system_prompt[:500]
|
||||
if len(system_prompt) > 500:
|
||||
prompt_summary += "..."
|
||||
|
||||
tool_names = [t.name for t in ctx.available_tools]
|
||||
output_keys = ctx.node_spec.output_keys or []
|
||||
|
||||
prompt = (
|
||||
f'You are about to work on a task as node "{node_id}".\n\n'
|
||||
f"System prompt:\n{prompt_summary}\n\n"
|
||||
f"Tools available: {tool_names}\n"
|
||||
f"Required outputs: {output_keys}\n\n"
|
||||
f"Write a brief action plan (2-5 bullet points) describing "
|
||||
f"what you will do to complete this task. Be specific and concise.\n"
|
||||
f"Return ONLY the plan text, no preamble."
|
||||
)
|
||||
|
||||
response = await ctx.llm.acomplete(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=1024,
|
||||
)
|
||||
|
||||
plan = response.content.strip()
|
||||
if plan and event_bus:
|
||||
await event_bus.emit_node_action_plan(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
plan=plan,
|
||||
execution_id=execution_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Action plan generation failed for node '%s': %s", node_id, e)
|
||||
|
||||
|
||||
async def publish_iteration(
|
||||
event_bus: EventBus | None,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
iteration: int,
|
||||
execution_id: str = "",
|
||||
extra_data: dict | None = None,
|
||||
) -> None:
|
||||
if event_bus:
|
||||
await event_bus.emit_node_loop_iteration(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
iteration=iteration,
|
||||
execution_id=execution_id,
|
||||
extra_data=extra_data,
|
||||
)
|
||||
|
||||
|
||||
async def publish_llm_turn_complete(
|
||||
event_bus: EventBus | None,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
stop_reason: str,
|
||||
model: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cached_tokens: int = 0,
|
||||
execution_id: str = "",
|
||||
iteration: int | None = None,
|
||||
) -> None:
|
||||
if event_bus:
|
||||
await event_bus.emit_llm_turn_complete(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
stop_reason=stop_reason,
|
||||
model=model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cached_tokens=cached_tokens,
|
||||
execution_id=execution_id,
|
||||
iteration=iteration,
|
||||
)
|
||||
|
||||
|
||||
def log_skip_judge(
|
||||
ctx: NodeContext,
|
||||
node_id: str,
|
||||
iteration: int,
|
||||
feedback: str,
|
||||
tool_calls: list[dict],
|
||||
llm_text: str,
|
||||
turn_tokens: dict[str, int],
|
||||
iter_start: float,
|
||||
) -> None:
|
||||
"""Log a CONTINUE step that skips judge evaluation (e.g., waiting for input)."""
|
||||
if ctx.runtime_logger:
|
||||
ctx.runtime_logger.log_step(
|
||||
node_id=node_id,
|
||||
node_type="event_loop",
|
||||
step_index=iteration,
|
||||
verdict="CONTINUE",
|
||||
verdict_feedback=feedback,
|
||||
tool_calls=tool_calls,
|
||||
llm_text=llm_text,
|
||||
input_tokens=turn_tokens.get("input", 0),
|
||||
output_tokens=turn_tokens.get("output", 0),
|
||||
latency_ms=int((time.time() - iter_start) * 1000),
|
||||
)
|
||||
|
||||
|
||||
async def publish_loop_completed(
|
||||
event_bus: EventBus | None,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
iterations: int,
|
||||
execution_id: str = "",
|
||||
) -> None:
|
||||
if event_bus:
|
||||
await event_bus.emit_node_loop_completed(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
iterations=iterations,
|
||||
execution_id=execution_id,
|
||||
)
|
||||
|
||||
|
||||
async def publish_context_usage(
|
||||
event_bus: EventBus | None,
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
trigger: str,
|
||||
) -> None:
|
||||
"""Emit a CONTEXT_USAGE_UPDATED event with current context window state."""
|
||||
if not event_bus:
|
||||
return
|
||||
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
|
||||
estimated = conversation.estimate_tokens()
|
||||
max_tokens = conversation._max_context_tokens
|
||||
ratio = estimated / max_tokens if max_tokens > 0 else 0.0
|
||||
await event_bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.CONTEXT_USAGE_UPDATED,
|
||||
stream_id=ctx.stream_id or ctx.node_id,
|
||||
node_id=ctx.node_id,
|
||||
data={
|
||||
"usage_ratio": round(ratio, 4),
|
||||
"usage_pct": round(ratio * 100),
|
||||
"message_count": conversation.message_count,
|
||||
"estimated_tokens": estimated,
|
||||
"max_context_tokens": max_tokens,
|
||||
"trigger": trigger,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def publish_stalled(
|
||||
event_bus: EventBus | None,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
execution_id: str = "",
|
||||
) -> None:
|
||||
if event_bus:
|
||||
await event_bus.emit_node_stalled(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
reason="Consecutive similar responses detected",
|
||||
execution_id=execution_id,
|
||||
)
|
||||
|
||||
|
||||
async def publish_text_delta(
|
||||
event_bus: EventBus | None,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
content: str,
|
||||
snapshot: str,
|
||||
ctx: NodeContext,
|
||||
execution_id: str = "",
|
||||
iteration: int | None = None,
|
||||
inner_turn: int = 0,
|
||||
) -> None:
|
||||
if event_bus:
|
||||
if ctx.node_spec.client_facing:
|
||||
await event_bus.emit_client_output_delta(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
content=content,
|
||||
snapshot=snapshot,
|
||||
execution_id=execution_id,
|
||||
iteration=iteration,
|
||||
inner_turn=inner_turn,
|
||||
)
|
||||
else:
|
||||
await event_bus.emit_llm_text_delta(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
content=content,
|
||||
snapshot=snapshot,
|
||||
execution_id=execution_id,
|
||||
inner_turn=inner_turn,
|
||||
)
|
||||
|
||||
|
||||
async def publish_tool_started(
|
||||
event_bus: EventBus | None,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
tool_use_id: str,
|
||||
tool_name: str,
|
||||
tool_input: dict,
|
||||
execution_id: str = "",
|
||||
) -> None:
|
||||
if event_bus:
|
||||
await event_bus.emit_tool_call_started(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
tool_use_id=tool_use_id,
|
||||
tool_name=tool_name,
|
||||
tool_input=tool_input,
|
||||
execution_id=execution_id,
|
||||
)
|
||||
|
||||
|
||||
async def publish_tool_completed(
|
||||
event_bus: EventBus | None,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
tool_use_id: str,
|
||||
tool_name: str,
|
||||
result: str,
|
||||
is_error: bool,
|
||||
execution_id: str = "",
|
||||
) -> None:
|
||||
if event_bus:
|
||||
await event_bus.emit_tool_call_completed(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
tool_use_id=tool_use_id,
|
||||
tool_name=tool_name,
|
||||
result=result,
|
||||
is_error=is_error,
|
||||
execution_id=execution_id,
|
||||
)
|
||||
|
||||
|
||||
async def publish_judge_verdict(
|
||||
event_bus: EventBus | None,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
action: str,
|
||||
feedback: str = "",
|
||||
judge_type: str = "implicit",
|
||||
iteration: int = 0,
|
||||
execution_id: str = "",
|
||||
) -> None:
|
||||
if event_bus:
|
||||
await event_bus.emit_judge_verdict(
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
action=action,
|
||||
feedback=feedback,
|
||||
judge_type=judge_type,
|
||||
iteration=iteration,
|
||||
execution_id=execution_id,
|
||||
)
|
||||
|
||||
|
||||
async def publish_output_key_set(
|
||||
event_bus: EventBus | None,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
key: str,
|
||||
execution_id: str = "",
|
||||
) -> None:
|
||||
if event_bus:
|
||||
await event_bus.emit_output_key_set(
|
||||
stream_id=stream_id, node_id=node_id, key=key, execution_id=execution_id
|
||||
)
|
||||
|
||||
|
||||
async def run_hooks(
|
||||
hooks_config: dict[str, list],
|
||||
event: str,
|
||||
conversation: NodeConversation,
|
||||
trigger: str | None = None,
|
||||
) -> None:
|
||||
"""Run all registered hooks for *event*, applying their results.
|
||||
|
||||
Each hook receives a HookContext and may return a HookResult that:
|
||||
- replaces the system prompt (result.system_prompt)
|
||||
- injects an extra user message (result.inject)
|
||||
Hooks run in registration order; each sees the prompt as left by the
|
||||
previous hook.
|
||||
"""
|
||||
hook_list = hooks_config.get(event, [])
|
||||
if not hook_list:
|
||||
return
|
||||
for hook in hook_list:
|
||||
ctx = HookContext(
|
||||
event=event,
|
||||
trigger=trigger,
|
||||
system_prompt=conversation.system_prompt,
|
||||
)
|
||||
try:
|
||||
result = await hook(ctx)
|
||||
except Exception:
|
||||
logger.warning("Hook '%s' raised an exception", event, exc_info=True)
|
||||
continue
|
||||
if result is None:
|
||||
continue
|
||||
if result.system_prompt:
|
||||
conversation.update_system_prompt(result.system_prompt)
|
||||
if result.inject:
|
||||
await conversation.add_user_message(result.inject)
|
||||
@@ -0,0 +1,175 @@
|
||||
"""Judge evaluation pipeline for the event loop."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
|
||||
from framework.graph.conversation import NodeConversation
|
||||
from framework.graph.event_loop.types import JudgeProtocol, JudgeVerdict, OutputAccumulator
|
||||
from framework.graph.node import NodeContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SubagentJudge:
|
||||
"""Judge for subagent execution."""
|
||||
|
||||
def __init__(self, task: str, max_iterations: int = 10):
|
||||
self._task = task
|
||||
self._max_iterations = max_iterations
|
||||
|
||||
async def evaluate(self, context: dict[str, object]) -> JudgeVerdict:
|
||||
missing = context.get("missing_keys", [])
|
||||
if not isinstance(missing, list) or not missing:
|
||||
return JudgeVerdict(action="ACCEPT", feedback="")
|
||||
|
||||
iteration = context.get("iteration", 0)
|
||||
if not isinstance(iteration, int):
|
||||
iteration = 0
|
||||
remaining = self._max_iterations - iteration - 1
|
||||
|
||||
if remaining <= 3:
|
||||
urgency = (
|
||||
f"URGENT: Only {remaining} iterations left. "
|
||||
f"Stop all other work and call set_output NOW for: {missing}"
|
||||
)
|
||||
elif remaining <= self._max_iterations // 2:
|
||||
urgency = (
|
||||
f"WARNING: {remaining} iterations remaining. "
|
||||
f"You must call set_output for: {missing}"
|
||||
)
|
||||
else:
|
||||
urgency = f"Missing output keys: {missing}. Use set_output to provide them."
|
||||
|
||||
return JudgeVerdict(action="RETRY", feedback=f"Your task: {self._task}\n{urgency}")
|
||||
|
||||
|
||||
async def judge_turn(
|
||||
*,
|
||||
mark_complete_flag: bool,
|
||||
judge: JudgeProtocol | None,
|
||||
ctx: NodeContext,
|
||||
conversation: NodeConversation,
|
||||
accumulator: OutputAccumulator,
|
||||
assistant_text: str,
|
||||
tool_results: list[dict[str, object]],
|
||||
iteration: int,
|
||||
get_missing_output_keys_fn: Callable[
|
||||
[OutputAccumulator, list[str] | None, list[str] | None],
|
||||
list[str],
|
||||
],
|
||||
max_context_tokens: int,
|
||||
) -> JudgeVerdict:
|
||||
"""Evaluate the current state using judge or implicit logic.
|
||||
|
||||
Evaluation levels (in order):
|
||||
0. Short-circuits: mark_complete, skip_judge, tool-continue.
|
||||
1. Custom judge (JudgeProtocol) — full authority when set.
|
||||
2. Implicit judge — output-key check + optional conversation-aware
|
||||
quality gate (when ``success_criteria`` is defined).
|
||||
|
||||
Returns a JudgeVerdict. ``feedback=None`` means no real evaluation
|
||||
happened (skip_judge, tool-continue); the caller must not inject a
|
||||
feedback message. Any non-None feedback (including ``""``) means a
|
||||
real evaluation occurred and will be logged into the conversation.
|
||||
"""
|
||||
# --- Level 0: short-circuits (no evaluation) -----------------------
|
||||
|
||||
if mark_complete_flag:
|
||||
return JudgeVerdict(action="ACCEPT")
|
||||
|
||||
if ctx.node_spec.skip_judge:
|
||||
return JudgeVerdict(action="RETRY") # feedback=None → not logged
|
||||
|
||||
# --- Level 1: custom judge -----------------------------------------
|
||||
|
||||
if judge is not None:
|
||||
context = {
|
||||
"assistant_text": assistant_text,
|
||||
"tool_calls": tool_results,
|
||||
"output_accumulator": accumulator.to_dict(),
|
||||
"accumulator": accumulator,
|
||||
"iteration": iteration,
|
||||
"conversation_summary": conversation.export_summary(),
|
||||
"output_keys": ctx.node_spec.output_keys,
|
||||
"missing_keys": get_missing_output_keys_fn(
|
||||
accumulator, ctx.node_spec.output_keys, ctx.node_spec.nullable_output_keys
|
||||
),
|
||||
}
|
||||
verdict = await judge.evaluate(context)
|
||||
# Ensure evaluated RETRY always carries feedback for logging.
|
||||
if verdict.action == "RETRY" and not verdict.feedback:
|
||||
return JudgeVerdict(action="RETRY", feedback="Custom judge returned RETRY.")
|
||||
return verdict
|
||||
|
||||
# --- Level 2: implicit judge ---------------------------------------
|
||||
|
||||
# Real tool calls were made — let the agent keep working.
|
||||
if tool_results:
|
||||
return JudgeVerdict(action="RETRY") # feedback=None → not logged
|
||||
|
||||
missing = get_missing_output_keys_fn(
|
||||
accumulator, ctx.node_spec.output_keys, ctx.node_spec.nullable_output_keys
|
||||
)
|
||||
|
||||
if missing:
|
||||
return JudgeVerdict(
|
||||
action="RETRY",
|
||||
feedback=(
|
||||
f"Task incomplete. Required outputs not yet produced: {missing}. "
|
||||
f"Follow your system prompt instructions to complete the work."
|
||||
),
|
||||
)
|
||||
|
||||
# All output keys present — run safety checks before accepting.
|
||||
|
||||
output_keys = ctx.node_spec.output_keys or []
|
||||
nullable_keys = set(ctx.node_spec.nullable_output_keys or [])
|
||||
|
||||
# All-nullable with nothing set → node produced nothing useful.
|
||||
all_nullable = output_keys and nullable_keys >= set(output_keys)
|
||||
none_set = not any(accumulator.get(k) is not None for k in output_keys)
|
||||
if all_nullable and none_set:
|
||||
return JudgeVerdict(
|
||||
action="RETRY",
|
||||
feedback=(
|
||||
f"No output keys have been set yet. "
|
||||
f"Use set_output to set at least one of: {output_keys}"
|
||||
),
|
||||
)
|
||||
|
||||
# Client-facing with no output keys → continuous interaction node.
|
||||
# Inject tool-use pressure instead of auto-accepting.
|
||||
if not output_keys and ctx.node_spec.client_facing:
|
||||
return JudgeVerdict(
|
||||
action="RETRY",
|
||||
feedback=(
|
||||
"STOP describing what you will do. "
|
||||
"You have FULL access to all tools — file creation, "
|
||||
"shell commands, MCP tools — and you CAN call them "
|
||||
"directly in your response. Respond ONLY with tool "
|
||||
"calls, no prose. Execute the task now."
|
||||
),
|
||||
)
|
||||
|
||||
# Level 2b: conversation-aware quality check (if success_criteria set)
|
||||
if ctx.node_spec.success_criteria and ctx.llm:
|
||||
from framework.graph.conversation_judge import evaluate_phase_completion
|
||||
|
||||
verdict = await evaluate_phase_completion(
|
||||
llm=ctx.llm,
|
||||
conversation=conversation,
|
||||
phase_name=ctx.node_spec.name,
|
||||
phase_description=ctx.node_spec.description,
|
||||
success_criteria=ctx.node_spec.success_criteria,
|
||||
accumulator_state=accumulator.to_dict(),
|
||||
max_context_tokens=max_context_tokens,
|
||||
)
|
||||
if verdict.action != "ACCEPT":
|
||||
return JudgeVerdict(
|
||||
action=verdict.action,
|
||||
feedback=verdict.feedback or "Phase criteria not met.",
|
||||
)
|
||||
|
||||
return JudgeVerdict(action="ACCEPT", feedback="")
|
||||
@@ -0,0 +1,106 @@
|
||||
"""Stall and doom-loop detection for the event loop.
|
||||
|
||||
Pure functions with no class dependencies — safe to call from any context.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
|
||||
def ngram_similarity(s1: str, s2: str, n: int = 2) -> float:
|
||||
"""Jaccard similarity of n-gram sets.
|
||||
|
||||
Returns 0.0-1.0, where 1.0 is exact match.
|
||||
Fast: O(len(s) + len(s2)) using set operations.
|
||||
"""
|
||||
|
||||
def _ngrams(s: str) -> set[str]:
|
||||
return {s[i : i + n] for i in range(len(s) - n + 1) if s.strip()}
|
||||
|
||||
if not s1 or not s2:
|
||||
return 0.0
|
||||
|
||||
ngrams1, ngrams2 = _ngrams(s1.lower()), _ngrams(s2.lower())
|
||||
if not ngrams1 or not ngrams2:
|
||||
return 0.0
|
||||
|
||||
intersection = len(ngrams1 & ngrams2)
|
||||
union = len(ngrams1 | ngrams2)
|
||||
return intersection / union if union else 0.0
|
||||
|
||||
|
||||
def is_stalled(
|
||||
recent_responses: list[str],
|
||||
threshold: int,
|
||||
similarity_threshold: float,
|
||||
) -> bool:
|
||||
"""Detect stall using n-gram similarity.
|
||||
|
||||
Detects when ALL N consecutive responses are mutually similar
|
||||
(>= threshold). A single dissimilar response resets the signal.
|
||||
This catches phrases like "I'm still stuck" vs "I'm stuck"
|
||||
without false-positives on "attempt 1" vs "attempt 2".
|
||||
"""
|
||||
if len(recent_responses) < threshold:
|
||||
return False
|
||||
if not recent_responses[0]:
|
||||
return False
|
||||
|
||||
# Every consecutive pair must be similar
|
||||
for i in range(1, len(recent_responses)):
|
||||
if ngram_similarity(recent_responses[i], recent_responses[i - 1]) < similarity_threshold:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def fingerprint_tool_calls(
|
||||
tool_results: list[dict],
|
||||
) -> list[tuple[str, str]]:
|
||||
"""Create deterministic fingerprints for a turn's tool calls.
|
||||
|
||||
Each fingerprint is (tool_name, canonical_args_json). Order-sensitive
|
||||
so [search("a"), fetch("b")] != [fetch("b"), search("a")].
|
||||
"""
|
||||
fingerprints = []
|
||||
for tr in tool_results:
|
||||
name = tr.get("tool_name", "")
|
||||
args = tr.get("tool_input", {})
|
||||
try:
|
||||
canonical = json.dumps(args, sort_keys=True, default=str)
|
||||
except (TypeError, ValueError):
|
||||
canonical = str(args)
|
||||
fingerprints.append((name, canonical))
|
||||
return fingerprints
|
||||
|
||||
|
||||
def is_tool_doom_loop(
|
||||
recent_tool_fingerprints: list[list[tuple[str, str]]],
|
||||
threshold: int,
|
||||
enabled: bool = True,
|
||||
) -> tuple[bool, str]:
|
||||
"""Detect doom loop via exact fingerprint match.
|
||||
|
||||
Detects when N consecutive turns invoke the same tools with
|
||||
identical (canonicalized) arguments. Different arguments mean
|
||||
different work, so only exact matches count.
|
||||
|
||||
Returns (is_doom_loop, description).
|
||||
"""
|
||||
if not enabled:
|
||||
return False, ""
|
||||
if len(recent_tool_fingerprints) < threshold:
|
||||
return False, ""
|
||||
first = recent_tool_fingerprints[0]
|
||||
if not first:
|
||||
return False, ""
|
||||
|
||||
# All turns in the window must match the first exactly
|
||||
if all(fp == first for fp in recent_tool_fingerprints[1:]):
|
||||
tool_names = [name for name, _ in first]
|
||||
desc = (
|
||||
f"Doom loop detected: {len(recent_tool_fingerprints)} "
|
||||
f"identical consecutive tool calls ({', '.join(tool_names)})"
|
||||
)
|
||||
return True, desc
|
||||
return False, ""
|
||||
@@ -0,0 +1,412 @@
|
||||
"""Subagent execution for the event loop.
|
||||
|
||||
Handles the full subagent lifecycle: validation, context setup, tool filtering,
|
||||
conversation store derivation, execution, and cleanup. Also includes the
|
||||
_EscalationReceiver helper used for subagent → queen escalation routing.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from framework.graph.conversation import ConversationStore
|
||||
from framework.graph.event_loop.judge_pipeline import SubagentJudge
|
||||
from framework.graph.event_loop.types import LoopConfig, OutputAccumulator
|
||||
from framework.graph.node import NodeContext, SharedMemory
|
||||
from framework.llm.provider import ToolResult, ToolUse
|
||||
from framework.runtime.event_bus import EventBus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.graph.event_loop_node import EventLoopNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EscalationReceiver:
|
||||
"""Temporary receiver registered in node_registry for subagent escalation routing.
|
||||
|
||||
When a subagent calls ``report_to_parent(wait_for_response=True)``, the callback
|
||||
creates one of these, registers it under a unique escalation ID in the executor's
|
||||
``node_registry``, and awaits ``wait()``. The TUI / runner calls
|
||||
``inject_input(escalation_id, content)`` which the ``ExecutionStream`` routes here
|
||||
via ``inject_event()`` — matching the same ``hasattr(node, "inject_event")`` check
|
||||
used for regular ``EventLoopNode`` instances.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._event = asyncio.Event()
|
||||
self._response: str | None = None
|
||||
self._awaiting_input = True # So inject_worker_message() can prefer us
|
||||
|
||||
async def inject_event(
|
||||
self,
|
||||
content: str,
|
||||
*,
|
||||
is_client_input: bool = False,
|
||||
image_content: list[dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
"""Called by ExecutionStream.inject_input() when the user responds."""
|
||||
self._response = content
|
||||
self._event.set()
|
||||
|
||||
async def wait(self) -> str | None:
|
||||
"""Block until inject_event() delivers the user's response."""
|
||||
await self._event.wait()
|
||||
return self._response
|
||||
|
||||
|
||||
async def execute_subagent(
|
||||
ctx: NodeContext,
|
||||
agent_id: str,
|
||||
task: str,
|
||||
*,
|
||||
config: LoopConfig,
|
||||
event_loop_node_cls: type[EventLoopNode],
|
||||
escalation_receiver_cls: type[EscalationReceiver],
|
||||
accumulator: OutputAccumulator | None = None,
|
||||
event_bus: EventBus | None = None,
|
||||
tool_executor: Callable[[ToolUse], ToolResult | Awaitable[ToolResult]] | None = None,
|
||||
conversation_store: ConversationStore | None = None,
|
||||
subagent_instance_counter: dict[str, int] | None = None,
|
||||
) -> ToolResult:
|
||||
"""Execute a subagent and return the result as a ToolResult.
|
||||
|
||||
The subagent:
|
||||
- Gets a fresh conversation with just the task
|
||||
- Has read-only access to the parent's readable memory
|
||||
- Cannot delegate to its own subagents (prevents recursion)
|
||||
- Returns its output in structured JSON format
|
||||
|
||||
Args:
|
||||
ctx: Parent node's context (for memory, tools, LLM access).
|
||||
agent_id: The node ID of the subagent to invoke.
|
||||
task: The task description to give the subagent.
|
||||
accumulator: Parent's OutputAccumulator.
|
||||
event_bus: EventBus for lifecycle events.
|
||||
config: LoopConfig for iteration/tool limits.
|
||||
tool_executor: Tool executor callable.
|
||||
conversation_store: Parent conversation store (for deriving subagent store).
|
||||
subagent_instance_counter: Mutable counter dict for unique subagent paths.
|
||||
|
||||
Returns:
|
||||
ToolResult with structured JSON output.
|
||||
"""
|
||||
# Log subagent invocation start
|
||||
logger.info(
|
||||
"\n" + "=" * 60 + "\n"
|
||||
"🤖 SUBAGENT INVOCATION\n"
|
||||
"=" * 60 + "\n"
|
||||
"Parent Node: %s\n"
|
||||
"Subagent ID: %s\n"
|
||||
"Task: %s\n" + "=" * 60,
|
||||
ctx.node_id,
|
||||
agent_id,
|
||||
task[:500] + "..." if len(task) > 500 else task,
|
||||
)
|
||||
|
||||
# 1. Validate agent exists in registry
|
||||
if agent_id not in ctx.node_registry:
|
||||
return ToolResult(
|
||||
tool_use_id="",
|
||||
content=json.dumps(
|
||||
{
|
||||
"message": f"Sub-agent '{agent_id}' not found in registry",
|
||||
"data": None,
|
||||
"metadata": {"agent_id": agent_id, "success": False, "error": "not_found"},
|
||||
}
|
||||
),
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
subagent_spec = ctx.node_registry[agent_id]
|
||||
|
||||
# 2. Create read-only memory snapshot
|
||||
parent_data = ctx.memory.read_all()
|
||||
|
||||
# Merge in-flight outputs from the parent's accumulator.
|
||||
if accumulator:
|
||||
for key, value in accumulator.to_dict().items():
|
||||
if key not in parent_data:
|
||||
parent_data[key] = value
|
||||
|
||||
subagent_memory = SharedMemory()
|
||||
for key, value in parent_data.items():
|
||||
subagent_memory.write(key, value, validate=False)
|
||||
|
||||
read_keys = set(parent_data.keys()) | set(subagent_spec.input_keys or [])
|
||||
scoped_memory = subagent_memory.with_permissions(
|
||||
read_keys=list(read_keys),
|
||||
write_keys=[], # Read-only!
|
||||
)
|
||||
|
||||
# 2b. Compute instance counter early so the callback and child context
|
||||
# share the same stable node_id for this subagent invocation.
|
||||
if subagent_instance_counter is not None:
|
||||
subagent_instance_counter.setdefault(agent_id, 0)
|
||||
subagent_instance_counter[agent_id] += 1
|
||||
subagent_instance = str(subagent_instance_counter[agent_id])
|
||||
else:
|
||||
subagent_instance = "1"
|
||||
|
||||
if subagent_instance == "1":
|
||||
sa_node_id = f"{ctx.node_id}:subagent:{agent_id}"
|
||||
else:
|
||||
sa_node_id = f"{ctx.node_id}:subagent:{agent_id}:{subagent_instance}"
|
||||
|
||||
# 2c. Set up report callback (one-way channel to parent / event bus)
|
||||
subagent_reports: list[dict] = []
|
||||
|
||||
async def _report_callback(
|
||||
message: str,
|
||||
data: dict | None = None,
|
||||
*,
|
||||
wait_for_response: bool = False,
|
||||
) -> str | None:
|
||||
subagent_reports.append({"message": message, "data": data, "timestamp": time.time()})
|
||||
if event_bus:
|
||||
await event_bus.emit_subagent_report(
|
||||
stream_id=ctx.node_id,
|
||||
node_id=sa_node_id,
|
||||
subagent_id=agent_id,
|
||||
message=message,
|
||||
data=data,
|
||||
execution_id=ctx.execution_id,
|
||||
)
|
||||
|
||||
if not wait_for_response:
|
||||
return None
|
||||
|
||||
if not event_bus:
|
||||
logger.warning(
|
||||
"Subagent '%s' requested user response but no event_bus available",
|
||||
agent_id,
|
||||
)
|
||||
return None
|
||||
|
||||
# Create isolated receiver and register for input routing
|
||||
import uuid
|
||||
|
||||
escalation_id = f"{ctx.node_id}:escalation:{uuid.uuid4().hex[:8]}"
|
||||
receiver = escalation_receiver_cls()
|
||||
registry = ctx.shared_node_registry
|
||||
|
||||
registry[escalation_id] = receiver
|
||||
try:
|
||||
await event_bus.emit_escalation_requested(
|
||||
stream_id=ctx.stream_id or ctx.node_id,
|
||||
node_id=escalation_id,
|
||||
reason=f"Subagent report (wait_for_response) from {agent_id}",
|
||||
context=message,
|
||||
execution_id=ctx.execution_id,
|
||||
)
|
||||
# Block until queen responds
|
||||
return await receiver.wait()
|
||||
finally:
|
||||
registry.pop(escalation_id, None)
|
||||
|
||||
# 3. Filter tools for subagent
|
||||
subagent_tool_names = set(subagent_spec.tools or [])
|
||||
tool_source = ctx.all_tools if ctx.all_tools else ctx.available_tools
|
||||
|
||||
# GCU auto-population
|
||||
if subagent_spec.node_type == "gcu" and not subagent_tool_names:
|
||||
subagent_tools = [t for t in tool_source if t.name != "delegate_to_sub_agent"]
|
||||
else:
|
||||
subagent_tools = [
|
||||
t
|
||||
for t in tool_source
|
||||
if t.name in subagent_tool_names and t.name != "delegate_to_sub_agent"
|
||||
]
|
||||
|
||||
missing = subagent_tool_names - {t.name for t in subagent_tools}
|
||||
if missing:
|
||||
logger.warning(
|
||||
"Subagent '%s' requested tools not found in catalog: %s",
|
||||
agent_id,
|
||||
sorted(missing),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"📦 Subagent '%s' configuration:\n"
|
||||
" - System prompt: %s\n"
|
||||
" - Tools available (%d): %s\n"
|
||||
" - Memory keys inherited: %s",
|
||||
agent_id,
|
||||
(subagent_spec.system_prompt[:200] + "...")
|
||||
if subagent_spec.system_prompt and len(subagent_spec.system_prompt) > 200
|
||||
else subagent_spec.system_prompt,
|
||||
len(subagent_tools),
|
||||
[t.name for t in subagent_tools],
|
||||
list(parent_data.keys()),
|
||||
)
|
||||
|
||||
# 4. Build subagent context
|
||||
max_iter = min(config.max_iterations, 10)
|
||||
subagent_ctx = NodeContext(
|
||||
runtime=ctx.runtime,
|
||||
node_id=sa_node_id,
|
||||
node_spec=subagent_spec,
|
||||
memory=scoped_memory,
|
||||
input_data={"task": task, **parent_data},
|
||||
llm=ctx.llm,
|
||||
available_tools=subagent_tools,
|
||||
goal_context=(
|
||||
f"Your specific task: {task}\n\n"
|
||||
f"COMPLETION REQUIREMENTS:\n"
|
||||
f"When your task is done, you MUST call set_output() "
|
||||
f"for each required key: {subagent_spec.output_keys}\n"
|
||||
f"Alternatively, call report_to_parent(mark_complete=true) "
|
||||
f"with your findings in message/data.\n"
|
||||
f"You have a maximum of {max_iter} turns to complete this task."
|
||||
),
|
||||
goal=ctx.goal,
|
||||
max_tokens=ctx.max_tokens,
|
||||
runtime_logger=ctx.runtime_logger,
|
||||
is_subagent_mode=True, # Prevents nested delegation
|
||||
report_callback=_report_callback,
|
||||
node_registry={}, # Empty - no nested subagents
|
||||
shared_node_registry=ctx.shared_node_registry, # For escalation routing
|
||||
)
|
||||
|
||||
# 5. Create and execute subagent EventLoopNode
|
||||
subagent_conv_store = None
|
||||
if conversation_store is not None:
|
||||
from framework.storage.conversation_store import FileConversationStore
|
||||
|
||||
parent_base = getattr(conversation_store, "_base", None)
|
||||
if parent_base is not None:
|
||||
conversations_dir = parent_base.parent
|
||||
subagent_dir_name = f"{agent_id}-{subagent_instance}"
|
||||
subagent_store_path = conversations_dir / subagent_dir_name
|
||||
subagent_conv_store = FileConversationStore(base_path=subagent_store_path)
|
||||
|
||||
# Derive a subagent-scoped spillover dir
|
||||
subagent_spillover = None
|
||||
if config.spillover_dir:
|
||||
subagent_spillover = str(Path(config.spillover_dir) / agent_id / subagent_instance)
|
||||
|
||||
subagent_node = event_loop_node_cls(
|
||||
event_bus=event_bus,
|
||||
judge=SubagentJudge(task=task, max_iterations=max_iter),
|
||||
config=LoopConfig(
|
||||
max_iterations=max_iter,
|
||||
max_tool_calls_per_turn=config.max_tool_calls_per_turn,
|
||||
tool_call_overflow_margin=config.tool_call_overflow_margin,
|
||||
max_context_tokens=config.max_context_tokens,
|
||||
stall_detection_threshold=config.stall_detection_threshold,
|
||||
max_tool_result_chars=config.max_tool_result_chars,
|
||||
spillover_dir=subagent_spillover,
|
||||
),
|
||||
tool_executor=tool_executor,
|
||||
conversation_store=subagent_conv_store,
|
||||
)
|
||||
|
||||
# Inject a unique GCU browser profile for this subagent
|
||||
_profile_token = None
|
||||
try:
|
||||
from gcu.browser.session import set_active_profile as _set_gcu_profile
|
||||
|
||||
_profile_token = _set_gcu_profile(f"{agent_id}-{subagent_instance}")
|
||||
except ImportError:
|
||||
pass # GCU tools not installed; no-op
|
||||
|
||||
try:
|
||||
logger.info("🚀 Starting subagent '%s' execution...", agent_id)
|
||||
start_time = time.time()
|
||||
result = await subagent_node.execute(subagent_ctx)
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
separator = "-" * 60
|
||||
logger.info(
|
||||
"\n%s\n"
|
||||
"✅ SUBAGENT '%s' COMPLETED\n"
|
||||
"%s\n"
|
||||
"Success: %s\n"
|
||||
"Latency: %dms\n"
|
||||
"Tokens used: %s\n"
|
||||
"Output keys: %s\n"
|
||||
"%s",
|
||||
separator,
|
||||
agent_id,
|
||||
separator,
|
||||
result.success,
|
||||
latency_ms,
|
||||
result.tokens_used,
|
||||
list(result.output.keys()) if result.output else [],
|
||||
separator,
|
||||
)
|
||||
|
||||
result_json = {
|
||||
"message": (
|
||||
f"Sub-agent '{agent_id}' completed successfully"
|
||||
if result.success
|
||||
else f"Sub-agent '{agent_id}' failed: {result.error}"
|
||||
),
|
||||
"data": result.output,
|
||||
"reports": subagent_reports if subagent_reports else None,
|
||||
"metadata": {
|
||||
"agent_id": agent_id,
|
||||
"success": result.success,
|
||||
"tokens_used": result.tokens_used,
|
||||
"latency_ms": latency_ms,
|
||||
"report_count": len(subagent_reports),
|
||||
},
|
||||
}
|
||||
|
||||
return ToolResult(
|
||||
tool_use_id="",
|
||||
content=json.dumps(result_json, indent=2, default=str),
|
||||
is_error=not result.success,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"\n" + "!" * 60 + "\n❌ SUBAGENT '%s' FAILED\nError: %s\n" + "!" * 60,
|
||||
agent_id,
|
||||
str(e),
|
||||
)
|
||||
result_json = {
|
||||
"message": f"Sub-agent '{agent_id}' raised exception: {e}",
|
||||
"data": None,
|
||||
"metadata": {
|
||||
"agent_id": agent_id,
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
},
|
||||
}
|
||||
return ToolResult(
|
||||
tool_use_id="",
|
||||
content=json.dumps(result_json, indent=2),
|
||||
is_error=True,
|
||||
)
|
||||
finally:
|
||||
# Restore the GCU profile context
|
||||
if _profile_token is not None:
|
||||
from gcu.browser.session import _active_profile as _gcu_profile_var
|
||||
|
||||
_gcu_profile_var.reset(_profile_token)
|
||||
|
||||
# Stop the browser session for this subagent's profile
|
||||
if tool_executor is not None:
|
||||
_subagent_profile = f"{agent_id}-{subagent_instance}"
|
||||
try:
|
||||
_stop_use = ToolUse(
|
||||
id="gcu-cleanup",
|
||||
name="browser_stop",
|
||||
input={"profile": _subagent_profile},
|
||||
)
|
||||
_stop_result = tool_executor(_stop_use)
|
||||
if asyncio.iscoroutine(_stop_result) or asyncio.isfuture(_stop_result):
|
||||
await _stop_result
|
||||
except Exception as _gcu_exc:
|
||||
logger.warning(
|
||||
"GCU browser_stop failed for profile %r: %s",
|
||||
_subagent_profile,
|
||||
_gcu_exc,
|
||||
)
|
||||
@@ -0,0 +1,369 @@
|
||||
"""Synthetic tool builders for the event loop.
|
||||
|
||||
Factory functions that create ``Tool`` definitions for framework-level
|
||||
synthetic tools (set_output, ask_user, escalate, delegate, report_to_parent).
|
||||
Also includes the ``handle_set_output`` validation logic.
|
||||
|
||||
All functions are pure — they receive explicit parameters and return
|
||||
``Tool`` or ``ToolResult`` objects with no side effects.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from framework.llm.provider import Tool, ToolResult
|
||||
|
||||
|
||||
def build_ask_user_tool() -> Tool:
|
||||
"""Build the synthetic ask_user tool for explicit user-input requests.
|
||||
|
||||
Client-facing nodes call ask_user() when they need to pause and wait
|
||||
for user input. Text-only turns WITHOUT ask_user flow through without
|
||||
blocking, allowing progress updates and summaries to stream freely.
|
||||
"""
|
||||
return Tool(
|
||||
name="ask_user",
|
||||
description=(
|
||||
"You MUST call this tool whenever you need the user's response. "
|
||||
"Always call it after greeting the user, asking a question, or "
|
||||
"requesting approval. Do NOT call it for status updates or "
|
||||
"summaries that don't require a response. "
|
||||
"Always include 2-3 predefined options. The UI automatically "
|
||||
"appends an 'Other' free-text input after your options, so NEVER "
|
||||
"include catch-all options like 'Custom idea', 'Something else', "
|
||||
"'Other', or 'None of the above' — the UI handles that. "
|
||||
"When the question primarily needs a typed answer but you must "
|
||||
"include options, make one option signal that typing is expected "
|
||||
"(e.g. 'I\\'ll type my response'). This helps users discover the "
|
||||
"free-text input. "
|
||||
"The ONLY exception: omit options when the question demands a "
|
||||
"free-form answer the user must type out (e.g. 'Describe your "
|
||||
"agent idea', 'Paste the error message'). "
|
||||
'{"question": "What would you like to do?", "options": '
|
||||
'["Build a new agent", "Modify existing agent", "Run tests"]} '
|
||||
"Free-form example: "
|
||||
'{"question": "Describe the agent you want to build."}'
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "The question or prompt shown to the user.",
|
||||
},
|
||||
"options": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"2-3 specific predefined choices. Include in most cases. "
|
||||
'Example: ["Option A", "Option B", "Option C"]. '
|
||||
"The UI always appends an 'Other' free-text input, so "
|
||||
"do NOT include catch-alls like 'Custom idea' or 'Other'. "
|
||||
"Omit ONLY when the user must type a free-form answer."
|
||||
),
|
||||
"minItems": 2,
|
||||
"maxItems": 3,
|
||||
},
|
||||
},
|
||||
"required": ["question"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def build_ask_user_multiple_tool() -> Tool:
|
||||
"""Build the synthetic ask_user_multiple tool for batched questions.
|
||||
|
||||
Queen-only tool that presents multiple questions at once so the user
|
||||
can answer them all in a single interaction rather than one at a time.
|
||||
"""
|
||||
return Tool(
|
||||
name="ask_user_multiple",
|
||||
description=(
|
||||
"Ask the user multiple questions at once. Use this instead of "
|
||||
"ask_user when you have 2 or more questions to ask in the same "
|
||||
"turn — it lets the user answer everything in one go rather than "
|
||||
"going back and forth. Each question can have its own predefined "
|
||||
"options (2-3 choices) or be free-form. The UI renders all "
|
||||
"questions together with a single Submit button. "
|
||||
"ALWAYS prefer this over ask_user when you have multiple things "
|
||||
"to clarify. "
|
||||
"IMPORTANT: Do NOT repeat the questions in your text response — "
|
||||
"the widget renders them. Keep your text to a brief intro only. "
|
||||
'{"questions": ['
|
||||
' {"id": "scope", "prompt": "What scope?", "options": ["Full", "Partial"]},'
|
||||
' {"id": "format", "prompt": "Output format?", "options": ["PDF", "CSV", "JSON"]},'
|
||||
' {"id": "details", "prompt": "Any special requirements?"}'
|
||||
"]}"
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"questions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Short identifier for this question (used in the response)."
|
||||
),
|
||||
},
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "The question text shown to the user.",
|
||||
},
|
||||
"options": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"2-3 predefined choices. The UI appends an "
|
||||
"'Other' free-text input automatically. "
|
||||
"Omit only when the user must type a free-form answer."
|
||||
),
|
||||
"minItems": 2,
|
||||
"maxItems": 3,
|
||||
},
|
||||
},
|
||||
"required": ["id", "prompt"],
|
||||
},
|
||||
"minItems": 2,
|
||||
"maxItems": 8,
|
||||
"description": "List of questions to present to the user.",
|
||||
},
|
||||
},
|
||||
"required": ["questions"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def build_set_output_tool(output_keys: list[str] | None) -> Tool | None:
|
||||
"""Build the synthetic set_output tool for explicit output declaration."""
|
||||
if not output_keys:
|
||||
return None
|
||||
return Tool(
|
||||
name="set_output",
|
||||
description=(
|
||||
"Set an output value for this node. Call once per output key. "
|
||||
"Use this for brief notes, counts, status, and file references — "
|
||||
"NOT for large data payloads. When a tool result was saved to a "
|
||||
"data file, pass the filename as the value "
|
||||
"(e.g. 'google_sheets_get_values_1.txt') so the next phase can "
|
||||
"load the full data. Values exceeding ~2000 characters are "
|
||||
"auto-saved to data files. "
|
||||
f"Valid keys: {output_keys}"
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": f"Output key. Must be one of: {output_keys}",
|
||||
"enum": output_keys,
|
||||
},
|
||||
"value": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The output value — a brief note, count, status, "
|
||||
"or data filename reference."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["key", "value"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def build_escalate_tool() -> Tool:
|
||||
"""Build the synthetic escalate tool for worker -> queen handoff."""
|
||||
return Tool(
|
||||
name="escalate",
|
||||
description=(
|
||||
"Escalate to the queen when requesting user input, "
|
||||
"blocked by errors, missing "
|
||||
"credentials, or ambiguous constraints that require supervisor "
|
||||
"guidance. Include a concise reason and optional context. "
|
||||
"The node will pause until the queen injects guidance."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reason": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Short reason for escalation (e.g. 'Tool repeatedly failing')."
|
||||
),
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": "Optional diagnostic details for the queen.",
|
||||
},
|
||||
},
|
||||
"required": ["reason"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def build_delegate_tool(sub_agents: list[str], node_registry: dict[str, Any]) -> Tool | None:
|
||||
"""Build the synthetic delegate_to_sub_agent tool for subagent invocation.
|
||||
|
||||
Args:
|
||||
sub_agents: List of node IDs that can be invoked as subagents.
|
||||
node_registry: Map of node_id -> NodeSpec for looking up subagent descriptions.
|
||||
|
||||
Returns:
|
||||
Tool definition if sub_agents is non-empty, None otherwise.
|
||||
"""
|
||||
if not sub_agents:
|
||||
return None
|
||||
|
||||
agent_descriptions = []
|
||||
for agent_id in sub_agents:
|
||||
spec = node_registry.get(agent_id)
|
||||
if spec:
|
||||
desc = getattr(spec, "description", "(no description)")
|
||||
agent_descriptions.append(f"- {agent_id}: {desc}")
|
||||
else:
|
||||
agent_descriptions.append(f"- {agent_id}: (not found in registry)")
|
||||
|
||||
return Tool(
|
||||
name="delegate_to_sub_agent",
|
||||
description=(
|
||||
"Delegate a task to a specialized sub-agent. The sub-agent runs "
|
||||
"autonomously with read-only access to current memory and returns "
|
||||
"its result. Use this to parallelize work or leverage specialized capabilities.\n\n"
|
||||
"Available sub-agents:\n" + "\n".join(agent_descriptions)
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"description": f"The sub-agent to invoke. Must be one of: {sub_agents}",
|
||||
"enum": sub_agents,
|
||||
},
|
||||
"task": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The task description for the sub-agent to execute. "
|
||||
"Be specific about what you want the sub-agent to do and "
|
||||
"what information to return."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["agent_id", "task"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def build_report_to_parent_tool() -> Tool:
|
||||
"""Build the synthetic report_to_parent tool for sub-agent progress reports.
|
||||
|
||||
Sub-agents call this to send one-way progress updates, partial findings,
|
||||
or status reports to the parent node (and external observers via event bus)
|
||||
without blocking execution.
|
||||
|
||||
When ``wait_for_response`` is True, the sub-agent blocks until the parent
|
||||
relays the user's response — used for escalation (e.g. login pages, CAPTCHAs).
|
||||
|
||||
When ``mark_complete`` is True, the sub-agent terminates immediately after
|
||||
sending the report — no need to call set_output for each output key.
|
||||
"""
|
||||
return Tool(
|
||||
name="report_to_parent",
|
||||
description=(
|
||||
"Send a report to the parent agent. By default this is fire-and-forget: "
|
||||
"the parent receives the report but does not respond. "
|
||||
"Set wait_for_response=true to BLOCK until the user replies — use this "
|
||||
"when you need human intervention (e.g. login pages, CAPTCHAs, "
|
||||
"authentication walls). The user's response is returned as the tool result. "
|
||||
"Set mark_complete=true to finish your task and terminate immediately "
|
||||
"after sending the report — use this when your findings are in the "
|
||||
"message/data fields and you don't need to call set_output."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "A human-readable status or progress message.",
|
||||
},
|
||||
"data": {
|
||||
"type": "object",
|
||||
"description": "Optional structured data to include with the report.",
|
||||
},
|
||||
"wait_for_response": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"If true, block execution until the user responds. "
|
||||
"Use for escalation scenarios requiring human intervention."
|
||||
),
|
||||
"default": False,
|
||||
},
|
||||
"mark_complete": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"If true, terminate the sub-agent immediately after sending "
|
||||
"this report. The report message and data are delivered to the "
|
||||
"parent as the final result. No set_output calls are needed."
|
||||
),
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
"required": ["message"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def handle_set_output(
|
||||
tool_input: dict[str, Any],
|
||||
output_keys: list[str] | None,
|
||||
) -> ToolResult:
|
||||
"""Handle set_output tool call. Returns ToolResult (sync)."""
|
||||
import logging
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
key = tool_input.get("key", "")
|
||||
value = tool_input.get("value", "")
|
||||
valid_keys = output_keys or []
|
||||
|
||||
# Recover from truncated JSON (max_tokens hit mid-argument).
|
||||
# The _raw key is set by litellm when json.loads fails.
|
||||
if not key and "_raw" in tool_input:
|
||||
raw = tool_input["_raw"]
|
||||
key_match = re.search(r'"key"\s*:\s*"(\w+)"', raw)
|
||||
if key_match:
|
||||
key = key_match.group(1)
|
||||
val_match = re.search(r'"value"\s*:\s*"', raw)
|
||||
if val_match:
|
||||
start = val_match.end()
|
||||
value = raw[start:].rstrip()
|
||||
for suffix in ('"}\n', '"}', '"'):
|
||||
if value.endswith(suffix):
|
||||
value = value[: -len(suffix)]
|
||||
break
|
||||
if key:
|
||||
logger.warning(
|
||||
"Recovered set_output args from truncated JSON: key=%s, value_len=%d",
|
||||
key,
|
||||
len(value),
|
||||
)
|
||||
# Re-inject so the caller sees proper key/value
|
||||
tool_input["key"] = key
|
||||
tool_input["value"] = value
|
||||
|
||||
if key not in valid_keys:
|
||||
return ToolResult(
|
||||
tool_use_id="",
|
||||
content=f"Invalid output key '{key}'. Valid keys: {valid_keys}",
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
return ToolResult(
|
||||
tool_use_id="",
|
||||
content=f"Output '{key}' set successfully.",
|
||||
is_error=False,
|
||||
)
|
||||
@@ -0,0 +1,542 @@
|
||||
"""Tool result handling: truncation, spillover, JSON preview, and execution.
|
||||
|
||||
Manages tool result size limits, file spillover for large results, and
|
||||
smart JSON previews. Also includes transient error classification and
|
||||
the context-window-exceeded error detector.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from framework.llm.provider import ToolResult, ToolUse
|
||||
from framework.llm.stream_events import ToolCallEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Pattern for detecting context-window-exceeded errors across LLM providers.
|
||||
_CONTEXT_TOO_LARGE_RE = re.compile(
|
||||
r"context.{0,20}(length|window|limit|size)|"
|
||||
r"too.{0,10}(long|large|many.{0,10}tokens)|"
|
||||
r"(exceed|exceeds|exceeded).{0,30}(limit|window|context|tokens)|"
|
||||
r"maximum.{0,20}token|prompt.{0,20}too.{0,10}long",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def is_context_too_large_error(exc: BaseException) -> bool:
|
||||
"""Detect whether an exception indicates the LLM input was too large."""
|
||||
cls = type(exc).__name__
|
||||
if "ContextWindow" in cls:
|
||||
return True
|
||||
return bool(_CONTEXT_TOO_LARGE_RE.search(str(exc)))
|
||||
|
||||
|
||||
def is_transient_error(exc: BaseException) -> bool:
|
||||
"""Classify whether an exception is transient (retryable) vs permanent.
|
||||
|
||||
Transient: network errors, rate limits, server errors, timeouts.
|
||||
Permanent: auth errors, bad requests, context window exceeded.
|
||||
"""
|
||||
try:
|
||||
from litellm.exceptions import (
|
||||
APIConnectionError,
|
||||
BadGatewayError,
|
||||
InternalServerError,
|
||||
RateLimitError,
|
||||
ServiceUnavailableError,
|
||||
)
|
||||
|
||||
transient_types: tuple[type[BaseException], ...] = (
|
||||
RateLimitError,
|
||||
APIConnectionError,
|
||||
InternalServerError,
|
||||
BadGatewayError,
|
||||
ServiceUnavailableError,
|
||||
TimeoutError,
|
||||
ConnectionError,
|
||||
OSError,
|
||||
)
|
||||
except ImportError:
|
||||
transient_types = (TimeoutError, ConnectionError, OSError)
|
||||
|
||||
if isinstance(exc, transient_types):
|
||||
return True
|
||||
|
||||
# RuntimeError from StreamErrorEvent with "Stream error:" prefix
|
||||
if isinstance(exc, RuntimeError):
|
||||
error_str = str(exc).lower()
|
||||
transient_keywords = [
|
||||
"rate limit",
|
||||
"429",
|
||||
"timeout",
|
||||
"connection",
|
||||
"internal server",
|
||||
"502",
|
||||
"503",
|
||||
"504",
|
||||
"service unavailable",
|
||||
"bad gateway",
|
||||
"overloaded",
|
||||
"failed to parse tool call",
|
||||
]
|
||||
return any(kw in error_str for kw in transient_keywords)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def extract_json_metadata(parsed: Any, *, _depth: int = 0, _max_depth: int = 3) -> str:
|
||||
"""Return a concise structural summary of parsed JSON.
|
||||
|
||||
Reports key names, value types, and — crucially — array lengths so
|
||||
the LLM knows how much data exists beyond the preview.
|
||||
|
||||
Returns an empty string for simple scalars.
|
||||
"""
|
||||
if _depth >= _max_depth:
|
||||
if isinstance(parsed, dict):
|
||||
return f"dict with {len(parsed)} keys"
|
||||
if isinstance(parsed, list):
|
||||
return f"list of {len(parsed)} items"
|
||||
return type(parsed).__name__
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
if not parsed:
|
||||
return "empty dict"
|
||||
lines: list[str] = []
|
||||
indent = " " * (_depth + 1)
|
||||
for key, value in list(parsed.items())[:20]:
|
||||
if isinstance(value, list):
|
||||
line = f'{indent}"{key}": list of {len(value)} items'
|
||||
if value:
|
||||
first = value[0]
|
||||
if isinstance(first, dict):
|
||||
sample_keys = list(first.keys())[:10]
|
||||
line += f" (each item: dict with keys {sample_keys})"
|
||||
elif isinstance(first, list):
|
||||
line += f" (each item: list of {len(first)} elements)"
|
||||
lines.append(line)
|
||||
elif isinstance(value, dict):
|
||||
child = extract_json_metadata(value, _depth=_depth + 1, _max_depth=_max_depth)
|
||||
lines.append(f'{indent}"{key}": {child}')
|
||||
else:
|
||||
lines.append(f'{indent}"{key}": {type(value).__name__}')
|
||||
if len(parsed) > 20:
|
||||
lines.append(f"{indent}... and {len(parsed) - 20} more keys")
|
||||
return "\n".join(lines)
|
||||
|
||||
if isinstance(parsed, list):
|
||||
if not parsed:
|
||||
return "empty list"
|
||||
desc = f"list of {len(parsed)} items"
|
||||
first = parsed[0]
|
||||
if isinstance(first, dict):
|
||||
sample_keys = list(first.keys())[:10]
|
||||
desc += f" (each item: dict with keys {sample_keys})"
|
||||
elif isinstance(first, list):
|
||||
desc += f" (each item: list of {len(first)} elements)"
|
||||
return desc
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def build_json_preview(parsed: Any, *, max_chars: int = 5000) -> str | None:
|
||||
"""Build a smart preview of parsed JSON, truncating large arrays.
|
||||
|
||||
Shows first 3 + last 1 items of large arrays with explicit count
|
||||
markers so the LLM cannot mistake the preview for the full dataset.
|
||||
|
||||
Returns ``None`` if no truncation was needed (no large arrays).
|
||||
"""
|
||||
_LARGE_ARRAY_THRESHOLD = 10
|
||||
|
||||
def _truncate_arrays(obj: Any) -> tuple[Any, bool]:
|
||||
"""Return (truncated_copy, was_truncated)."""
|
||||
if isinstance(obj, list) and len(obj) > _LARGE_ARRAY_THRESHOLD:
|
||||
n = len(obj)
|
||||
head = obj[:3]
|
||||
tail = obj[-1:]
|
||||
marker = f"... ({n - 4} more items omitted, {n} total) ..."
|
||||
return head + [marker] + tail, True
|
||||
if isinstance(obj, dict):
|
||||
changed = False
|
||||
out: dict[str, Any] = {}
|
||||
for k, v in obj.items():
|
||||
new_v, did = _truncate_arrays(v)
|
||||
out[k] = new_v
|
||||
changed = changed or did
|
||||
return (out, True) if changed else (obj, False)
|
||||
return obj, False
|
||||
|
||||
preview_obj, was_truncated = _truncate_arrays(parsed)
|
||||
if not was_truncated:
|
||||
return None # No large arrays — caller should use raw slicing
|
||||
|
||||
try:
|
||||
result = json.dumps(preview_obj, indent=2, ensure_ascii=False)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
if len(result) > max_chars:
|
||||
# Even 3+1 items too big — try just 1 item
|
||||
def _minimal_arrays(obj: Any) -> Any:
|
||||
if isinstance(obj, list) and len(obj) > _LARGE_ARRAY_THRESHOLD:
|
||||
n = len(obj)
|
||||
return obj[:1] + [f"... ({n - 1} more items omitted, {n} total) ..."]
|
||||
if isinstance(obj, dict):
|
||||
return {k: _minimal_arrays(v) for k, v in obj.items()}
|
||||
return obj
|
||||
|
||||
preview_obj = _minimal_arrays(parsed)
|
||||
try:
|
||||
result = json.dumps(preview_obj, indent=2, ensure_ascii=False)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
if len(result) > max_chars:
|
||||
result = result[:max_chars] + "…"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def truncate_tool_result(
|
||||
result: ToolResult,
|
||||
tool_name: str,
|
||||
*,
|
||||
max_tool_result_chars: int,
|
||||
spillover_dir: str | None,
|
||||
next_spill_filename_fn: Any, # Callable[[str], str]
|
||||
) -> ToolResult:
|
||||
"""Persist tool result to file and optionally truncate for context.
|
||||
|
||||
When *spillover_dir* is configured, EVERY non-error tool result is
|
||||
saved to a file (short filename like ``web_search_1.txt``). A
|
||||
``[Saved to '...']`` annotation is appended so the reference
|
||||
survives pruning and compaction.
|
||||
|
||||
- Small results (≤ limit): full content kept + file annotation
|
||||
- Large results (> limit): preview + file reference
|
||||
- Errors: pass through unchanged
|
||||
- load_data results: truncate with pagination hint (no re-spill)
|
||||
"""
|
||||
limit = max_tool_result_chars
|
||||
|
||||
# Errors always pass through unchanged
|
||||
if result.is_error:
|
||||
return result
|
||||
|
||||
# load_data reads FROM spilled files — never re-spill (circular).
|
||||
# Just truncate with a pagination hint if the result is too large.
|
||||
if tool_name == "load_data":
|
||||
if limit <= 0 or len(result.content) <= limit:
|
||||
return result # Small load_data result — pass through as-is
|
||||
# Large load_data result — truncate with smart preview
|
||||
PREVIEW_CAP = min(5000, max(limit - 500, limit // 2))
|
||||
|
||||
metadata_str = ""
|
||||
smart_preview: str | None = None
|
||||
try:
|
||||
parsed_ld = json.loads(result.content)
|
||||
metadata_str = extract_json_metadata(parsed_ld)
|
||||
smart_preview = build_json_preview(parsed_ld, max_chars=PREVIEW_CAP)
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass
|
||||
|
||||
if smart_preview is not None:
|
||||
preview_block = smart_preview
|
||||
else:
|
||||
preview_block = result.content[:PREVIEW_CAP] + "…"
|
||||
|
||||
header = (
|
||||
f"[{tool_name} result: {len(result.content):,} chars — "
|
||||
f"too large for context. Use offset_bytes/limit_bytes "
|
||||
f"parameters to read smaller chunks.]"
|
||||
)
|
||||
if metadata_str:
|
||||
header += f"\n\nData structure:\n{metadata_str}"
|
||||
header += (
|
||||
"\n\nWARNING: This is an INCOMPLETE preview. Do NOT draw conclusions or counts from it."
|
||||
)
|
||||
|
||||
truncated = f"{header}\n\nPreview (small sample only):\n{preview_block}"
|
||||
logger.info(
|
||||
"%s result truncated: %d → %d chars (use offset/limit to paginate)",
|
||||
tool_name,
|
||||
len(result.content),
|
||||
len(truncated),
|
||||
)
|
||||
return ToolResult(
|
||||
tool_use_id=result.tool_use_id,
|
||||
content=truncated,
|
||||
is_error=False,
|
||||
image_content=result.image_content,
|
||||
is_skill_content=result.is_skill_content,
|
||||
)
|
||||
|
||||
spill_dir = spillover_dir
|
||||
if spill_dir:
|
||||
spill_path = Path(spill_dir)
|
||||
spill_path.mkdir(parents=True, exist_ok=True)
|
||||
filename = next_spill_filename_fn(tool_name)
|
||||
|
||||
# Pretty-print JSON content so load_data's line-based
|
||||
# pagination works correctly.
|
||||
write_content = result.content
|
||||
parsed_json: Any = None # track for metadata extraction
|
||||
try:
|
||||
parsed_json = json.loads(result.content)
|
||||
write_content = json.dumps(parsed_json, indent=2, ensure_ascii=False)
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass # Not JSON — write as-is
|
||||
|
||||
(spill_path / filename).write_text(write_content, encoding="utf-8")
|
||||
|
||||
if limit > 0 and len(result.content) > limit:
|
||||
# Large result: build a small, metadata-rich preview so the
|
||||
# LLM cannot mistake it for the complete dataset.
|
||||
PREVIEW_CAP = 5000
|
||||
|
||||
# Extract structural metadata (array lengths, key names)
|
||||
metadata_str = ""
|
||||
smart_preview: str | None = None
|
||||
if parsed_json is not None:
|
||||
metadata_str = extract_json_metadata(parsed_json)
|
||||
smart_preview = build_json_preview(parsed_json, max_chars=PREVIEW_CAP)
|
||||
|
||||
if smart_preview is not None:
|
||||
preview_block = smart_preview
|
||||
else:
|
||||
preview_block = result.content[:PREVIEW_CAP] + "…"
|
||||
|
||||
# Assemble header with structural info + warning
|
||||
header = (
|
||||
f"[Result from {tool_name}: {len(result.content):,} chars — "
|
||||
f"too large for context, saved to '{filename}'.]\n"
|
||||
)
|
||||
if metadata_str:
|
||||
header += f"\nData structure:\n{metadata_str}"
|
||||
header += (
|
||||
f"\n\nWARNING: The preview below is INCOMPLETE. "
|
||||
f"Do NOT draw conclusions or counts from it. "
|
||||
f"Use load_data(filename='{filename}') to read the "
|
||||
f"full data before analysis."
|
||||
)
|
||||
|
||||
content = f"{header}\n\nPreview (small sample only):\n{preview_block}"
|
||||
logger.info(
|
||||
"Tool result spilled to file: %s (%d chars → %s)",
|
||||
tool_name,
|
||||
len(result.content),
|
||||
filename,
|
||||
)
|
||||
else:
|
||||
# Small result: keep full content + annotation
|
||||
content = f"{result.content}\n\n[Saved to '{filename}']"
|
||||
logger.info(
|
||||
"Tool result saved to file: %s (%d chars → %s)",
|
||||
tool_name,
|
||||
len(result.content),
|
||||
filename,
|
||||
)
|
||||
|
||||
return ToolResult(
|
||||
tool_use_id=result.tool_use_id,
|
||||
content=content,
|
||||
is_error=False,
|
||||
image_content=result.image_content,
|
||||
is_skill_content=result.is_skill_content,
|
||||
)
|
||||
|
||||
# No spillover_dir — truncate in-place if needed
|
||||
if limit > 0 and len(result.content) > limit:
|
||||
PREVIEW_CAP = min(5000, max(limit - 500, limit // 2))
|
||||
|
||||
metadata_str = ""
|
||||
smart_preview: str | None = None
|
||||
try:
|
||||
parsed_inline = json.loads(result.content)
|
||||
metadata_str = extract_json_metadata(parsed_inline)
|
||||
smart_preview = build_json_preview(parsed_inline, max_chars=PREVIEW_CAP)
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass
|
||||
|
||||
if smart_preview is not None:
|
||||
preview_block = smart_preview
|
||||
else:
|
||||
preview_block = result.content[:PREVIEW_CAP] + "…"
|
||||
|
||||
header = (
|
||||
f"[Result from {tool_name}: {len(result.content):,} chars — "
|
||||
f"truncated to fit context budget.]"
|
||||
)
|
||||
if metadata_str:
|
||||
header += f"\n\nData structure:\n{metadata_str}"
|
||||
header += (
|
||||
"\n\nWARNING: This is an INCOMPLETE preview. "
|
||||
"Do NOT draw conclusions or counts from the preview alone."
|
||||
)
|
||||
|
||||
truncated = f"{header}\n\n{preview_block}"
|
||||
logger.info(
|
||||
"Tool result truncated in-place: %s (%d → %d chars)",
|
||||
tool_name,
|
||||
len(result.content),
|
||||
len(truncated),
|
||||
)
|
||||
return ToolResult(
|
||||
tool_use_id=result.tool_use_id,
|
||||
content=truncated,
|
||||
is_error=False,
|
||||
image_content=result.image_content,
|
||||
is_skill_content=result.is_skill_content,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def execute_tool(
|
||||
tool_executor: Any, # Callable[[ToolUse], ToolResult | Awaitable[ToolResult]] | None
|
||||
tc: ToolCallEvent,
|
||||
timeout: float,
|
||||
skill_dirs: list[str] | None = None,
|
||||
) -> ToolResult:
|
||||
"""Execute a tool call, handling both sync and async executors.
|
||||
|
||||
Applies ``tool_call_timeout_seconds`` to prevent hung MCP servers
|
||||
from blocking the event loop indefinitely. The initial executor
|
||||
call is offloaded to a thread pool so that sync executors don't
|
||||
freeze the event loop.
|
||||
"""
|
||||
if tool_executor is None:
|
||||
return ToolResult(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=f"No tool executor configured for '{tc.tool_name}'",
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
skill_dirs = skill_dirs or []
|
||||
skill_read_tools = {"view_file", "load_data", "read_file"}
|
||||
if tc.tool_name in skill_read_tools and skill_dirs:
|
||||
raw_path = tc.tool_input.get("path", "")
|
||||
if raw_path:
|
||||
resolved = Path(raw_path).resolve(strict=False)
|
||||
resolved_roots = [Path(skill_dir).resolve(strict=False) for skill_dir in skill_dirs]
|
||||
if any(resolved.is_relative_to(root) for root in resolved_roots):
|
||||
try:
|
||||
content = resolved.read_text(encoding="utf-8")
|
||||
except Exception as exc:
|
||||
return ToolResult(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=f"Could not read skill resource '{raw_path}': {exc}",
|
||||
is_error=True,
|
||||
)
|
||||
return ToolResult(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=content,
|
||||
is_skill_content=resolved.name == "SKILL.md",
|
||||
)
|
||||
|
||||
tool_use = ToolUse(id=tc.tool_use_id, name=tc.tool_name, input=tc.tool_input)
|
||||
|
||||
async def _run() -> ToolResult:
|
||||
# Offload the executor call to a thread. Sync MCP executors
|
||||
# block on future.result() — running in a thread keeps the
|
||||
# event loop free so asyncio.wait_for can fire the timeout.
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(None, tool_executor, tool_use)
|
||||
# Async executors return a coroutine — await it on the loop
|
||||
if asyncio.iscoroutine(result) or asyncio.isfuture(result):
|
||||
result = await result
|
||||
return result
|
||||
|
||||
try:
|
||||
if timeout > 0:
|
||||
result = await asyncio.wait_for(_run(), timeout=timeout)
|
||||
else:
|
||||
result = await _run()
|
||||
except TimeoutError:
|
||||
logger.warning("Tool '%s' timed out after %.0fs", tc.tool_name, timeout)
|
||||
return ToolResult(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=(
|
||||
f"Tool '{tc.tool_name}' timed out after {timeout:.0f}s. "
|
||||
"The operation took too long and was cancelled. "
|
||||
"Try a simpler request or a different approach."
|
||||
),
|
||||
is_error=True,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def record_learning(key: str, value: Any, spillover_dir: str | None) -> None:
|
||||
"""Append a set_output value to adapt.md as a learning entry.
|
||||
|
||||
Called at set_output time — the moment knowledge is produced — so that
|
||||
adapt.md accumulates the agent's outputs across the session. Since
|
||||
adapt.md is injected into the system prompt, these persist through
|
||||
any compaction.
|
||||
"""
|
||||
if not spillover_dir:
|
||||
return
|
||||
try:
|
||||
adapt_path = Path(spillover_dir) / "adapt.md"
|
||||
adapt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
content = adapt_path.read_text(encoding="utf-8") if adapt_path.exists() else ""
|
||||
|
||||
if "## Outputs" not in content:
|
||||
content += "\n\n## Outputs\n"
|
||||
|
||||
# Truncate long values for memory (full value is in shared memory)
|
||||
v_str = str(value)
|
||||
if len(v_str) > 500:
|
||||
v_str = v_str[:500] + "…"
|
||||
|
||||
entry = f"- {key}: {v_str}\n"
|
||||
|
||||
# Replace existing entry for same key (update, not duplicate)
|
||||
lines = content.splitlines(keepends=True)
|
||||
replaced = False
|
||||
for i, line in enumerate(lines):
|
||||
if line.startswith(f"- {key}:"):
|
||||
lines[i] = entry
|
||||
replaced = True
|
||||
break
|
||||
if replaced:
|
||||
content = "".join(lines)
|
||||
else:
|
||||
content += entry
|
||||
|
||||
adapt_path.write_text(content, encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to record learning for key=%s: %s", key, e)
|
||||
|
||||
|
||||
def next_spill_filename(tool_name: str, counter: int) -> str:
|
||||
"""Return a short, monotonic filename for a tool result spill."""
|
||||
# Shorten common tool name prefixes to save tokens
|
||||
short = tool_name.removeprefix("tool_").removeprefix("mcp_")
|
||||
return f"{short}_{counter}.txt"
|
||||
|
||||
|
||||
def restore_spill_counter(spillover_dir: str | None) -> int:
|
||||
"""Scan spillover_dir for existing spill files and return the max counter.
|
||||
|
||||
Returns the highest spill number found (or 0 if none).
|
||||
"""
|
||||
if not spillover_dir:
|
||||
return 0
|
||||
spill_path = Path(spillover_dir)
|
||||
if not spill_path.is_dir():
|
||||
return 0
|
||||
max_n = 0
|
||||
for f in spill_path.iterdir():
|
||||
if not f.is_file():
|
||||
continue
|
||||
m = re.search(r"_(\d+)\.txt$", f.name)
|
||||
if m:
|
||||
max_n = max(max_n, int(m.group(1)))
|
||||
return max_n
|
||||
@@ -0,0 +1,190 @@
|
||||
"""Shared types and state containers for the event loop package."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
from framework.graph.conversation import ConversationStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TriggerEvent:
|
||||
"""A framework-level trigger signal (timer tick or webhook hit)."""
|
||||
|
||||
trigger_type: str
|
||||
source_id: str
|
||||
payload: dict[str, Any] = field(default_factory=dict)
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@dataclass
|
||||
class JudgeVerdict:
|
||||
"""Result of judge evaluation for the event loop."""
|
||||
|
||||
action: Literal["ACCEPT", "RETRY", "ESCALATE"]
|
||||
# None = no evaluation happened (skip_judge, tool-continue); not logged.
|
||||
# "" = evaluated but no feedback; logged with default text.
|
||||
# "..." = evaluated with feedback; logged as-is.
|
||||
feedback: str | None = None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class JudgeProtocol(Protocol):
|
||||
"""Protocol for event-loop judges."""
|
||||
|
||||
async def evaluate(self, context: dict[str, Any]) -> JudgeVerdict: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoopConfig:
|
||||
"""Configuration for the event loop."""
|
||||
|
||||
max_iterations: int = 50
|
||||
max_tool_calls_per_turn: int = 30
|
||||
judge_every_n_turns: int = 1
|
||||
stall_detection_threshold: int = 3
|
||||
stall_similarity_threshold: float = 0.85
|
||||
max_context_tokens: int = 32_000
|
||||
store_prefix: str = ""
|
||||
|
||||
# Overflow margin for max_tool_calls_per_turn. Tool calls are only
|
||||
# discarded when the count exceeds max_tool_calls_per_turn * (1 + margin).
|
||||
tool_call_overflow_margin: float = 0.5
|
||||
|
||||
# Tool result context management.
|
||||
max_tool_result_chars: int = 30_000
|
||||
spillover_dir: str | None = None
|
||||
|
||||
# set_output value spilling.
|
||||
max_output_value_chars: int = 2_000
|
||||
|
||||
# Stream retry.
|
||||
max_stream_retries: int = 3
|
||||
stream_retry_backoff_base: float = 2.0
|
||||
stream_retry_max_delay: float = 60.0
|
||||
|
||||
# Tool doom loop detection.
|
||||
tool_doom_loop_threshold: int = 3
|
||||
|
||||
# Client-facing auto-block grace period.
|
||||
cf_grace_turns: int = 1
|
||||
tool_doom_loop_enabled: bool = True
|
||||
|
||||
# Per-tool-call timeout.
|
||||
tool_call_timeout_seconds: float = 60.0
|
||||
|
||||
# Subagent delegation timeout.
|
||||
subagent_timeout_seconds: float = 600.0
|
||||
|
||||
# Lifecycle hooks.
|
||||
hooks: dict[str, list] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.hooks is None:
|
||||
object.__setattr__(self, "hooks", {})
|
||||
|
||||
|
||||
@dataclass
|
||||
class HookContext:
|
||||
"""Context passed to every lifecycle hook."""
|
||||
|
||||
event: str
|
||||
trigger: str | None
|
||||
system_prompt: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class HookResult:
|
||||
"""What a hook may return to modify node state."""
|
||||
|
||||
system_prompt: str | None = None
|
||||
inject: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputAccumulator:
|
||||
"""Accumulates output key-value pairs with optional write-through persistence."""
|
||||
|
||||
values: dict[str, Any] = field(default_factory=dict)
|
||||
store: ConversationStore | None = None
|
||||
spillover_dir: str | None = None
|
||||
max_value_chars: int = 0
|
||||
|
||||
async def set(self, key: str, value: Any) -> None:
|
||||
"""Set a key-value pair, auto-spilling large values to files."""
|
||||
value = self._auto_spill(key, value)
|
||||
self.values[key] = value
|
||||
if self.store:
|
||||
cursor = await self.store.read_cursor() or {}
|
||||
outputs = cursor.get("outputs", {})
|
||||
outputs[key] = value
|
||||
cursor["outputs"] = outputs
|
||||
await self.store.write_cursor(cursor)
|
||||
|
||||
def _auto_spill(self, key: str, value: Any) -> Any:
|
||||
"""Save large values to a file and return a reference string."""
|
||||
if self.max_value_chars <= 0 or not self.spillover_dir:
|
||||
return value
|
||||
|
||||
val_str = json.dumps(value, ensure_ascii=False) if not isinstance(value, str) else value
|
||||
if len(val_str) <= self.max_value_chars:
|
||||
return value
|
||||
|
||||
spill_path = Path(self.spillover_dir)
|
||||
spill_path.mkdir(parents=True, exist_ok=True)
|
||||
ext = ".json" if isinstance(value, (dict, list)) else ".txt"
|
||||
filename = f"output_{key}{ext}"
|
||||
write_content = (
|
||||
json.dumps(value, indent=2, ensure_ascii=False)
|
||||
if isinstance(value, (dict, list))
|
||||
else str(value)
|
||||
)
|
||||
(spill_path / filename).write_text(write_content, encoding="utf-8")
|
||||
file_size = (spill_path / filename).stat().st_size
|
||||
logger.info(
|
||||
"set_output value auto-spilled: key=%s, %d chars -> %s (%d bytes)",
|
||||
key,
|
||||
len(val_str),
|
||||
filename,
|
||||
file_size,
|
||||
)
|
||||
return (
|
||||
f"[Saved to '{filename}' ({file_size:,} bytes). "
|
||||
f"Use load_data(filename='{filename}') "
|
||||
f"to access full data.]"
|
||||
)
|
||||
|
||||
def get(self, key: str) -> Any | None:
|
||||
return self.values.get(key)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return dict(self.values)
|
||||
|
||||
def has_all_keys(self, required: list[str]) -> bool:
|
||||
return all(key in self.values and self.values[key] is not None for key in required)
|
||||
|
||||
@classmethod
|
||||
async def restore(cls, store: ConversationStore) -> OutputAccumulator:
|
||||
cursor = await store.read_cursor()
|
||||
values = {}
|
||||
if cursor and "outputs" in cursor:
|
||||
values = cursor["outputs"]
|
||||
return cls(values=values, store=store)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"HookContext",
|
||||
"HookResult",
|
||||
"JudgeProtocol",
|
||||
"JudgeVerdict",
|
||||
"LoopConfig",
|
||||
"OutputAccumulator",
|
||||
"TriggerEvent",
|
||||
]
|
||||
+366
-2428
File diff suppressed because it is too large
Load Diff
@@ -155,6 +155,8 @@ class GraphExecutor:
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
skill_dirs: list[str] | None = None,
|
||||
context_warn_ratio: float | None = None,
|
||||
batch_init_nudge: str | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the executor.
|
||||
@@ -183,6 +185,8 @@ class GraphExecutor:
|
||||
skills_catalog_prompt: Available skills catalog for system prompt
|
||||
protocols_prompt: Default skill operational protocols for system prompt
|
||||
skill_dirs: Skill base directories for Tier 3 resource access
|
||||
context_warn_ratio: Token usage ratio to trigger DS-13 preservation warning
|
||||
batch_init_nudge: System prompt nudge for DS-12 batch auto-detection
|
||||
"""
|
||||
self.runtime = runtime
|
||||
self.llm = llm
|
||||
@@ -207,6 +211,8 @@ class GraphExecutor:
|
||||
self.skills_catalog_prompt = skills_catalog_prompt
|
||||
self.protocols_prompt = protocols_prompt
|
||||
self.skill_dirs: list[str] = skill_dirs or []
|
||||
self.context_warn_ratio: float | None = context_warn_ratio
|
||||
self.batch_init_nudge: str | None = batch_init_nudge
|
||||
|
||||
if protocols_prompt:
|
||||
self.logger.info(
|
||||
@@ -1906,6 +1912,8 @@ class GraphExecutor:
|
||||
skills_catalog_prompt=self.skills_catalog_prompt,
|
||||
protocols_prompt=self.protocols_prompt,
|
||||
skill_dirs=self.skill_dirs,
|
||||
default_skill_warn_ratio=self.context_warn_ratio,
|
||||
default_skill_batch_nudge=self.batch_init_nudge,
|
||||
)
|
||||
|
||||
VALID_NODE_TYPES = {
|
||||
|
||||
@@ -569,6 +569,10 @@ class NodeContext:
|
||||
skills_catalog_prompt: str = "" # Available skills XML catalog
|
||||
protocols_prompt: str = "" # Default skill operational protocols
|
||||
skill_dirs: list[str] = field(default_factory=list) # Skill base dirs for resource access
|
||||
# DS-12: batch auto-detection nudge appended to system prompt when input looks like a batch
|
||||
default_skill_batch_nudge: str | None = None
|
||||
# DS-13: token usage ratio at which to inject a context preservation warning
|
||||
default_skill_warn_ratio: float | None = None
|
||||
|
||||
# Per-iteration metadata provider — when set, EventLoopNode merges
|
||||
# the returned dict into node_loop_iteration event data. Used by
|
||||
|
||||
@@ -159,6 +159,26 @@ if litellm is not None:
|
||||
# (e.g. stream_options for Anthropic) instead of forwarding them verbatim.
|
||||
litellm.drop_params = True
|
||||
|
||||
|
||||
def _is_ollama_model(model: str) -> bool:
|
||||
"""Return True for any Ollama model string (ollama/ or ollama_chat/ prefix)."""
|
||||
return model.startswith("ollama/") or model.startswith("ollama_chat/")
|
||||
|
||||
|
||||
def _ensure_ollama_chat_prefix(model: str) -> str:
|
||||
"""Normalise Ollama model strings to use the ollama_chat/ prefix.
|
||||
|
||||
LiteLLM requires the ``ollama_chat/`` prefix (not ``ollama/``) to enable
|
||||
native function-calling support. With ``ollama/``, LiteLLM falls back to
|
||||
JSON-mode tool calls, which the framework cannot parse as real tool calls.
|
||||
|
||||
See: https://docs.litellm.ai/docs/providers/ollama#example-usage---tool-calling
|
||||
"""
|
||||
if model.startswith("ollama/"):
|
||||
return "ollama_chat/" + model[len("ollama/") :]
|
||||
return model
|
||||
|
||||
|
||||
RATE_LIMIT_MAX_RETRIES = 10
|
||||
RATE_LIMIT_BACKOFF_BASE = 2 # seconds
|
||||
RATE_LIMIT_MAX_DELAY = 120 # seconds - cap to prevent absurd waits
|
||||
@@ -499,7 +519,9 @@ class LiteLLMProvider(LLMProvider):
|
||||
# Translate kimi/ prefix to anthropic/ so litellm uses the Anthropic
|
||||
# Messages API handler and routes to that endpoint — no special headers needed.
|
||||
_original_model = model
|
||||
if model.lower().startswith("kimi/"):
|
||||
if _is_ollama_model(model):
|
||||
model = _ensure_ollama_chat_prefix(model)
|
||||
elif model.lower().startswith("kimi/"):
|
||||
model = "anthropic/" + model[len("kimi/") :]
|
||||
# Normalise api_base: litellm's Anthropic handler appends /v1/messages,
|
||||
# so the base must be https://api.kimi.com/coding (no /v1 suffix).
|
||||
@@ -722,6 +744,10 @@ class LiteLLMProvider(LLMProvider):
|
||||
# Add tools if provided
|
||||
if tools:
|
||||
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
|
||||
if _is_ollama_model(self.model):
|
||||
# Ollama requires explicit tool_choice=auto for function calling
|
||||
# so future readers don't have to guess.
|
||||
kwargs.setdefault("tool_choice", "auto")
|
||||
|
||||
# Add response_format for structured output
|
||||
# LiteLLM passes this through to the underlying provider
|
||||
@@ -919,6 +945,10 @@ class LiteLLMProvider(LLMProvider):
|
||||
kwargs["api_base"] = self.api_base
|
||||
if tools:
|
||||
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
|
||||
if _is_ollama_model(self.model):
|
||||
# Ollama requires explicit tool_choice=auto for function calling
|
||||
# so future readers don't have to guess.
|
||||
kwargs.setdefault("tool_choice", "auto")
|
||||
if response_format:
|
||||
kwargs["response_format"] = response_format
|
||||
|
||||
@@ -1620,6 +1650,10 @@ class LiteLLMProvider(LLMProvider):
|
||||
kwargs["api_base"] = self.api_base
|
||||
if tools:
|
||||
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
|
||||
if _is_ollama_model(self.model):
|
||||
# Ollama requires explicit tool_choice=auto for function calling
|
||||
# so future readers don't have to guess.
|
||||
kwargs.setdefault("tool_choice", "auto")
|
||||
if response_format:
|
||||
kwargs["response_format"] = response_format
|
||||
# The Codex ChatGPT backend (Responses API) rejects several params.
|
||||
|
||||
@@ -14,6 +14,8 @@ from typing import Any, Literal
|
||||
|
||||
import httpx
|
||||
|
||||
from framework.runner.mcp_errors import MCPToolNotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -456,7 +458,10 @@ class MCPClient:
|
||||
self.connect()
|
||||
|
||||
if tool_name not in self._tools:
|
||||
raise ValueError(f"Unknown tool: {tool_name}")
|
||||
raise MCPToolNotFoundError(
|
||||
server=self.config.name,
|
||||
tool_name=tool_name,
|
||||
)
|
||||
|
||||
if self.config.transport == "stdio":
|
||||
with self._stdio_call_lock:
|
||||
@@ -507,7 +512,10 @@ class MCPClient:
|
||||
content_item = result.content[0]
|
||||
if hasattr(content_item, "text"):
|
||||
error_text = content_item.text
|
||||
raise RuntimeError(f"MCP tool '{tool_name}' failed: {error_text}")
|
||||
raise RuntimeError(
|
||||
f"[Server: {self.config.name}] [Transport: {self.config.transport}] "
|
||||
f"Tool '{tool_name}' failed: {error_text}"
|
||||
)
|
||||
|
||||
# Extract content — preserve image blocks alongside text
|
||||
if result.content:
|
||||
@@ -558,11 +566,17 @@ class MCPClient:
|
||||
data = response.json()
|
||||
|
||||
if "error" in data:
|
||||
raise RuntimeError(f"Tool execution error: {data['error']}")
|
||||
raise RuntimeError(
|
||||
f"[Server: {self.config.name}] [Transport: {self.config.transport}] "
|
||||
f"Tool '{tool_name}' failed: {data['error']}"
|
||||
)
|
||||
|
||||
return data.get("result", {}).get("content", [])
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to call tool via HTTP: {e}") from e
|
||||
raise RuntimeError(
|
||||
f"[Server: {self.config.name}] [Transport: {self.config.transport}] "
|
||||
f"Failed to call tool via HTTP: Tool '{tool_name}' failed: {e}"
|
||||
) from e
|
||||
|
||||
def _reconnect(self) -> None:
|
||||
"""Reconnect to the configured MCP server."""
|
||||
|
||||
@@ -0,0 +1,99 @@
|
||||
"""Structured error codes and exceptions for MCP server operations."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class MCPErrorCode(Enum):
|
||||
"""Standardized error codes for MCP operations."""
|
||||
|
||||
MCP_INSTALL_FAILED = "MCP_INSTALL_FAILED"
|
||||
MCP_AUTH_MISSING = "MCP_AUTH_MISSING"
|
||||
MCP_CONNECT_TIMEOUT = "MCP_CONNECT_TIMEOUT"
|
||||
MCP_TOOL_NOT_FOUND = "MCP_TOOL_NOT_FOUND"
|
||||
MCP_PROTOCOL_MISMATCH = "MCP_PROTOCOL_MISMATCH"
|
||||
MCP_VERSION_CONFLICT = "MCP_VERSION_CONFLICT"
|
||||
MCP_HEALTH_FAILED = "MCP_HEALTH_FAILED"
|
||||
|
||||
|
||||
class MCPError(ValueError):
|
||||
"""Base exception for all structured MCP errors."""
|
||||
|
||||
def __init__(self, code: MCPErrorCode, what: str, why: str, fix: str):
|
||||
self.code = code
|
||||
self.what = what
|
||||
self.why = why
|
||||
self.fix = fix
|
||||
self.message = (
|
||||
f"[{self.code.value}]\nWhat failed: {self.what}\nWhy: {self.why}\nFix: {self.fix}"
|
||||
)
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class MCPToolNotFoundError(MCPError):
|
||||
def __init__(self, server: str, tool_name: str):
|
||||
super().__init__(
|
||||
code=MCPErrorCode.MCP_TOOL_NOT_FOUND,
|
||||
what=f"Tool '{tool_name}' not found on server '{server}'",
|
||||
why=f"The server '{server}' does not expose a tool named '{tool_name}'.",
|
||||
fix=f"Run 'hive mcp inspect {server}' to view available tools.",
|
||||
)
|
||||
|
||||
|
||||
class MCPConnectTimeoutError(MCPError):
|
||||
def __init__(self, server: str, transport: str, timeout_sec: int):
|
||||
super().__init__(
|
||||
code=MCPErrorCode.MCP_CONNECT_TIMEOUT,
|
||||
what=f"Connection timed out while starting server '{server}'",
|
||||
why=f"The {transport} transport did not respond within {timeout_sec} seconds.",
|
||||
fix=f"Check if the server is running. Run 'hive mcp doctor {server}' for diagnostics.",
|
||||
)
|
||||
|
||||
|
||||
class MCPAuthError(MCPError):
|
||||
def __init__(self, server: str, env_var: str):
|
||||
super().__init__(
|
||||
code=MCPErrorCode.MCP_AUTH_MISSING,
|
||||
what=f"Authentication failed for server '{server}'",
|
||||
why=f"The required environment variable '{env_var}' is missing or empty.",
|
||||
fix=f"Run: hive mcp config {server} --set {env_var}=<your-token>",
|
||||
)
|
||||
|
||||
|
||||
class MCPInstallError(MCPError):
|
||||
def __init__(self, server: str, why: str, fix: str):
|
||||
super().__init__(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Could not install MCP server '{server}'",
|
||||
why=why,
|
||||
fix=fix,
|
||||
)
|
||||
|
||||
|
||||
class MCPProtocolMismatchError(MCPError):
|
||||
def __init__(self, server: str, detail: str):
|
||||
super().__init__(
|
||||
code=MCPErrorCode.MCP_PROTOCOL_MISMATCH,
|
||||
what=f"Protocol mismatch with server '{server}'",
|
||||
why=detail,
|
||||
fix=f"Check the MCP SDK version required by '{server}' matches your installation.",
|
||||
)
|
||||
|
||||
|
||||
class MCPVersionConflictError(MCPError):
|
||||
def __init__(self, server: str, detail: str):
|
||||
super().__init__(
|
||||
code=MCPErrorCode.MCP_VERSION_CONFLICT,
|
||||
what=f"Version conflict with server '{server}'",
|
||||
why=detail,
|
||||
fix="Update or pin the MCP server package to a compatible version.",
|
||||
)
|
||||
|
||||
|
||||
class MCPHealthCheckError(MCPError):
|
||||
def __init__(self, server: str, detail: str):
|
||||
super().__init__(
|
||||
code=MCPErrorCode.MCP_HEALTH_FAILED,
|
||||
what=f"Health check failed for server '{server}'",
|
||||
why=detail,
|
||||
fix=f"Run 'hive mcp doctor {server}' to diagnose the issue.",
|
||||
)
|
||||
@@ -16,6 +16,11 @@ import httpx
|
||||
|
||||
from framework.runner.mcp_client import MCPClient, MCPServerConfig
|
||||
from framework.runner.mcp_connection_manager import MCPConnectionManager
|
||||
from framework.runner.mcp_errors import (
|
||||
MCPError,
|
||||
MCPErrorCode,
|
||||
MCPInstallError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -141,7 +146,12 @@ class MCPRegistry:
|
||||
"""
|
||||
data = self._read_installed()
|
||||
if name in data["servers"]:
|
||||
raise ValueError(f"Server '{name}' already exists. Use remove first.")
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Server '{name}' already exists",
|
||||
why="A server with this name is already registered locally.",
|
||||
fix=f"Run: hive mcp remove {name} — then add it again.",
|
||||
)
|
||||
|
||||
if manifest is not None:
|
||||
# Inline manifest provided directly
|
||||
@@ -153,7 +163,12 @@ class MCPRegistry:
|
||||
else:
|
||||
# Build manifest from individual params
|
||||
if not transport:
|
||||
raise ValueError("transport is required when manifest is not provided")
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Cannot register server '{name}'",
|
||||
why="transport is required when manifest is not provided.",
|
||||
fix="Pass --transport stdio|http|unix|sse when using hive mcp add.",
|
||||
)
|
||||
manifest = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
@@ -162,11 +177,21 @@ class MCPRegistry:
|
||||
match transport:
|
||||
case "http":
|
||||
if not url:
|
||||
raise ValueError("url is required for http transport")
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Cannot register server '{name}' with http transport",
|
||||
why="url is required for http transport.",
|
||||
fix="Pass --url https://your-server to hive mcp add.",
|
||||
)
|
||||
manifest["http"] = {"url": url, "headers": headers or {}}
|
||||
case "stdio":
|
||||
if not command:
|
||||
raise ValueError("command is required for stdio transport")
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Cannot register server '{name}' with stdio transport",
|
||||
why="command is required for stdio transport.",
|
||||
fix="Pass --command <executable> to hive mcp add.",
|
||||
)
|
||||
manifest["stdio"] = {
|
||||
"command": command,
|
||||
"args": args or [],
|
||||
@@ -175,15 +200,30 @@ class MCPRegistry:
|
||||
}
|
||||
case "unix":
|
||||
if not socket_path:
|
||||
raise ValueError("socket_path is required for unix transport")
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Cannot register server '{name}' with unix transport",
|
||||
why="socket_path is required for unix transport.",
|
||||
fix="Pass --socket-path /path/to/socket to hive mcp add.",
|
||||
)
|
||||
manifest["unix"] = {"socket_path": socket_path}
|
||||
manifest["http"] = {"url": url or "http://localhost"}
|
||||
case "sse":
|
||||
if not url:
|
||||
raise ValueError("url is required for sse transport")
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Cannot register server '{name}' with sse transport",
|
||||
why="url is required for sse transport.",
|
||||
fix="Pass --url https://your-server to hive mcp add.",
|
||||
)
|
||||
manifest["sse"] = {"url": url}
|
||||
case _:
|
||||
raise ValueError(f"Unsupported transport: {transport}")
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Cannot register server '{name}'",
|
||||
why=f"Unsupported transport: '{transport}'.",
|
||||
fix="Use one of: stdio, http, unix, sse.",
|
||||
)
|
||||
|
||||
entry = self._make_entry(
|
||||
source="local",
|
||||
@@ -203,34 +243,48 @@ class MCPRegistry:
|
||||
"""Install a server from the cached remote registry index."""
|
||||
data = self._read_installed()
|
||||
if name in data["servers"]:
|
||||
raise ValueError(f"Server '{name}' already exists. Remove it first or use update.")
|
||||
raise MCPInstallError(
|
||||
server=name,
|
||||
why=f"Server '{name}' already exists in the registry.",
|
||||
fix=f"Run: hive mcp remove {name} — then install again.",
|
||||
)
|
||||
|
||||
index = self._read_cached_index()
|
||||
manifest = index.get("servers", {}).get(name)
|
||||
if manifest is None:
|
||||
raise ValueError(
|
||||
f"Server '{name}' not found in registry index. "
|
||||
"Run 'hive mcp update' to refresh the index."
|
||||
raise MCPInstallError(
|
||||
server=name,
|
||||
why=f"Server '{name}' not found in registry index.",
|
||||
fix="Run: hive mcp update — then try again.",
|
||||
)
|
||||
|
||||
# Validate version if specified
|
||||
if version is not None:
|
||||
index_version = manifest.get("version")
|
||||
if index_version is None:
|
||||
raise ValueError(f"Cannot pin version for '{name}': manifest has no version field.")
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_VERSION_CONFLICT,
|
||||
what=f"Cannot pin version for '{name}'",
|
||||
why="The registry manifest has no version field.",
|
||||
fix="Run: hive mcp update — then omit --version to use latest.",
|
||||
)
|
||||
if index_version != version:
|
||||
raise ValueError(
|
||||
f"Version mismatch for '{name}': requested {version}, "
|
||||
f"index has {index_version}. "
|
||||
"Run 'hive mcp update' to refresh the index."
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_VERSION_CONFLICT,
|
||||
what=f"Version mismatch for '{name}'",
|
||||
why=f"Requested {version} but index has {index_version}.",
|
||||
fix="Run: hive mcp update — or omit --version to use latest.",
|
||||
)
|
||||
|
||||
transport_config = manifest.get("transport", {})
|
||||
supported = transport_config.get("supported", [])
|
||||
if transport is not None:
|
||||
if supported and transport not in supported:
|
||||
raise ValueError(
|
||||
f"Transport '{transport}' not supported by '{name}'. Supported: {supported}"
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Transport '{transport}' not supported by '{name}'",
|
||||
why=f"Server supports: {supported}.",
|
||||
fix=f"Use one of the supported transports: {supported}.",
|
||||
)
|
||||
resolved_transport = transport
|
||||
else:
|
||||
@@ -261,7 +315,12 @@ class MCPRegistry:
|
||||
"""Remove a server from the registry."""
|
||||
data = self._read_installed()
|
||||
if name not in data["servers"]:
|
||||
raise ValueError(f"Server '{name}' is not installed.")
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Cannot remove server '{name}'",
|
||||
why="Server is not installed.",
|
||||
fix="Run: hive mcp list — to see installed servers.",
|
||||
)
|
||||
del data["servers"][name]
|
||||
self._write_installed(data)
|
||||
logger.info("Removed MCP server '%s'", name)
|
||||
@@ -277,7 +336,12 @@ class MCPRegistry:
|
||||
def _set_enabled(self, name: str, *, enabled: bool) -> None:
|
||||
data = self._read_installed()
|
||||
if name not in data["servers"]:
|
||||
raise ValueError(f"Server '{name}' is not installed.")
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Cannot {'enable' if enabled else 'disable'} server '{name}'",
|
||||
why="Server is not installed.",
|
||||
fix="Run: hive mcp list — to see installed servers.",
|
||||
)
|
||||
data["servers"][name]["enabled"] = enabled
|
||||
self._write_installed(data)
|
||||
logger.info("%s MCP server '%s'", "Enabled" if enabled else "Disabled", name)
|
||||
@@ -314,9 +378,19 @@ class MCPRegistry:
|
||||
"""Set an env or header override for a server."""
|
||||
data = self._read_installed()
|
||||
if name not in data["servers"]:
|
||||
raise ValueError(f"Server '{name}' is not installed.")
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Cannot set override for server '{name}'",
|
||||
why="Server is not installed.",
|
||||
fix="Run: hive mcp list — to see installed servers.",
|
||||
)
|
||||
if override_type not in ("env", "headers"):
|
||||
raise ValueError(f"Invalid override type: {override_type}")
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Invalid override type '{override_type}' for server '{name}'",
|
||||
why="Override type must be 'env' or 'headers'.",
|
||||
fix="Use --type env or --type headers.",
|
||||
)
|
||||
data["servers"][name]["overrides"][override_type][key] = value
|
||||
self._write_installed(data)
|
||||
logger.info("Set %s override %s for MCP server '%s'", override_type, key, name)
|
||||
@@ -401,14 +475,16 @@ class MCPRegistry:
|
||||
|
||||
# ── load_agent_selection ────────────────────────────────────────
|
||||
|
||||
def load_agent_selection(self, agent_path: Path) -> list[dict[str, Any]]:
|
||||
def load_agent_selection(self, agent_path: Path) -> tuple[list[dict[str, Any]], int | None]:
|
||||
"""Load mcp_registry.json from an agent directory and resolve servers.
|
||||
|
||||
Returns list of plain dicts compatible with ToolRegistry.register_mcp_server().
|
||||
Returns:
|
||||
(server_config_dicts, max_tools) for :meth:`ToolRegistry.load_registry_servers`.
|
||||
``max_tools`` is ``None`` when omitted or invalid in JSON.
|
||||
"""
|
||||
registry_json_path = agent_path / "mcp_registry.json"
|
||||
if not registry_json_path.exists():
|
||||
return []
|
||||
return [], None
|
||||
|
||||
selection = json.loads(registry_json_path.read_text(encoding="utf-8"))
|
||||
|
||||
@@ -437,15 +513,16 @@ class MCPRegistry:
|
||||
continue
|
||||
validated[field] = value
|
||||
|
||||
max_tools = validated.get("max_tools")
|
||||
configs = self.resolve_for_agent(
|
||||
include=validated.get("include"),
|
||||
tags=validated.get("tags"),
|
||||
exclude=validated.get("exclude"),
|
||||
profile=validated.get("profile"),
|
||||
max_tools=validated.get("max_tools"),
|
||||
max_tools=max_tools,
|
||||
versions=validated.get("versions"),
|
||||
)
|
||||
return [self._server_config_to_dict(c) for c in configs]
|
||||
return [self._server_config_to_dict(c) for c in configs], max_tools
|
||||
|
||||
# ── resolve_for_agent ───────────────────────────────────────────
|
||||
|
||||
@@ -552,12 +629,14 @@ class MCPRegistry:
|
||||
)
|
||||
continue
|
||||
|
||||
# Check tool count cap before adding (FR-56)
|
||||
# Check tool count cap before adding (FR-56), using manifest tool list when present.
|
||||
# When ``tools`` is empty (e.g. ``add_local``), counts are unknown here—callers should
|
||||
# pass the same ``max_tools`` to ToolRegistry.load_registry_servers to cap registration.
|
||||
manifest_tools = manifest.get("tools", [])
|
||||
server_tool_count = len(manifest_tools)
|
||||
if max_tools is not None and server_tool_count == 0:
|
||||
logger.debug(
|
||||
"Server '%s' has no declared tools in manifest, skipping max_tools check",
|
||||
"Server '%s' has no tools list in manifest; max_tools enforced at registration",
|
||||
name,
|
||||
)
|
||||
elif max_tools is not None and total_tools + server_tool_count > max_tools:
|
||||
@@ -693,7 +772,12 @@ class MCPRegistry:
|
||||
|
||||
data = self._read_installed()
|
||||
if name not in data["servers"]:
|
||||
raise ValueError(f"Server '{name}' is not installed.")
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_HEALTH_FAILED,
|
||||
what=f"Cannot health-check server '{name}'",
|
||||
why="Server is not installed.",
|
||||
fix="Run: hive mcp list — to see installed servers.",
|
||||
)
|
||||
|
||||
entry = data["servers"][name]
|
||||
manifest = self._get_effective_manifest(name, entry)
|
||||
@@ -728,7 +812,12 @@ class MCPRegistry:
|
||||
if manager.has_connection(name):
|
||||
is_healthy = manager.health_check(name)
|
||||
if not is_healthy:
|
||||
raise RuntimeError("Shared MCP connection health check failed")
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_HEALTH_FAILED,
|
||||
what=f"Health check failed for server '{name}'",
|
||||
why="Shared MCP connection reported unhealthy.",
|
||||
fix=f"Run: hive mcp doctor {name} — for diagnostics.",
|
||||
)
|
||||
pooled_client = manager.acquire(config)
|
||||
try:
|
||||
tools = pooled_client.list_tools()
|
||||
|
||||
@@ -0,0 +1,906 @@
|
||||
"""CLI commands for MCP server registry management.
|
||||
|
||||
Commands:
|
||||
hive mcp install <name> Install a server from the registry
|
||||
hive mcp add Register a local/running MCP server
|
||||
hive mcp remove <name> Remove an installed server
|
||||
hive mcp enable <name> Enable a server
|
||||
hive mcp disable <name> Disable a server
|
||||
hive mcp list List installed servers
|
||||
hive mcp info <name> Show server details
|
||||
hive mcp config <name> Set env/header overrides
|
||||
hive mcp search <query> Search the registry index
|
||||
hive mcp health [name] Check server health
|
||||
hive mcp update Refresh index and update installed servers
|
||||
hive mcp update <name> Update a single installed server
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# ── Shared helpers ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _get_registry(base_path: Path | None = None):
|
||||
"""Initialize and return an MCPRegistry instance."""
|
||||
from framework.runner.mcp_registry import MCPRegistry
|
||||
|
||||
registry = MCPRegistry(base_path=base_path)
|
||||
registry.initialize()
|
||||
return registry
|
||||
|
||||
|
||||
def _ensure_index_available(registry) -> bool:
|
||||
"""Ensure the registry index is cached locally.
|
||||
|
||||
If no index exists or the cache is stale, fetches a fresh copy.
|
||||
Returns True if a usable index exists, False otherwise.
|
||||
|
||||
Semantics:
|
||||
- Stale cache + refresh fails -> warn and continue with stale cache (True)
|
||||
- No cache + refresh fails -> hard fail (False)
|
||||
"""
|
||||
import httpx
|
||||
|
||||
cache_exists = (registry._cache_dir / "registry_index.json").exists()
|
||||
|
||||
if registry.is_index_stale():
|
||||
print("Updating registry index...", file=sys.stderr)
|
||||
try:
|
||||
count = registry.update_index()
|
||||
print(f"Registry index updated ({count} servers available).", file=sys.stderr)
|
||||
return True
|
||||
except (httpx.HTTPError, OSError) as exc:
|
||||
if cache_exists:
|
||||
print(
|
||||
f"Warning: failed to update registry index: {exc}\nUsing cached index.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return True
|
||||
print(
|
||||
f"Error: no registry index available and refresh failed: {exc}\n"
|
||||
"Check your network connection and try: hive mcp update",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return False
|
||||
|
||||
return cache_exists
|
||||
|
||||
|
||||
_SECURITY_NOTICE = (
|
||||
"Registry servers run code on your machine. Only install servers you trust.\n"
|
||||
"Learn more: https://github.com/aden-hive/hive-mcp-registry"
|
||||
)
|
||||
_NOTICE_SENTINEL = ".security_notice_shown"
|
||||
|
||||
|
||||
def _print_security_notice_if_first_use(registry_base: Path) -> None:
|
||||
"""Print a one-time security notice on first registry install.
|
||||
|
||||
Only prints the notice. Call _mark_security_notice_shown() after
|
||||
a successful install to persist the sentinel.
|
||||
"""
|
||||
sentinel = registry_base / _NOTICE_SENTINEL
|
||||
if sentinel.exists():
|
||||
return
|
||||
print(f"\n {_SECURITY_NOTICE}\n", file=sys.stderr)
|
||||
|
||||
|
||||
def _mark_security_notice_shown(registry_base: Path) -> None:
|
||||
"""Persist the security notice sentinel after a successful install."""
|
||||
sentinel = registry_base / _NOTICE_SENTINEL
|
||||
try:
|
||||
sentinel.touch()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def _prompt_for_missing_credentials(
|
||||
registry,
|
||||
name: str,
|
||||
manifest: dict,
|
||||
) -> None:
|
||||
"""Prompt for required credentials not already set in env or overrides."""
|
||||
credentials = manifest.get("credentials", [])
|
||||
if not credentials:
|
||||
return
|
||||
|
||||
server = registry.get_server(name)
|
||||
existing_overrides = server.get("overrides", {}).get("env", {}) if server else {}
|
||||
|
||||
prompted = False
|
||||
for cred in credentials:
|
||||
if not isinstance(cred, dict):
|
||||
continue
|
||||
env_var = cred.get("env_var", "")
|
||||
if not env_var:
|
||||
continue
|
||||
required = cred.get("required", False)
|
||||
if not required:
|
||||
continue
|
||||
|
||||
# Skip if already in environment or overrides
|
||||
if os.environ.get(env_var) or existing_overrides.get(env_var):
|
||||
continue
|
||||
|
||||
if not prompted:
|
||||
print(f"\n{name} requires credentials:", file=sys.stderr)
|
||||
prompted = True
|
||||
|
||||
description = cred.get("description", env_var)
|
||||
help_url = cred.get("help_url", "")
|
||||
help_hint = f" (get one at {help_url})" if help_url else ""
|
||||
|
||||
try:
|
||||
value = input(f" {description}{help_hint}\n {env_var}: ").strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print("\nSkipped credential prompting.", file=sys.stderr)
|
||||
return
|
||||
|
||||
if value:
|
||||
registry.set_override(name, env_var, value, override_type="env")
|
||||
|
||||
|
||||
def _parse_key_value_pairs(values: list[str]) -> dict[str, str]:
|
||||
"""Parse KEY=VAL pairs from CLI args. Raises ValueError on bad format."""
|
||||
result = {}
|
||||
for item in values:
|
||||
if "=" not in item:
|
||||
raise ValueError(
|
||||
f"Invalid format: '{item}'. Expected KEY=VALUE.\n"
|
||||
f"Example: --set JIRA_API_TOKEN=abc123"
|
||||
)
|
||||
key, _, value = item.partition("=")
|
||||
if not key:
|
||||
raise ValueError(f"Invalid format: '{item}'. Key cannot be empty.")
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
def _find_agents_using_server(registry, name: str) -> list[str]:
|
||||
"""Scan agent directories for mcp_registry.json files that would load a server.
|
||||
|
||||
Uses MCPRegistry.load_agent_selection() to resolve actual selection logic
|
||||
so results stay consistent with runtime behavior.
|
||||
"""
|
||||
agent_dirs: list[Path] = []
|
||||
# parents: [0]=runner, [1]=framework, [2]=core, [3]=hive (project root)
|
||||
# NOTE: This path arithmetic assumes running from the source tree layout.
|
||||
# It will not resolve correctly if installed via pip into site-packages.
|
||||
project_root = Path(__file__).resolve().parents[3]
|
||||
core_dir = Path(__file__).resolve().parents[2]
|
||||
|
||||
candidates = [
|
||||
project_root / "exports",
|
||||
core_dir / "exports",
|
||||
core_dir / "framework" / "agents",
|
||||
]
|
||||
for candidate in candidates:
|
||||
if candidate.is_dir():
|
||||
for child in candidate.iterdir():
|
||||
if child.is_dir():
|
||||
agent_dirs.append(child)
|
||||
|
||||
matches = []
|
||||
for agent_dir in agent_dirs:
|
||||
registry_json = agent_dir / "mcp_registry.json"
|
||||
if not registry_json.exists():
|
||||
continue
|
||||
try:
|
||||
configs = registry.load_agent_selection(agent_dir)
|
||||
resolved_names = {c["name"] for c in configs}
|
||||
if name in resolved_names:
|
||||
matches.append(str(agent_dir))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def _render_installed_table(entries: list[dict]) -> None:
|
||||
"""Render installed servers as a formatted table."""
|
||||
if not entries:
|
||||
print("No servers installed.")
|
||||
print("Run 'hive mcp install <name>' or 'hive mcp add' to get started.")
|
||||
return
|
||||
|
||||
# Column widths
|
||||
name_w = max(len(e["name"]) for e in entries)
|
||||
name_w = max(name_w, 4)
|
||||
transport_w = max(len(e.get("transport", "")) for e in entries)
|
||||
transport_w = max(transport_w, 9)
|
||||
|
||||
header = (
|
||||
f" {'NAME':<{name_w}} "
|
||||
f"{'TRANSPORT':<{transport_w}} "
|
||||
f"{'ENABLED':<7} "
|
||||
f"{'HEALTH':<9} "
|
||||
f"{'TOOLS':<5} "
|
||||
f"{'TRUST':<10} "
|
||||
f"{'SOURCE'}"
|
||||
)
|
||||
print(header)
|
||||
print(" " + "─" * (len(header) - 2))
|
||||
|
||||
for entry in entries:
|
||||
enabled = "yes" if entry.get("enabled", True) else "no"
|
||||
health = entry.get("last_health_status") or "unknown"
|
||||
health_sym = {"healthy": "✓", "unhealthy": "✗"}.get(health, "●")
|
||||
source = entry.get("source", "")
|
||||
manifest = entry.get("manifest", {})
|
||||
tools_count = str(len(manifest.get("tools", [])))
|
||||
trust_tier = manifest.get("status", "")
|
||||
print(
|
||||
f" {entry['name']:<{name_w}} "
|
||||
f"{entry.get('transport', ''):<{transport_w}} "
|
||||
f"{enabled:<7} "
|
||||
f"{health_sym} {health:<7} "
|
||||
f"{tools_count:<5} "
|
||||
f"{trust_tier:<10} "
|
||||
f"{source}"
|
||||
)
|
||||
|
||||
|
||||
def _render_available_table(entries: list[dict]) -> None:
|
||||
"""Render available registry servers as a formatted table."""
|
||||
if not entries:
|
||||
print("No servers in registry index.")
|
||||
print("Run 'hive mcp update' to refresh the index.")
|
||||
return
|
||||
|
||||
name_w = max(len(e["name"]) for e in entries)
|
||||
name_w = max(name_w, 4)
|
||||
|
||||
header = f" {'NAME':<{name_w}} {'VERSION':<9} {'STATUS':<10} DESCRIPTION"
|
||||
print(header)
|
||||
print(" " + "─" * (len(header) - 2))
|
||||
|
||||
for entry in entries:
|
||||
version = entry.get("version", "")
|
||||
status = entry.get("status", "community")
|
||||
desc = entry.get("description", "")
|
||||
# Truncate long descriptions
|
||||
if len(desc) > 60:
|
||||
desc = desc[:57] + "..."
|
||||
print(f" {entry['name']:<{name_w}} {version:<9} {status:<10} {desc}")
|
||||
|
||||
|
||||
def _mask_overrides(overrides: dict) -> dict:
|
||||
"""Replace override values with '<set>' markers. Shared by all output paths."""
|
||||
masked: dict[str, dict[str, str]] = {}
|
||||
if overrides.get("env"):
|
||||
masked["env"] = dict.fromkeys(overrides["env"], "<set>")
|
||||
else:
|
||||
masked["env"] = {}
|
||||
if overrides.get("headers"):
|
||||
masked["headers"] = dict.fromkeys(overrides["headers"], "<set>")
|
||||
else:
|
||||
masked["headers"] = {}
|
||||
return masked
|
||||
|
||||
|
||||
def _emit_json(data: Any) -> None:
|
||||
"""Print data as formatted JSON."""
|
||||
print(json.dumps(data, indent=2, default=str))
|
||||
|
||||
|
||||
# ── Command registration ───────────────────────────────────────────
|
||||
|
||||
|
||||
def register_mcp_commands(subparsers) -> None:
|
||||
"""Register the ``hive mcp`` subcommand group."""
|
||||
mcp_parser = subparsers.add_parser("mcp", help="Manage MCP servers")
|
||||
mcp_sub = mcp_parser.add_subparsers(dest="mcp_command", required=True)
|
||||
|
||||
# ── install ──
|
||||
install_p = mcp_sub.add_parser("install", help="Install a server from the registry")
|
||||
install_p.add_argument("name", help="Server name in the registry")
|
||||
install_p.add_argument(
|
||||
"--version", dest="version", default=None, help="Pin to a specific version"
|
||||
)
|
||||
install_p.add_argument(
|
||||
"--transport", default=None, help="Override default transport (stdio, http, unix, sse)"
|
||||
)
|
||||
install_p.set_defaults(func=cmd_mcp_install)
|
||||
|
||||
# ── add ──
|
||||
add_p = mcp_sub.add_parser("add", help="Register a local/running MCP server")
|
||||
add_p.add_argument("--name", required=False, help="Server name")
|
||||
add_p.add_argument(
|
||||
"--transport",
|
||||
choices=["stdio", "http", "unix", "sse"],
|
||||
default=None,
|
||||
help="Transport type",
|
||||
)
|
||||
add_p.add_argument("--url", default=None, help="Server URL (http, unix, sse)")
|
||||
add_p.add_argument("--command", default=None, help="Command to run (stdio)")
|
||||
add_p.add_argument("--args", nargs="*", default=None, help="Command arguments (stdio)")
|
||||
add_p.add_argument("--socket-path", default=None, help="Unix socket path")
|
||||
add_p.add_argument("--description", default="", help="Server description")
|
||||
add_p.add_argument("--from", dest="from_manifest", default=None, help="Path to manifest.json")
|
||||
add_p.set_defaults(func=cmd_mcp_add)
|
||||
|
||||
# ── remove ──
|
||||
remove_p = mcp_sub.add_parser("remove", help="Remove an installed server")
|
||||
remove_p.add_argument("name", help="Server name")
|
||||
remove_p.set_defaults(func=cmd_mcp_remove)
|
||||
|
||||
# ── enable ──
|
||||
enable_p = mcp_sub.add_parser("enable", help="Enable a disabled server")
|
||||
enable_p.add_argument("name", help="Server name")
|
||||
enable_p.set_defaults(func=cmd_mcp_enable)
|
||||
|
||||
# ── disable ──
|
||||
disable_p = mcp_sub.add_parser("disable", help="Disable a server without removing it")
|
||||
disable_p.add_argument("name", help="Server name")
|
||||
disable_p.set_defaults(func=cmd_mcp_disable)
|
||||
|
||||
# ── list ──
|
||||
list_p = mcp_sub.add_parser("list", help="List servers")
|
||||
list_p.add_argument(
|
||||
"--available", action="store_true", help="Show available servers from registry"
|
||||
)
|
||||
list_p.add_argument("--json", dest="output_json", action="store_true", help="Output as JSON")
|
||||
list_p.set_defaults(func=cmd_mcp_list)
|
||||
|
||||
# ── info ──
|
||||
info_p = mcp_sub.add_parser("info", help="Show server details")
|
||||
info_p.add_argument("name", help="Server name")
|
||||
info_p.add_argument("--json", dest="output_json", action="store_true", help="Output as JSON")
|
||||
info_p.set_defaults(func=cmd_mcp_info)
|
||||
|
||||
# ── config ──
|
||||
config_p = mcp_sub.add_parser("config", help="Set server configuration overrides")
|
||||
config_p.add_argument("name", help="Server name")
|
||||
config_p.add_argument(
|
||||
"--set",
|
||||
dest="set_env",
|
||||
nargs="+",
|
||||
metavar="KEY=VAL",
|
||||
help="Set environment variable overrides",
|
||||
)
|
||||
config_p.add_argument(
|
||||
"--set-header", dest="set_header", nargs="+", metavar="KEY=VAL", help="Set header overrides"
|
||||
)
|
||||
config_p.set_defaults(func=cmd_mcp_config)
|
||||
|
||||
# ── search ──
|
||||
search_p = mcp_sub.add_parser("search", help="Search the registry")
|
||||
search_p.add_argument("query", help="Search term (name, tag, description, tool name)")
|
||||
search_p.add_argument("--json", dest="output_json", action="store_true", help="Output as JSON")
|
||||
search_p.set_defaults(func=cmd_mcp_search)
|
||||
|
||||
# ── health ──
|
||||
health_p = mcp_sub.add_parser("health", help="Check server health")
|
||||
health_p.add_argument("name", nargs="?", default=None, help="Server name (all if omitted)")
|
||||
health_p.add_argument("--json", dest="output_json", action="store_true", help="Output as JSON")
|
||||
health_p.set_defaults(func=cmd_mcp_health)
|
||||
|
||||
# ── update ──
|
||||
update_p = mcp_sub.add_parser(
|
||||
"update", help="Update installed servers or refresh the registry index"
|
||||
)
|
||||
update_p.add_argument(
|
||||
"name",
|
||||
nargs="?",
|
||||
default=None,
|
||||
help="Server name to update (omit to update all registry servers)",
|
||||
)
|
||||
update_p.set_defaults(func=cmd_mcp_update)
|
||||
|
||||
|
||||
# ── P0 command handlers ────────────────────────────────────────────
|
||||
|
||||
|
||||
def cmd_mcp_install(args) -> int:
|
||||
"""Install a server from the registry index."""
|
||||
registry = _get_registry()
|
||||
_print_security_notice_if_first_use(registry._base)
|
||||
if not _ensure_index_available(registry):
|
||||
return 1
|
||||
|
||||
try:
|
||||
entry = registry.install(
|
||||
args.name,
|
||||
transport=args.transport,
|
||||
version=args.version,
|
||||
)
|
||||
except ValueError as exc:
|
||||
print(f"Error: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
_mark_security_notice_shown(registry._base)
|
||||
|
||||
version_str = entry.get("manifest_version", "")
|
||||
transport = entry.get("transport", "")
|
||||
print(f"✓ Installed {args.name} v{version_str} ({transport})")
|
||||
|
||||
# Prompt for credentials defined in the manifest
|
||||
manifest = entry.get("manifest", {})
|
||||
_prompt_for_missing_credentials(registry, args.name, manifest)
|
||||
|
||||
print("\nNext steps:")
|
||||
print(f" hive mcp health {args.name} Check that the server is reachable")
|
||||
print(f" hive mcp info {args.name} View server details")
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_mcp_add(args) -> int:
|
||||
"""Register a local/running MCP server."""
|
||||
registry = _get_registry()
|
||||
|
||||
# Handle --from manifest.json
|
||||
if args.from_manifest:
|
||||
return _cmd_mcp_add_from_manifest(registry, args.from_manifest)
|
||||
|
||||
if not args.name:
|
||||
print(
|
||||
"Error: --name is required.\n"
|
||||
"Usage: hive mcp add --name my-server --transport http --url http://localhost:8080\n"
|
||||
" or: hive mcp add --from manifest.json",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
if not args.transport:
|
||||
print(
|
||||
f"Error: --transport is required.\n"
|
||||
f"Supported transports: stdio, http, unix, sse\n"
|
||||
f"Example: hive mcp add --name {args.name} --transport http --url http://localhost:8080",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
try:
|
||||
entry = registry.add_local(
|
||||
name=args.name,
|
||||
transport=args.transport,
|
||||
url=args.url,
|
||||
command=args.command,
|
||||
args=args.args,
|
||||
socket_path=args.socket_path,
|
||||
description=args.description,
|
||||
)
|
||||
except ValueError as exc:
|
||||
print(f"Error: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
print(f"✓ Registered {args.name} ({entry['transport']})")
|
||||
return 0
|
||||
|
||||
|
||||
def _cmd_mcp_add_from_manifest(registry, manifest_path: str) -> int:
|
||||
"""Register a server from a manifest.json file."""
|
||||
path = Path(manifest_path)
|
||||
if not path.exists():
|
||||
print(
|
||||
f"Error: manifest file not found: {manifest_path}\nCheck the path and try again.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
try:
|
||||
manifest = json.loads(path.read_text(encoding="utf-8"))
|
||||
except json.JSONDecodeError as exc:
|
||||
print(
|
||||
f"Error: invalid JSON in {manifest_path}: {exc}\n"
|
||||
f"Validate with: python -m json.tool {manifest_path}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
name = manifest.get("name")
|
||||
if not name:
|
||||
print(
|
||||
f"Error: manifest missing 'name' field.\nAdd a 'name' field to {manifest_path}.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
try:
|
||||
entry = registry.add_local(name=name, manifest=manifest)
|
||||
except ValueError as exc:
|
||||
print(f"Error: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
print(f"✓ Registered {name} from {manifest_path} ({entry['transport']})")
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_mcp_remove(args) -> int:
|
||||
"""Remove an installed server."""
|
||||
registry = _get_registry()
|
||||
try:
|
||||
registry.remove(args.name)
|
||||
except ValueError as exc:
|
||||
print(f"Error: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
print(f"✓ Removed {args.name}")
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_mcp_enable(args) -> int:
|
||||
"""Enable a disabled server."""
|
||||
registry = _get_registry()
|
||||
try:
|
||||
registry.enable(args.name)
|
||||
except ValueError as exc:
|
||||
print(f"Error: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
print(f"✓ Enabled {args.name}")
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_mcp_disable(args) -> int:
|
||||
"""Disable a server without removing it."""
|
||||
registry = _get_registry()
|
||||
try:
|
||||
registry.disable(args.name)
|
||||
except ValueError as exc:
|
||||
print(f"Error: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
print(f"✓ Disabled {args.name}")
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_mcp_list(args) -> int:
|
||||
"""List installed or available servers."""
|
||||
registry = _get_registry()
|
||||
|
||||
if args.available:
|
||||
if not _ensure_index_available(registry):
|
||||
return 1
|
||||
entries = registry.list_available()
|
||||
if args.output_json:
|
||||
_emit_json(entries)
|
||||
else:
|
||||
_render_available_table(entries)
|
||||
else:
|
||||
entries = registry.list_installed()
|
||||
if args.output_json:
|
||||
safe_entries = []
|
||||
for entry in entries:
|
||||
safe = dict(entry)
|
||||
safe["overrides"] = _mask_overrides(safe.get("overrides", {}))
|
||||
safe_entries.append(safe)
|
||||
_emit_json(safe_entries)
|
||||
else:
|
||||
_render_installed_table(entries)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_mcp_info(args) -> int:
|
||||
"""Show full details for a server."""
|
||||
registry = _get_registry()
|
||||
server = registry.get_server(args.name)
|
||||
|
||||
if server is None:
|
||||
print(
|
||||
f"Error: server '{args.name}' is not installed.\n"
|
||||
f"Run 'hive mcp list' to see installed servers.\n"
|
||||
f"Run 'hive mcp install {args.name}' to install from registry.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
# Enrich with agent usage for both JSON and human output
|
||||
agents = _find_agents_using_server(registry, args.name)
|
||||
if agents:
|
||||
server["used_by_agents"] = agents
|
||||
|
||||
if args.output_json:
|
||||
safe = dict(server)
|
||||
safe["overrides"] = _mask_overrides(safe.get("overrides", {}))
|
||||
_emit_json(safe)
|
||||
return 0
|
||||
|
||||
manifest = server.get("manifest", {})
|
||||
overrides = _mask_overrides(server.get("overrides", {}))
|
||||
tools = manifest.get("tools", [])
|
||||
status = manifest.get("status", "community")
|
||||
hive_block = manifest.get("hive", {})
|
||||
|
||||
print(f"{server['name']}")
|
||||
print("=" * 50)
|
||||
|
||||
# Core info
|
||||
print(f" Source: {server.get('source', '')}")
|
||||
print(f" Transport: {server.get('transport', '')}")
|
||||
print(f" Version: {server.get('manifest_version', 'unknown')}")
|
||||
print(f" Trust tier: {status}")
|
||||
print(f" Enabled: {'yes' if server.get('enabled', True) else 'no'}")
|
||||
|
||||
# Description
|
||||
desc = manifest.get("description", "")
|
||||
if desc:
|
||||
print(f" Description: {desc}")
|
||||
|
||||
# Health
|
||||
health = server.get("last_health_status")
|
||||
if health:
|
||||
health_sym = {"healthy": "✓", "unhealthy": "✗"}.get(health, "●")
|
||||
print(f" Health: {health_sym} {health}")
|
||||
last_check = server.get("last_health_check_at")
|
||||
if last_check:
|
||||
print(f" Last check: {last_check}")
|
||||
last_error = server.get("last_error")
|
||||
if last_error:
|
||||
print(f" Last error: {last_error}")
|
||||
|
||||
# Tools
|
||||
if tools:
|
||||
print(f"\n Tools ({len(tools)}):")
|
||||
for tool in tools:
|
||||
if isinstance(tool, dict):
|
||||
tool_name = tool.get("name", "")
|
||||
tool_desc = tool.get("description", "")
|
||||
print(f" • {tool_name}: {tool_desc}" if tool_desc else f" • {tool_name}")
|
||||
else:
|
||||
print(f" • {tool}")
|
||||
|
||||
# Overrides
|
||||
env_overrides = overrides.get("env", {})
|
||||
header_overrides = overrides.get("headers", {})
|
||||
if env_overrides or header_overrides:
|
||||
print("\n Overrides:")
|
||||
for key in env_overrides:
|
||||
print(f" env.{key} = <set>")
|
||||
for key in header_overrides:
|
||||
print(f" header.{key} = <set>")
|
||||
|
||||
# Hive block
|
||||
if hive_block:
|
||||
profiles = hive_block.get("profiles", [])
|
||||
if profiles:
|
||||
print(f"\n Profiles: {', '.join(profiles)}")
|
||||
min_ver = hive_block.get("min_version")
|
||||
if min_ver:
|
||||
print(f" Min Hive version: {min_ver}")
|
||||
|
||||
# Agent usage
|
||||
if agents:
|
||||
print("\n Used by agents:")
|
||||
for agent in agents:
|
||||
print(f" • {agent}")
|
||||
|
||||
# Timestamps
|
||||
print(f"\n Installed: {server.get('installed_at', 'unknown')}")
|
||||
print(f" Installed by: {server.get('installed_by', 'unknown')}")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_mcp_config(args) -> int:
|
||||
"""Set env or header overrides for a server."""
|
||||
registry = _get_registry()
|
||||
|
||||
if not args.set_env and not args.set_header:
|
||||
# Show current config
|
||||
server = registry.get_server(args.name)
|
||||
if server is None:
|
||||
print(
|
||||
f"Error: server '{args.name}' is not installed.\n"
|
||||
f"Run 'hive mcp list' to see installed servers.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
masked = _mask_overrides(server.get("overrides", {}))
|
||||
env_o = masked.get("env", {})
|
||||
header_o = masked.get("headers", {})
|
||||
if not env_o and not header_o:
|
||||
print(f"No overrides set for {args.name}.")
|
||||
print(f"Set one with: hive mcp config {args.name} --set KEY=VALUE")
|
||||
else:
|
||||
print(f"Overrides for {args.name}:")
|
||||
for key in env_o:
|
||||
print(f" env.{key} = <set>")
|
||||
for key in header_o:
|
||||
print(f" header.{key} = <set>")
|
||||
return 0
|
||||
|
||||
try:
|
||||
if args.set_env:
|
||||
pairs = _parse_key_value_pairs(args.set_env)
|
||||
for key, value in pairs.items():
|
||||
registry.set_override(args.name, key, value, override_type="env")
|
||||
print(f"✓ Set {len(pairs)} env override(s) for {args.name}")
|
||||
|
||||
if args.set_header:
|
||||
pairs = _parse_key_value_pairs(args.set_header)
|
||||
for key, value in pairs.items():
|
||||
registry.set_override(args.name, key, value, override_type="headers")
|
||||
print(f"✓ Set {len(pairs)} header override(s) for {args.name}")
|
||||
|
||||
except ValueError as exc:
|
||||
print(f"Error: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
# ── P1 command handlers ────────────────────────────────────────────
|
||||
|
||||
|
||||
def cmd_mcp_search(args) -> int:
|
||||
"""Search the registry index."""
|
||||
registry = _get_registry()
|
||||
if not _ensure_index_available(registry):
|
||||
return 1
|
||||
|
||||
results = registry.search(args.query)
|
||||
|
||||
if args.output_json:
|
||||
_emit_json(results)
|
||||
return 0
|
||||
|
||||
if not results:
|
||||
print(f"No servers matching '{args.query}'.")
|
||||
return 0
|
||||
|
||||
print(f"Found {len(results)} server(s) matching '{args.query}':\n")
|
||||
_render_available_table(results)
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_mcp_health(args) -> int:
|
||||
"""Check server health."""
|
||||
registry = _get_registry()
|
||||
|
||||
try:
|
||||
results = registry.health_check(name=args.name)
|
||||
except ValueError as exc:
|
||||
print(f"Error: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
# Single server returns a flat dict, all-servers returns name->dict
|
||||
if args.name:
|
||||
results = {args.name: results}
|
||||
|
||||
if args.output_json:
|
||||
_emit_json(results)
|
||||
return 0
|
||||
|
||||
for name, result in results.items():
|
||||
status = result.get("status", "unknown")
|
||||
tools = result.get("tools", 0)
|
||||
error = result.get("error")
|
||||
sym = {"healthy": "✓", "unhealthy": "✗"}.get(status, "●")
|
||||
|
||||
print(f" {sym} {name}: {status}", end="")
|
||||
if status == "healthy" and tools:
|
||||
print(f" ({tools} tools)")
|
||||
elif error:
|
||||
print(f"\n Error: {error}")
|
||||
else:
|
||||
print()
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_mcp_update(args) -> int:
|
||||
"""Update a single server, or refresh the index and update all registry servers."""
|
||||
registry = _get_registry()
|
||||
|
||||
if args.name:
|
||||
return _cmd_mcp_update_server(args.name, registry)
|
||||
|
||||
# Step 1: refresh the registry index
|
||||
try:
|
||||
count = registry.update_index()
|
||||
except Exception as exc:
|
||||
print(
|
||||
f"Error: failed to update registry index: {exc}\n"
|
||||
f"Check your network connection and try again.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
print(f"✓ Registry index updated ({count} servers available)")
|
||||
|
||||
# Step 2: update all installed registry servers (skip local/pinned)
|
||||
installed = registry.list_installed()
|
||||
registry_servers = [
|
||||
s for s in installed if s.get("source") == "registry" and not s.get("pinned")
|
||||
]
|
||||
|
||||
if not registry_servers:
|
||||
return 0
|
||||
|
||||
print(f"\nUpdating {len(registry_servers)} installed server(s)...")
|
||||
errors = 0
|
||||
for server in registry_servers:
|
||||
name = server["name"]
|
||||
rc = _cmd_mcp_update_server(name, registry)
|
||||
if rc != 0:
|
||||
errors += 1
|
||||
|
||||
return 1 if errors else 0
|
||||
|
||||
|
||||
def _cmd_mcp_update_server(name: str, registry=None) -> int:
|
||||
"""Bridge: reinstall a server from the latest index.
|
||||
|
||||
This is a temporary bridge until #6355 adds proper version diffing,
|
||||
tool-signature change detection, and --dry-run support.
|
||||
"""
|
||||
if registry is None:
|
||||
registry = _get_registry()
|
||||
|
||||
server = registry.get_server(name)
|
||||
if server is None:
|
||||
print(
|
||||
f"Error: server '{name}' is not installed.\n"
|
||||
f"Run 'hive mcp install {name}' to install it.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
if server.get("source") != "registry":
|
||||
print(
|
||||
f"Error: '{name}' is a local server and cannot be updated from the registry.\n"
|
||||
f"Use 'hive mcp remove {name}' and 'hive mcp add' to re-register it.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
if server.get("pinned"):
|
||||
print(
|
||||
f"Error: '{name}' is pinned to v{server.get('manifest_version', '?')}.\n"
|
||||
f"To update a pinned server, remove and reinstall:\n"
|
||||
f" hive mcp remove {name} && hive mcp install {name}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
# Refresh index, then reinstall
|
||||
if not _ensure_index_available(registry):
|
||||
return 1
|
||||
|
||||
old_version = server.get("manifest_version", "unknown")
|
||||
transport = server.get("transport")
|
||||
overrides = server.get("overrides", {})
|
||||
was_enabled = server.get("enabled", True)
|
||||
|
||||
# Save the full entry before removing so we can restore on failure
|
||||
saved_entry = dict(server)
|
||||
saved_entry.pop("name", None)
|
||||
|
||||
try:
|
||||
registry.remove(name)
|
||||
entry = registry.install(name, transport=transport)
|
||||
except ValueError as exc:
|
||||
# Restore the original entry so update doesn't become an uninstall
|
||||
data = registry._read_installed()
|
||||
data["servers"][name] = saved_entry
|
||||
registry._write_installed(data)
|
||||
print(
|
||||
f"Error: {exc}\nServer '{name}' has been restored to its previous state.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
new_version = entry.get("manifest_version", "unknown")
|
||||
|
||||
# Restore prior state from the previous installation
|
||||
for key, value in overrides.get("env", {}).items():
|
||||
registry.set_override(name, key, value, override_type="env")
|
||||
for key, value in overrides.get("headers", {}).items():
|
||||
registry.set_override(name, key, value, override_type="headers")
|
||||
if not was_enabled:
|
||||
registry.disable(name)
|
||||
|
||||
if old_version == new_version:
|
||||
print(f"✓ {name} is already at v{new_version}")
|
||||
else:
|
||||
print(f"✓ Updated {name}: v{old_version} → v{new_version}")
|
||||
|
||||
return 0
|
||||
@@ -0,0 +1,252 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_CACHE_INDEX_PATH = Path.home() / ".hive" / "mcp_registry" / "cache" / "registry_index.json"
|
||||
_FIXTURE_INDEX_PATH = Path(__file__).resolve().parent / "fixtures" / "registry_index.json"
|
||||
|
||||
|
||||
def resolve_registry_servers(
|
||||
*,
|
||||
include: list[str] | None = None,
|
||||
tags: list[str] | None = None,
|
||||
exclude: list[str] | None = None,
|
||||
profile: str | None = None,
|
||||
max_tools: int | None = None,
|
||||
versions: dict[str, str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Resolve registry-sourced MCP servers for `mcp_registry.json` selection.
|
||||
|
||||
This function is written to be mock-friendly during early development:
|
||||
- If the real `MCPRegistry` core module is present, delegate to it.
|
||||
- Otherwise, fall back to a cached local index (`~/.hive/.../registry_index.json`)
|
||||
and then to the repo fixture index.
|
||||
"""
|
||||
|
||||
# `max_tools` is enforced by ToolRegistry. We keep it in the resolver
|
||||
# signature to match the PRD and future MCPRegistry interfaces.
|
||||
_ = max_tools
|
||||
|
||||
try:
|
||||
from framework.runner.mcp_registry import MCPRegistry # type: ignore
|
||||
|
||||
registry = MCPRegistry()
|
||||
resolved = registry.resolve_for_agent(
|
||||
include=include or [],
|
||||
tags=tags or [],
|
||||
exclude=exclude or [],
|
||||
profile=profile,
|
||||
max_tools=max_tools,
|
||||
versions=versions or {},
|
||||
)
|
||||
# Future-proof: normalize both dicts and typed objects to dicts.
|
||||
return [_normalize_server_config(x) for x in resolved]
|
||||
except ImportError:
|
||||
# Expected while #6349/#6574 is not merged locally.
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning("MCPRegistry resolution failed; falling back to cache/fixtures: %s", e)
|
||||
|
||||
return _resolve_from_local_index(
|
||||
include=include,
|
||||
tags=tags,
|
||||
exclude=exclude,
|
||||
profile=profile,
|
||||
versions=versions or {},
|
||||
)
|
||||
|
||||
|
||||
def _resolve_from_local_index(
|
||||
*,
|
||||
include: list[str] | None,
|
||||
tags: list[str] | None,
|
||||
exclude: list[str] | None,
|
||||
profile: str | None,
|
||||
versions: dict[str, str],
|
||||
) -> list[dict[str, Any]]:
|
||||
index = _load_index_json()
|
||||
servers = _coerce_index_servers(index)
|
||||
servers_by_name: dict[str, dict[str, Any]] = {
|
||||
s["name"]: s for s in servers if isinstance(s, dict) and "name" in s
|
||||
}
|
||||
|
||||
include_list = include or []
|
||||
tags_list = tags or []
|
||||
exclude_set = set(exclude or [])
|
||||
|
||||
def _profiles_of(entry: dict[str, Any]) -> set[str]:
|
||||
if isinstance(entry.get("profiles"), list):
|
||||
return set(entry["profiles"])
|
||||
hive = entry.get("hive")
|
||||
if isinstance(hive, dict) and isinstance(hive.get("profiles"), list):
|
||||
return set(hive["profiles"])
|
||||
return set()
|
||||
|
||||
def _tags_of(entry: dict[str, Any]) -> set[str]:
|
||||
if isinstance(entry.get("tags"), list):
|
||||
return set(entry["tags"])
|
||||
return set()
|
||||
|
||||
def _entry_version(entry: dict[str, Any]) -> str | None:
|
||||
# Prefer flat `version`, but support a few common shapes.
|
||||
v = entry.get("version")
|
||||
if isinstance(v, str):
|
||||
return v
|
||||
v2 = entry.get("manifest_version")
|
||||
if isinstance(v2, str):
|
||||
return v2
|
||||
hive = entry.get("manifest")
|
||||
if isinstance(hive, dict) and isinstance(hive.get("version"), str):
|
||||
return hive["version"]
|
||||
return None
|
||||
|
||||
def _version_allows(server_name: str) -> bool:
|
||||
if server_name not in versions:
|
||||
return True
|
||||
pinned = versions[server_name]
|
||||
entry = servers_by_name.get(server_name)
|
||||
if not entry:
|
||||
return False
|
||||
return _entry_version(entry) == pinned
|
||||
|
||||
resolved_names: list[str] = []
|
||||
resolved_set: set[str] = set()
|
||||
|
||||
# 1) Include-order first
|
||||
for name in include_list:
|
||||
if name in exclude_set:
|
||||
continue
|
||||
if name in servers_by_name and _version_allows(name) and name not in resolved_set:
|
||||
resolved_names.append(name)
|
||||
resolved_set.add(name)
|
||||
|
||||
# 2) Then tag/profile matches, alphabetical
|
||||
profile_candidates = set()
|
||||
if profile:
|
||||
for name, entry in servers_by_name.items():
|
||||
if name in exclude_set or not _version_allows(name):
|
||||
continue
|
||||
if profile in _profiles_of(entry):
|
||||
profile_candidates.add(name)
|
||||
|
||||
tag_candidates = set()
|
||||
if tags_list:
|
||||
tags_set = set(tags_list)
|
||||
for name, entry in servers_by_name.items():
|
||||
if name in exclude_set or not _version_allows(name):
|
||||
continue
|
||||
if _tags_of(entry).intersection(tags_set):
|
||||
tag_candidates.add(name)
|
||||
|
||||
tag_profile_names = sorted((profile_candidates | tag_candidates) - resolved_set)
|
||||
resolved_names.extend(tag_profile_names)
|
||||
|
||||
# Missing requested servers should warn (FR-54).
|
||||
for name in include_list:
|
||||
if name in exclude_set:
|
||||
continue
|
||||
if name not in resolved_set:
|
||||
if name not in servers_by_name:
|
||||
logger.warning(
|
||||
"Server '%s' requested by mcp_registry.json but not found in index. "
|
||||
"Run: hive mcp install %s",
|
||||
name,
|
||||
name,
|
||||
)
|
||||
elif name in versions:
|
||||
logger.warning(
|
||||
"Server '%s' was requested but pinned version '%s' was not found in index. "
|
||||
"Run: hive mcp update %s or change the pin in mcp_registry.json",
|
||||
name,
|
||||
versions[name],
|
||||
name,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Server '%s' requested by mcp_registry.json was not selected. "
|
||||
"Check selection filters/exclude lists.",
|
||||
name,
|
||||
)
|
||||
|
||||
resolved_configs: list[dict[str, Any]] = []
|
||||
repo_root = Path(__file__).resolve().parents[3]
|
||||
for name in resolved_names:
|
||||
entry = servers_by_name.get(name)
|
||||
if not entry:
|
||||
continue
|
||||
config = entry.get("mcp_config")
|
||||
if not isinstance(config, dict):
|
||||
# Best-effort: allow a direct MCP config shape at top-level.
|
||||
config = {
|
||||
k: v
|
||||
for k, v in entry.items()
|
||||
if k
|
||||
in {
|
||||
"name",
|
||||
"transport",
|
||||
"command",
|
||||
"args",
|
||||
"env",
|
||||
"cwd",
|
||||
"url",
|
||||
"headers",
|
||||
"description",
|
||||
}
|
||||
}
|
||||
mcp_config = dict(config)
|
||||
mcp_config["name"] = name
|
||||
if mcp_config.get("transport") == "stdio":
|
||||
_absolutize_stdio_config_in_place(repo_root, mcp_config)
|
||||
resolved_configs.append(mcp_config)
|
||||
|
||||
return resolved_configs
|
||||
|
||||
|
||||
def _load_index_json() -> Any:
|
||||
if _CACHE_INDEX_PATH.exists():
|
||||
return json.loads(_CACHE_INDEX_PATH.read_text(encoding="utf-8"))
|
||||
if _FIXTURE_INDEX_PATH.exists():
|
||||
logger.info("Using local fixture index because registry cache is missing")
|
||||
return json.loads(_FIXTURE_INDEX_PATH.read_text(encoding="utf-8"))
|
||||
logger.warning("No local MCP registry index found (cache and fixture missing)")
|
||||
return {"servers": []}
|
||||
|
||||
|
||||
def _coerce_index_servers(index: Any) -> list[dict[str, Any]]:
|
||||
if isinstance(index, list):
|
||||
return [x for x in index if isinstance(x, dict)]
|
||||
if isinstance(index, dict):
|
||||
servers = index.get("servers", [])
|
||||
if isinstance(servers, list):
|
||||
return [x for x in servers if isinstance(x, dict)]
|
||||
return []
|
||||
|
||||
|
||||
def _normalize_server_config(raw: Any) -> dict[str, Any]:
|
||||
if isinstance(raw, dict):
|
||||
return dict(raw)
|
||||
|
||||
# Future-proof object-to-dict normalization.
|
||||
for attr in ("to_dict", "model_dump"):
|
||||
maybe = getattr(raw, attr, None)
|
||||
if callable(maybe):
|
||||
return dict(maybe())
|
||||
|
||||
return dict(getattr(raw, "__dict__", {}))
|
||||
|
||||
|
||||
def _absolutize_stdio_config_in_place(repo_root: Path, config: dict[str, Any]) -> None:
|
||||
cwd = config.get("cwd")
|
||||
if isinstance(cwd, str) and not Path(cwd).is_absolute():
|
||||
config["cwd"] = str((repo_root / cwd).resolve())
|
||||
|
||||
# We intentionally do not absolutize `args` here.
|
||||
# For stdio servers, arguments may include the script name relative to
|
||||
# `cwd` (e.g. "coder_tools_server.py" with cwd="tools"). ToolRegistry's
|
||||
# stdio resolution logic handles script path checks and platform quirks.
|
||||
@@ -1429,12 +1429,18 @@ class AgentRunner:
|
||||
|
||||
def _load_registry_mcp_servers(self, agent_path: Path) -> None:
|
||||
"""Load and register MCP servers selected via ``mcp_registry.json``."""
|
||||
registry_json = agent_path / "mcp_registry.json"
|
||||
if registry_json.is_file():
|
||||
self._tool_registry.set_mcp_registry_agent_path(agent_path)
|
||||
else:
|
||||
self._tool_registry.set_mcp_registry_agent_path(None)
|
||||
|
||||
from framework.runner.mcp_registry import MCPRegistry
|
||||
|
||||
try:
|
||||
registry = MCPRegistry()
|
||||
registry.initialize()
|
||||
server_configs = registry.load_agent_selection(agent_path)
|
||||
server_configs, selection_max_tools = registry.load_agent_selection(agent_path)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to load MCP registry servers for '%s': %s",
|
||||
@@ -1446,7 +1452,12 @@ class AgentRunner:
|
||||
if not server_configs:
|
||||
return
|
||||
|
||||
results = self._tool_registry.load_registry_servers(server_configs)
|
||||
results = self._tool_registry.load_registry_servers(
|
||||
server_configs,
|
||||
preserve_existing_tools=True,
|
||||
log_collisions=True,
|
||||
max_tools=selection_max_tools,
|
||||
)
|
||||
loaded = [result for result in results if result["status"] == "loaded"]
|
||||
skipped = [result for result in results if result["status"] != "loaded"]
|
||||
|
||||
|
||||
@@ -66,6 +66,8 @@ class ToolRegistry:
|
||||
self._mcp_cred_snapshot: set[str] = set() # Credential filenames at MCP load time
|
||||
self._mcp_aden_key_snapshot: str | None = None # ADEN_API_KEY value at MCP load time
|
||||
self._mcp_server_tools: dict[str, set[str]] = {} # server name -> tool names
|
||||
# Agent dir for re-loading registry MCP after credential resync.
|
||||
self._mcp_registry_agent_path: Path | None = None
|
||||
|
||||
def register(
|
||||
self,
|
||||
@@ -490,7 +492,13 @@ class ToolRegistry:
|
||||
self._resolve_mcp_server_config(server_config, base_dir)
|
||||
for server_config in server_list
|
||||
]
|
||||
self.load_registry_servers(resolved_server_list, log_summary=False)
|
||||
# Ordered first-wins for duplicate tool names across servers; keep tools.py tools.
|
||||
self.load_registry_servers(
|
||||
resolved_server_list,
|
||||
log_summary=False,
|
||||
preserve_existing_tools=True,
|
||||
log_collisions=False,
|
||||
)
|
||||
|
||||
# Snapshot credential files and ADEN_API_KEY so we can detect mid-session changes
|
||||
self._mcp_cred_snapshot = self._snapshot_credentials()
|
||||
@@ -499,6 +507,10 @@ class ToolRegistry:
|
||||
def _register_mcp_server_with_retry(
|
||||
self,
|
||||
server_config: dict[str, Any],
|
||||
*,
|
||||
preserve_existing_tools: bool = True,
|
||||
tool_cap: int | None = None,
|
||||
log_collisions: bool = False,
|
||||
) -> tuple[bool, int, str | None]:
|
||||
"""Register a single MCP server with one retry for transient failures."""
|
||||
name = server_config.get("name", "unknown")
|
||||
@@ -506,7 +518,12 @@ class ToolRegistry:
|
||||
|
||||
for attempt in range(2):
|
||||
try:
|
||||
count = self.register_mcp_server(server_config)
|
||||
count = self.register_mcp_server(
|
||||
server_config,
|
||||
preserve_existing_tools=preserve_existing_tools,
|
||||
tool_cap=tool_cap,
|
||||
log_collisions=log_collisions,
|
||||
)
|
||||
if count > 0:
|
||||
return True, count, None
|
||||
last_error = "registered 0 tools"
|
||||
@@ -532,13 +549,38 @@ class ToolRegistry:
|
||||
server_list: list[dict[str, Any]],
|
||||
*,
|
||||
log_summary: bool = True,
|
||||
preserve_existing_tools: bool = True,
|
||||
max_tools: int | None = None,
|
||||
log_collisions: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Register resolved registry-selected MCP servers with retry and status tracking."""
|
||||
"""Register MCP servers from a resolved config list (registry and/or static).
|
||||
|
||||
``preserve_existing_tools`` enforces first-wins tool names (FR-100): later
|
||||
servers skip names already taken— including tools from ``mcp_servers.json``
|
||||
or ``tools.py`` when those were loaded first.
|
||||
|
||||
``max_tools`` caps how many *new* tool names are registered across this batch
|
||||
(collisions do not consume the cap). When ``log_collisions`` is True, skipped
|
||||
duplicate names emit a warning (FR-101).
|
||||
"""
|
||||
results: list[dict[str, Any]] = []
|
||||
tools_added_batch = 0
|
||||
|
||||
for server_config in server_list:
|
||||
remaining: int | None = None
|
||||
if max_tools is not None:
|
||||
remaining = max_tools - tools_added_batch
|
||||
if remaining <= 0:
|
||||
break
|
||||
|
||||
name = server_config.get("name", "unknown")
|
||||
success, tools_loaded, error = self._register_mcp_server_with_retry(server_config)
|
||||
success, tools_loaded, error = self._register_mcp_server_with_retry(
|
||||
server_config,
|
||||
preserve_existing_tools=preserve_existing_tools,
|
||||
tool_cap=remaining,
|
||||
log_collisions=log_collisions,
|
||||
)
|
||||
tools_added_batch += tools_loaded
|
||||
result = {
|
||||
"server": name,
|
||||
"status": "loaded" if success else "skipped",
|
||||
@@ -565,6 +607,10 @@ class ToolRegistry:
|
||||
self,
|
||||
server_config: dict[str, Any],
|
||||
use_connection_manager: bool = True,
|
||||
*,
|
||||
preserve_existing_tools: bool = True,
|
||||
tool_cap: int | None = None,
|
||||
log_collisions: bool = False,
|
||||
) -> int:
|
||||
"""
|
||||
Register an MCP server and discover its tools.
|
||||
@@ -581,6 +627,9 @@ class ToolRegistry:
|
||||
- headers: HTTP headers (for http)
|
||||
- description: Server description (optional)
|
||||
use_connection_manager: When True, reuse a shared client keyed by server name
|
||||
preserve_existing_tools: If True, do not replace tools already in the registry.
|
||||
tool_cap: Max tools to newly register from this server (None = unlimited).
|
||||
log_collisions: If True, log when this server skips a tool name already taken.
|
||||
|
||||
Returns:
|
||||
Number of tools registered from this server
|
||||
@@ -623,6 +672,23 @@ class ToolRegistry:
|
||||
self._mcp_server_tools[server_name] = set()
|
||||
count = 0
|
||||
for mcp_tool in client.list_tools():
|
||||
if tool_cap is not None and count >= tool_cap:
|
||||
break
|
||||
|
||||
if preserve_existing_tools and mcp_tool.name in self._tools:
|
||||
if log_collisions:
|
||||
origin_server = (
|
||||
self._find_mcp_origin_server_for_tool(mcp_tool.name) or "<existing>"
|
||||
)
|
||||
logger.warning(
|
||||
"MCP tool '%s' from '%s' shadowed by '%s' (loaded first)",
|
||||
mcp_tool.name,
|
||||
server_name,
|
||||
origin_server,
|
||||
)
|
||||
# Skip registration; do not update MCP tool bookkeeping for this server.
|
||||
continue
|
||||
|
||||
# Convert MCP tool to framework Tool (strips context params from LLM schema)
|
||||
tool = self._convert_mcp_tool_to_framework_tool(mcp_tool)
|
||||
|
||||
@@ -688,11 +754,27 @@ class ToolRegistry:
|
||||
self._mcp_server_tools[server_name].add(mcp_tool.name)
|
||||
count += 1
|
||||
|
||||
logger.info(f"Registered {count} tools from MCP server '{config.name}'")
|
||||
logger.info(
|
||||
"MCP Registry Load",
|
||||
extra={
|
||||
"server": config.name,
|
||||
"status": "success",
|
||||
"tools_loaded": count,
|
||||
"skipped_reason": None,
|
||||
},
|
||||
)
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register MCP server: {e}")
|
||||
logger.error(
|
||||
"MCP Registry Load",
|
||||
extra={
|
||||
"server": server_config.get("name", "unknown"),
|
||||
"status": "failed",
|
||||
"tools_loaded": 0,
|
||||
"skipped_reason": str(e),
|
||||
},
|
||||
)
|
||||
if "Connection closed" in str(e) and os.name == "nt":
|
||||
logger.debug(
|
||||
"On Windows, check that the MCP subprocess starts (e.g. uv in PATH, "
|
||||
@@ -700,6 +782,12 @@ class ToolRegistry:
|
||||
)
|
||||
return 0
|
||||
|
||||
def _find_mcp_origin_server_for_tool(self, tool_name: str) -> str | None:
|
||||
for server_name, tool_names in self._mcp_server_tools.items():
|
||||
if tool_name in tool_names:
|
||||
return server_name
|
||||
return None
|
||||
|
||||
def _convert_mcp_tool_to_framework_tool(self, mcp_tool: Any) -> Tool:
|
||||
"""
|
||||
Convert an MCP tool to a framework Tool.
|
||||
@@ -787,6 +875,37 @@ class ToolRegistry:
|
||||
# MCP credential resync
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def set_mcp_registry_agent_path(self, agent_path: Path | None) -> None:
|
||||
"""Remember agent dir so registry MCP servers reload after credential resync."""
|
||||
self._mcp_registry_agent_path = None if agent_path is None else Path(agent_path)
|
||||
|
||||
def reload_registry_mcp_servers_after_resync(self) -> None:
|
||||
"""Re-run ``mcp_registry.json`` resolution and register servers (post-resync)."""
|
||||
if self._mcp_registry_agent_path is None:
|
||||
return
|
||||
from framework.runner.mcp_registry import MCPRegistry
|
||||
|
||||
try:
|
||||
reg = MCPRegistry()
|
||||
reg.initialize()
|
||||
configs, selection_max_tools = reg.load_agent_selection(self._mcp_registry_agent_path)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to reload MCP registry servers after resync for '%s': %s",
|
||||
self._mcp_registry_agent_path.name,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
if not configs:
|
||||
return
|
||||
self.load_registry_servers(
|
||||
configs,
|
||||
log_summary=True,
|
||||
preserve_existing_tools=True,
|
||||
log_collisions=True,
|
||||
max_tools=selection_max_tools,
|
||||
)
|
||||
|
||||
def _snapshot_credentials(self) -> set[str]:
|
||||
"""Return the set of credential filenames currently on disk."""
|
||||
try:
|
||||
@@ -832,9 +951,12 @@ class ToolRegistry:
|
||||
for name in self._mcp_tool_names:
|
||||
self._tools.pop(name, None)
|
||||
self._mcp_tool_names.clear()
|
||||
self._mcp_server_tools.clear()
|
||||
|
||||
# 3. Re-load MCP servers (spawns fresh subprocesses with new credentials)
|
||||
self.load_mcp_config(self._mcp_config_path)
|
||||
if self._mcp_registry_agent_path is not None:
|
||||
self.reload_registry_mcp_servers_after_resync()
|
||||
|
||||
logger.info("MCP server resync complete")
|
||||
return True
|
||||
|
||||
@@ -200,6 +200,8 @@ class AgentRuntime:
|
||||
self._skills_manager.load()
|
||||
|
||||
self.skill_dirs: list[str] = self._skills_manager.allowlisted_dirs
|
||||
self.context_warn_ratio: float | None = self._skills_manager.context_warn_ratio
|
||||
self.batch_init_nudge: str | None = self._skills_manager.batch_init_nudge
|
||||
|
||||
# Primary graph identity
|
||||
self._graph_id: str = graph_id or "primary"
|
||||
@@ -348,6 +350,8 @@ class AgentRuntime:
|
||||
skills_catalog_prompt=self.skills_catalog_prompt,
|
||||
protocols_prompt=self.protocols_prompt,
|
||||
skill_dirs=self.skill_dirs,
|
||||
context_warn_ratio=self.context_warn_ratio,
|
||||
batch_init_nudge=self.batch_init_nudge,
|
||||
)
|
||||
await stream.start()
|
||||
self._streams[ep_id] = stream
|
||||
|
||||
@@ -16,7 +16,7 @@ from typing import Any
|
||||
from framework.observability import set_trace_context
|
||||
from framework.schemas.decision import Decision, DecisionType, Option, Outcome
|
||||
from framework.schemas.run import Run, RunStatus
|
||||
from framework.storage.backend import FileStorage
|
||||
from framework.storage.concurrent import ConcurrentStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -62,7 +62,7 @@ class Runtime:
|
||||
logger.warning(f"Storage path does not exist, creating: {path}")
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.storage = FileStorage(storage_path)
|
||||
self.storage = ConcurrentStorage(storage_path)
|
||||
self._current_run: Run | None = None
|
||||
self._current_node: str = "unknown"
|
||||
|
||||
@@ -132,8 +132,8 @@ class Runtime:
|
||||
self._current_run.output_data = output_data or {}
|
||||
self._current_run.complete(status, narrative)
|
||||
|
||||
# Save to storage
|
||||
self.storage.save_run(self._current_run)
|
||||
# Save to storage (sync — Runtime methods are not async)
|
||||
self.storage.save_run_sync(self._current_run)
|
||||
self._current_run = None
|
||||
|
||||
def set_node(self, node_id: str) -> None:
|
||||
|
||||
@@ -189,6 +189,8 @@ class ExecutionStream:
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
skill_dirs: list[str] | None = None,
|
||||
context_warn_ratio: float | None = None,
|
||||
batch_init_nudge: str | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize execution stream.
|
||||
@@ -215,6 +217,8 @@ class ExecutionStream:
|
||||
skills_catalog_prompt: Available skills catalog for system prompt
|
||||
protocols_prompt: Default skill operational protocols for system prompt
|
||||
skill_dirs: Skill base directories for Tier 3 resource access
|
||||
context_warn_ratio: Token usage ratio to trigger DS-13 preservation warning
|
||||
batch_init_nudge: System prompt nudge for DS-12 batch auto-detection
|
||||
"""
|
||||
self.stream_id = stream_id
|
||||
self.entry_spec = entry_spec
|
||||
@@ -239,6 +243,8 @@ class ExecutionStream:
|
||||
self._skills_catalog_prompt = skills_catalog_prompt
|
||||
self._protocols_prompt = protocols_prompt
|
||||
self._skill_dirs: list[str] = skill_dirs or []
|
||||
self._context_warn_ratio: float | None = context_warn_ratio
|
||||
self._batch_init_nudge: str | None = batch_init_nudge
|
||||
|
||||
_es_logger = logging.getLogger(__name__)
|
||||
if protocols_prompt:
|
||||
@@ -703,6 +709,8 @@ class ExecutionStream:
|
||||
skills_catalog_prompt=self._skills_catalog_prompt,
|
||||
protocols_prompt=self._protocols_prompt,
|
||||
skill_dirs=self._skill_dirs,
|
||||
context_warn_ratio=self._context_warn_ratio,
|
||||
batch_init_nudge=self._batch_init_nudge,
|
||||
)
|
||||
# Track executor so inject_input() can reach EventLoopNode instances
|
||||
self._active_executors[execution_id] = executor
|
||||
@@ -961,7 +969,10 @@ class ExecutionStream:
|
||||
return
|
||||
import json as _json
|
||||
|
||||
session_dir = self._session_store.get_session_path(execution_id)
|
||||
try:
|
||||
session_dir = self._session_store.get_session_path(execution_id)
|
||||
except ValueError:
|
||||
return
|
||||
runs_file = session_dir / "runs.jsonl"
|
||||
now = datetime.now()
|
||||
record = {
|
||||
|
||||
@@ -90,9 +90,16 @@ async def create_queen(
|
||||
try:
|
||||
registry = MCPRegistry()
|
||||
registry.initialize()
|
||||
registry_configs = registry.load_agent_selection(queen_pkg_dir)
|
||||
if (queen_pkg_dir / "mcp_registry.json").is_file():
|
||||
queen_registry.set_mcp_registry_agent_path(queen_pkg_dir)
|
||||
registry_configs, selection_max_tools = registry.load_agent_selection(queen_pkg_dir)
|
||||
if registry_configs:
|
||||
results = queen_registry.load_registry_servers(registry_configs)
|
||||
results = queen_registry.load_registry_servers(
|
||||
registry_configs,
|
||||
preserve_existing_tools=True,
|
||||
log_collisions=True,
|
||||
max_tools=selection_max_tools,
|
||||
)
|
||||
logger.info("Queen: loaded MCP registry servers: %s", results)
|
||||
except Exception:
|
||||
logger.warning("Queen: MCP registry config failed to load", exc_info=True)
|
||||
@@ -232,6 +239,7 @@ async def create_queen(
|
||||
)
|
||||
|
||||
# ---- Default skill protocols -------------------------------------
|
||||
_queen_skill_dirs: list[str] = []
|
||||
try:
|
||||
from framework.skills.manager import SkillsManager, SkillsManagerConfig
|
||||
|
||||
@@ -242,6 +250,7 @@ async def create_queen(
|
||||
_queen_skills_mgr.load()
|
||||
phase_state.protocols_prompt = _queen_skills_mgr.protocols_prompt
|
||||
phase_state.skills_catalog_prompt = _queen_skills_mgr.skills_catalog_prompt
|
||||
_queen_skill_dirs = _queen_skills_mgr.allowlisted_dirs
|
||||
except Exception:
|
||||
logger.debug("Queen skill loading failed (non-fatal)", exc_info=True)
|
||||
|
||||
@@ -306,6 +315,7 @@ async def create_queen(
|
||||
dynamic_tools_provider=phase_state.get_current_tools,
|
||||
dynamic_prompt_provider=phase_state.get_current_prompt,
|
||||
iteration_metadata_provider=lambda: {"phase": phase_state.phase},
|
||||
skill_dirs=_queen_skill_dirs,
|
||||
)
|
||||
session.queen_executor = executor
|
||||
|
||||
|
||||
@@ -9,27 +9,42 @@ from framework.skills.catalog import SkillCatalog
|
||||
from framework.skills.config import DefaultSkillConfig, SkillsConfig
|
||||
from framework.skills.defaults import DefaultSkillManager
|
||||
from framework.skills.discovery import DiscoveryConfig, SkillDiscovery
|
||||
from framework.skills.installer import (
|
||||
fork_skill,
|
||||
install_from_git,
|
||||
install_from_registry,
|
||||
remove_skill,
|
||||
)
|
||||
from framework.skills.manager import SkillsManager, SkillsManagerConfig
|
||||
from framework.skills.models import TrustStatus
|
||||
from framework.skills.parser import ParsedSkill, parse_skill_md
|
||||
from framework.skills.registry import RegistryClient
|
||||
from framework.skills.skill_errors import SkillError, SkillErrorCode, log_skill_error
|
||||
from framework.skills.trust import TrustedRepoStore, TrustGate
|
||||
from framework.skills.validator import ValidationResult, validate_strict
|
||||
|
||||
__all__ = [
|
||||
"DefaultSkillConfig",
|
||||
"DefaultSkillManager",
|
||||
"DiscoveryConfig",
|
||||
"ParsedSkill",
|
||||
"RegistryClient",
|
||||
"SkillCatalog",
|
||||
"SkillDiscovery",
|
||||
"SkillError",
|
||||
"SkillErrorCode",
|
||||
"SkillsConfig",
|
||||
"SkillsManager",
|
||||
"SkillsManagerConfig",
|
||||
"TrustGate",
|
||||
"TrustedRepoStore",
|
||||
"TrustStatus",
|
||||
"parse_skill_md",
|
||||
"SkillError",
|
||||
"SkillErrorCode",
|
||||
"ValidationResult",
|
||||
"fork_skill",
|
||||
"install_from_git",
|
||||
"install_from_registry",
|
||||
"log_skill_error",
|
||||
"parse_skill_md",
|
||||
"remove_skill",
|
||||
"validate_strict",
|
||||
]
|
||||
|
||||
@@ -20,3 +20,5 @@ What to extract: URLs and key snippets (not full pages), relevant API fields
|
||||
|
||||
Before transitioning to the next phase/node, write a handoff summary to
|
||||
`_handoff_context` with everything the next phase needs to know.
|
||||
|
||||
You will receive an alert when context reaches {{warn_at_usage_ratio_pct}}% — preserve immediately.
|
||||
|
||||
@@ -14,5 +14,5 @@ When a tool call fails:
|
||||
2. Decide — transient: retry once. Structural fixable: fix and retry.
|
||||
Structural unfixable: record as failed, move to next item.
|
||||
Blocking all progress: record escalation note.
|
||||
3. Adapt — if same tool failed 3+ times, stop using it and find alternative.
|
||||
3. Adapt — if same tool failed {{max_retries_per_tool}}+ times, stop using it and find alternative.
|
||||
Update plan in notes. Never silently drop the failed item.
|
||||
|
||||
@@ -8,7 +8,7 @@ metadata:
|
||||
|
||||
## Operational Protocol: Quality Self-Assessment
|
||||
|
||||
Every 5 iterations, self-assess:
|
||||
Every {{assessment_interval}} iterations, self-assess:
|
||||
|
||||
1. On-task? Still working toward the stated objective?
|
||||
2. Thorough? Cutting corners compared to earlier?
|
||||
|
||||
+1352
-6
File diff suppressed because it is too large
Load Diff
@@ -8,6 +8,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from framework.skills.config import SkillsConfig
|
||||
from framework.skills.parser import ParsedSkill, parse_skill_md
|
||||
@@ -18,6 +19,56 @@ logger = logging.getLogger(__name__)
|
||||
# Default skills directory relative to this module
|
||||
_DEFAULT_SKILLS_DIR = Path(__file__).parent / "_default_skills"
|
||||
|
||||
# Default config values per skill — used for {{placeholder}} substitution
|
||||
_SKILL_DEFAULTS: dict[str, dict[str, Any]] = {
|
||||
"hive.quality-monitor": {"assessment_interval": 5},
|
||||
"hive.error-recovery": {"max_retries_per_tool": 3},
|
||||
"hive.context-preservation": {"warn_at_usage_ratio_pct": 45},
|
||||
"hive.batch-ledger": {"checkpoint_every_n": 5},
|
||||
}
|
||||
|
||||
# Keywords that indicate a batch processing scenario (DS-12)
|
||||
_BATCH_KEYWORDS: tuple[str, ...] = (
|
||||
"list of",
|
||||
"collection of",
|
||||
"set of",
|
||||
"batch of",
|
||||
"each item",
|
||||
"for each",
|
||||
"process all",
|
||||
"records",
|
||||
"entries",
|
||||
"rows",
|
||||
"items",
|
||||
)
|
||||
|
||||
_BATCH_INIT_NUDGE = (
|
||||
"Note: your input appears to describe a batch operation. "
|
||||
"Initialize `_batch_ledger` with the total item count before processing."
|
||||
)
|
||||
|
||||
|
||||
def is_batch_scenario(text: str) -> bool:
|
||||
"""Return True if *text* contains batch-processing indicators (DS-12)."""
|
||||
lower = text.lower()
|
||||
return any(kw in lower for kw in _BATCH_KEYWORDS)
|
||||
|
||||
|
||||
def _apply_overrides(skill_name: str, body: str, overrides: dict[str, Any]) -> str:
|
||||
"""Substitute {{placeholder}} values in a skill body using overrides + defaults."""
|
||||
defaults = _SKILL_DEFAULTS.get(skill_name, {})
|
||||
# Convert float warn_at_usage_ratio → warn_at_usage_ratio_pct for the placeholder
|
||||
if "warn_at_usage_ratio" in overrides:
|
||||
overrides = dict(overrides)
|
||||
overrides.setdefault(
|
||||
"warn_at_usage_ratio_pct", int(float(overrides["warn_at_usage_ratio"]) * 100)
|
||||
)
|
||||
values = {**defaults, **overrides}
|
||||
for key, val in values.items():
|
||||
body = body.replace(f"{{{{{key}}}}}", str(val))
|
||||
return body
|
||||
|
||||
|
||||
# Ordered list of default skills (name → directory)
|
||||
SKILL_REGISTRY: dict[str, str] = {
|
||||
"hive.note-taking": "note-taking",
|
||||
@@ -123,8 +174,10 @@ class DefaultSkillManager:
|
||||
skill = self._skills.get(skill_name)
|
||||
if skill is None:
|
||||
continue
|
||||
# Use the full body — each SKILL.md contains exactly one protocol section
|
||||
parts.append(skill.body)
|
||||
# Apply config overrides to {{placeholder}} values before injection
|
||||
overrides = self._config.get_default_overrides(skill_name)
|
||||
body = _apply_overrides(skill_name, skill.body, overrides)
|
||||
parts.append(body)
|
||||
|
||||
if len(parts) <= 1:
|
||||
return ""
|
||||
@@ -198,3 +251,28 @@ class DefaultSkillManager:
|
||||
def active_skills(self) -> dict[str, ParsedSkill]:
|
||||
"""All active default skills keyed by name."""
|
||||
return dict(self._skills)
|
||||
|
||||
@property
|
||||
def batch_init_nudge(self) -> str | None:
|
||||
"""Nudge text to prepend to system prompt when batch input detected (DS-12).
|
||||
|
||||
Returns None if ``hive.batch-ledger`` is disabled or auto_detect_batch is False.
|
||||
"""
|
||||
if "hive.batch-ledger" not in self._skills:
|
||||
return None
|
||||
overrides = self._config.get_default_overrides("hive.batch-ledger")
|
||||
if overrides.get("auto_detect_batch") is False:
|
||||
return None
|
||||
return _BATCH_INIT_NUDGE
|
||||
|
||||
@property
|
||||
def context_warn_ratio(self) -> float | None:
|
||||
"""Token usage ratio at which to inject a context preservation warning (DS-13).
|
||||
|
||||
Returns None if ``hive.context-preservation`` is disabled.
|
||||
Defaults to 0.45 when the skill is active but no override is set.
|
||||
"""
|
||||
if "hive.context-preservation" not in self._skills:
|
||||
return None
|
||||
overrides = self._config.get_default_overrides("hive.context-preservation")
|
||||
return float(overrides.get("warn_at_usage_ratio", 0.45))
|
||||
|
||||
@@ -0,0 +1,348 @@
|
||||
"""Skill install, remove, and fork operations.
|
||||
|
||||
Handles filesystem operations for the hive skill CLI:
|
||||
- install_from_git: git clone --depth=1 → copy to target directory
|
||||
- install_from_registry: resolve registry entry → delegate to install_from_git
|
||||
- remove_skill: delete a skill from ~/.hive/skills/
|
||||
- fork_skill: copy a skill to a new location with a new name
|
||||
- maybe_show_install_notice: one-time security notice on first install (NFR-5)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from framework.skills.parser import ParsedSkill
|
||||
from framework.skills.skill_errors import SkillError, SkillErrorCode
|
||||
|
||||
# Default install destination for user-scope skills
|
||||
USER_SKILLS_DIR = Path.home() / ".hive" / "skills"
|
||||
|
||||
# Sentinel file for the one-time security notice on first install (NFR-5)
|
||||
INSTALL_NOTICE_SENTINEL = Path.home() / ".hive" / ".install_notice_shown"
|
||||
|
||||
_INSTALL_NOTICE = """\
|
||||
─────────────────────────────────────────────────────────────
|
||||
Security Notice: Installing Third-Party Skills
|
||||
─────────────────────────────────────────────────────────────
|
||||
Skills are instructions executed by AI agents. A malicious
|
||||
skill can manipulate agent behavior, exfiltrate data, or
|
||||
cause unintended actions.
|
||||
|
||||
Only install skills from sources you trust. Review the
|
||||
SKILL.md before running it in a production environment.
|
||||
|
||||
This notice is shown once. Use 'hive skill doctor' to audit
|
||||
installed skills at any time.
|
||||
─────────────────────────────────────────────────────────────
|
||||
"""
|
||||
|
||||
|
||||
def maybe_show_install_notice() -> None:
|
||||
"""Print a one-time security notice before the first skill install (NFR-5).
|
||||
|
||||
Touches a sentinel file in ~/.hive/ after showing the notice so it is
|
||||
only displayed once across all future installs.
|
||||
"""
|
||||
if INSTALL_NOTICE_SENTINEL.exists():
|
||||
return
|
||||
print(_INSTALL_NOTICE, flush=True)
|
||||
try:
|
||||
INSTALL_NOTICE_SENTINEL.parent.mkdir(parents=True, exist_ok=True)
|
||||
INSTALL_NOTICE_SENTINEL.touch()
|
||||
except OSError:
|
||||
pass # If we can't write the sentinel, just show the notice every time
|
||||
|
||||
|
||||
def install_from_git(
|
||||
git_url: str,
|
||||
skill_name: str,
|
||||
subdirectory: str | None = None,
|
||||
version: str | None = None,
|
||||
target_dir: Path | None = None,
|
||||
) -> Path:
|
||||
"""Install a skill from a git repository.
|
||||
|
||||
Clones the repository with --depth=1 into a temporary directory, then
|
||||
copies the skill subdirectory (or repo root) to the target location.
|
||||
|
||||
Args:
|
||||
git_url: Git repository URL to clone.
|
||||
skill_name: Name of the skill — used as the install directory name.
|
||||
subdirectory: Relative path within the repo to the skill directory.
|
||||
If None, the repo root is treated as the skill directory.
|
||||
version: Git ref to checkout (tag, branch, or commit). Defaults to
|
||||
the remote's default branch.
|
||||
target_dir: Where to install the skill. Defaults to
|
||||
~/.hive/skills/<skill_name>/.
|
||||
|
||||
Returns:
|
||||
Path to the installed skill directory (the parent of SKILL.md).
|
||||
|
||||
Raises:
|
||||
SkillError: On any failure (git not found, clone failed, SKILL.md missing).
|
||||
"""
|
||||
if shutil.which("git") is None:
|
||||
raise SkillError(
|
||||
code=SkillErrorCode.SKILL_ACTIVATION_FAILED,
|
||||
what=f"Cannot install '{skill_name}' from {git_url}",
|
||||
why="git is not installed or not on PATH.",
|
||||
fix="Install git (https://git-scm.com/) and retry.",
|
||||
)
|
||||
|
||||
dest = (target_dir or USER_SKILLS_DIR) / skill_name
|
||||
if dest.exists():
|
||||
raise SkillError(
|
||||
code=SkillErrorCode.SKILL_ACTIVATION_FAILED,
|
||||
what=f"Cannot install '{skill_name}'",
|
||||
why=f"Directory already exists: {dest}",
|
||||
fix=f"Run 'hive skill remove {skill_name}' first, or use a different --name.",
|
||||
)
|
||||
|
||||
tmp_dir = tempfile.mkdtemp(prefix="hive-skill-install-")
|
||||
try:
|
||||
_git_clone_shallow(git_url, Path(tmp_dir), version=version)
|
||||
|
||||
# Locate the skill within the cloned repo
|
||||
source_dir = Path(tmp_dir) / subdirectory if subdirectory else Path(tmp_dir)
|
||||
skill_md = source_dir / "SKILL.md"
|
||||
if not skill_md.exists():
|
||||
raise SkillError(
|
||||
code=SkillErrorCode.SKILL_NOT_FOUND,
|
||||
what=f"No SKILL.md found in '{subdirectory or '/'}' of {git_url}",
|
||||
why="The expected SKILL.md file is not present at the given path.",
|
||||
fix=(
|
||||
"Check the repository structure and use "
|
||||
"'hive skill install --from <url>' with the correct subdirectory."
|
||||
),
|
||||
)
|
||||
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
_copy_skill_dir(source_dir, dest)
|
||||
return dest
|
||||
|
||||
except SkillError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise SkillError(
|
||||
code=SkillErrorCode.SKILL_ACTIVATION_FAILED,
|
||||
what=f"Failed to install '{skill_name}' from {git_url}",
|
||||
why=str(exc),
|
||||
fix="Check the URL, your network connection, and git configuration.",
|
||||
) from exc
|
||||
finally:
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def install_from_registry(
|
||||
registry_entry: dict,
|
||||
target_dir: Path | None = None,
|
||||
version: str | None = None,
|
||||
) -> Path:
|
||||
"""Install a skill using a registry index entry.
|
||||
|
||||
Resolves the git_url and subdirectory from the registry entry and
|
||||
delegates to install_from_git.
|
||||
|
||||
Args:
|
||||
registry_entry: A skill entry dict from skill_index.json.
|
||||
target_dir: Override install destination.
|
||||
version: Override version (defaults to entry's 'version' field).
|
||||
|
||||
Returns:
|
||||
Path to the installed skill directory.
|
||||
|
||||
Raises:
|
||||
SkillError: If the registry entry is missing required fields or install fails.
|
||||
"""
|
||||
name = registry_entry.get("name")
|
||||
git_url = registry_entry.get("git_url")
|
||||
|
||||
if not name or not git_url:
|
||||
raise SkillError(
|
||||
code=SkillErrorCode.SKILL_NOT_FOUND,
|
||||
what="Incomplete registry entry — missing 'name' or 'git_url'.",
|
||||
why="The registry index entry does not contain all required fields.",
|
||||
fix="Report this issue to the registry maintainer.",
|
||||
)
|
||||
|
||||
resolved_version = version or registry_entry.get("version")
|
||||
subdirectory = registry_entry.get("subdirectory")
|
||||
|
||||
return install_from_git(
|
||||
git_url=git_url,
|
||||
skill_name=str(name),
|
||||
subdirectory=subdirectory,
|
||||
version=resolved_version,
|
||||
target_dir=target_dir,
|
||||
)
|
||||
|
||||
|
||||
def remove_skill(name: str, skills_dir: Path | None = None) -> bool:
|
||||
"""Remove an installed skill from the user skills directory.
|
||||
|
||||
Args:
|
||||
name: Skill directory name to remove.
|
||||
skills_dir: Override the search directory (default: ~/.hive/skills/).
|
||||
|
||||
Returns:
|
||||
True if removed, False if not found.
|
||||
|
||||
Raises:
|
||||
SkillError: If the directory exists but cannot be removed.
|
||||
"""
|
||||
target = (skills_dir or USER_SKILLS_DIR) / name
|
||||
if not target.exists():
|
||||
return False
|
||||
try:
|
||||
shutil.rmtree(target)
|
||||
return True
|
||||
except OSError as exc:
|
||||
raise SkillError(
|
||||
code=SkillErrorCode.SKILL_ACTIVATION_FAILED,
|
||||
what=f"Failed to remove skill '{name}' at {target}",
|
||||
why=str(exc),
|
||||
fix="Check file permissions and try again.",
|
||||
) from exc
|
||||
|
||||
|
||||
def fork_skill(
|
||||
source: ParsedSkill,
|
||||
new_name: str,
|
||||
target_dir: Path,
|
||||
) -> Path:
|
||||
"""Create a local editable copy of a skill with a new name.
|
||||
|
||||
Copies the skill's base directory to target_dir/new_name/ and rewrites
|
||||
the 'name' field in the copied SKILL.md frontmatter.
|
||||
|
||||
Args:
|
||||
source: The source skill to fork (from SkillDiscovery).
|
||||
new_name: Name for the forked skill.
|
||||
target_dir: Parent directory for the fork (e.g. ~/.hive/skills/).
|
||||
|
||||
Returns:
|
||||
Path to the forked skill directory.
|
||||
|
||||
Raises:
|
||||
SkillError: If the target already exists or the copy fails.
|
||||
"""
|
||||
dest = target_dir / new_name
|
||||
if dest.exists():
|
||||
raise SkillError(
|
||||
code=SkillErrorCode.SKILL_ACTIVATION_FAILED,
|
||||
what=f"Cannot fork to '{dest}'",
|
||||
why="Target directory already exists.",
|
||||
fix=f"Choose a different --name or remove '{dest}' first.",
|
||||
)
|
||||
|
||||
source_dir = Path(source.base_dir)
|
||||
try:
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
_copy_skill_dir(source_dir, dest)
|
||||
except OSError as exc:
|
||||
raise SkillError(
|
||||
code=SkillErrorCode.SKILL_ACTIVATION_FAILED,
|
||||
what=f"Failed to fork skill '{source.name}' to '{dest}'",
|
||||
why=str(exc),
|
||||
fix="Check file permissions and available disk space.",
|
||||
) from exc
|
||||
|
||||
# Rewrite the name in the forked SKILL.md via YAML round-trip (safe)
|
||||
forked_skill_md = dest / "SKILL.md"
|
||||
if forked_skill_md.exists():
|
||||
_rewrite_name_in_skill_md(forked_skill_md, new_name)
|
||||
|
||||
return dest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _git_clone_shallow(git_url: str, target: Path, version: str | None = None) -> None:
|
||||
"""Clone a git repo at --depth=1 into target directory.
|
||||
|
||||
Args:
|
||||
git_url: Repository URL.
|
||||
target: Destination directory (will be created by git).
|
||||
version: Optional git ref (branch/tag) to clone.
|
||||
|
||||
Raises:
|
||||
SkillError: If the clone fails.
|
||||
"""
|
||||
cmd = ["git", "clone", "--depth=1"]
|
||||
if version:
|
||||
cmd += ["--branch", version]
|
||||
cmd += [git_url, str(target)]
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
raise SkillError(
|
||||
code=SkillErrorCode.SKILL_ACTIVATION_FAILED,
|
||||
what=f"git clone timed out for {git_url}",
|
||||
why="The clone operation took longer than 60 seconds.",
|
||||
fix="Check your network connection and retry.",
|
||||
) from None
|
||||
except (FileNotFoundError, OSError) as exc:
|
||||
raise SkillError(
|
||||
code=SkillErrorCode.SKILL_ACTIVATION_FAILED,
|
||||
what=f"Cannot run git for {git_url}",
|
||||
why=str(exc),
|
||||
fix="Ensure git is installed and on PATH.",
|
||||
) from exc
|
||||
|
||||
if result.returncode != 0:
|
||||
stderr = result.stderr.strip()
|
||||
raise SkillError(
|
||||
code=SkillErrorCode.SKILL_ACTIVATION_FAILED,
|
||||
what=f"git clone failed for {git_url}",
|
||||
why=stderr or f"git exited with code {result.returncode}",
|
||||
fix="Check the URL is correct and the repository is publicly accessible.",
|
||||
)
|
||||
|
||||
|
||||
def _copy_skill_dir(src: Path, dst: Path) -> None:
|
||||
"""Copy a skill directory, ignoring VCS and cache artifacts."""
|
||||
ignore = shutil.ignore_patterns(".git", "__pycache__", "*.pyc", ".venv", "venv", "node_modules")
|
||||
shutil.copytree(src, dst, ignore=ignore)
|
||||
|
||||
|
||||
def _rewrite_name_in_skill_md(skill_md: Path, new_name: str) -> None:
|
||||
"""Rewrite the 'name' field in a SKILL.md frontmatter via YAML round-trip.
|
||||
|
||||
Parses the frontmatter with yaml.safe_load, updates 'name', re-serializes
|
||||
with yaml.dump, and reconstructs the file as:
|
||||
---
|
||||
<yaml>
|
||||
---
|
||||
<body>
|
||||
|
||||
Falls back to no-op if the file can't be parsed (the copy is still usable).
|
||||
"""
|
||||
import yaml
|
||||
|
||||
try:
|
||||
content = skill_md.read_text(encoding="utf-8")
|
||||
parts = content.split("---", 2)
|
||||
if len(parts) < 3:
|
||||
return
|
||||
frontmatter = yaml.safe_load(parts[1].strip())
|
||||
if not isinstance(frontmatter, dict):
|
||||
return
|
||||
frontmatter["name"] = new_name
|
||||
new_yaml = yaml.dump(frontmatter, default_flow_style=False, allow_unicode=True)
|
||||
new_content = f"---\n{new_yaml}---\n{parts[2]}"
|
||||
skill_md.write_text(new_content, encoding="utf-8")
|
||||
except Exception:
|
||||
pass # Degraded: forked copy works, name just isn't updated
|
||||
@@ -67,6 +67,7 @@ class SkillsManager:
|
||||
self._catalog_prompt: str = ""
|
||||
self._protocols_prompt: str = ""
|
||||
self._allowlisted_dirs: list[str] = []
|
||||
self._default_mgr: object = None # DefaultSkillManager, set after load()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Factory for backwards-compat bridge
|
||||
@@ -90,6 +91,7 @@ class SkillsManager:
|
||||
mgr._catalog_prompt = skills_catalog_prompt
|
||||
mgr._protocols_prompt = protocols_prompt
|
||||
mgr._allowlisted_dirs = []
|
||||
mgr._default_mgr = None
|
||||
return mgr
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -146,6 +148,7 @@ class SkillsManager:
|
||||
default_mgr.load()
|
||||
default_mgr.log_active_skills()
|
||||
protocols_prompt = default_mgr.build_protocols_prompt()
|
||||
self._default_mgr = default_mgr
|
||||
# DX-3: Community skill startup summary
|
||||
if self._config.project_root is not None and not self._config.skip_community_discovery:
|
||||
community_count = len(catalog._skills) if catalog_prompt else 0
|
||||
@@ -189,6 +192,20 @@ class SkillsManager:
|
||||
"""Skill base directories for Tier 3 resource access (AS-6)."""
|
||||
return self._allowlisted_dirs
|
||||
|
||||
@property
|
||||
def batch_init_nudge(self) -> str | None:
|
||||
"""Batch init nudge text for DS-12 auto-detection, or None if disabled."""
|
||||
if self._default_mgr is None:
|
||||
return None
|
||||
return self._default_mgr.batch_init_nudge # type: ignore[union-attr]
|
||||
|
||||
@property
|
||||
def context_warn_ratio(self) -> float | None:
|
||||
"""Token usage ratio for DS-13 context preservation warning, or None if disabled."""
|
||||
if self._default_mgr is None:
|
||||
return None
|
||||
return self._default_mgr.context_warn_ratio # type: ignore[union-attr]
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
return self._loaded
|
||||
|
||||
@@ -211,6 +211,15 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
|
||||
fix=f"Rename the directory to '{name}' or set name to '{parent_dir_name}'.",
|
||||
)
|
||||
|
||||
# Coerce compatibility / allowed-tools to list[str] — many SKILL.md files
|
||||
# in the wild use a plain string instead of a YAML list.
|
||||
raw_compat = frontmatter.get("compatibility")
|
||||
if isinstance(raw_compat, str):
|
||||
raw_compat = [raw_compat]
|
||||
raw_tools = frontmatter.get("allowed-tools")
|
||||
if isinstance(raw_tools, str):
|
||||
raw_tools = [raw_tools]
|
||||
|
||||
return ParsedSkill(
|
||||
name=name,
|
||||
description=str(description).strip(),
|
||||
@@ -219,7 +228,7 @@ def parse_skill_md(path: Path, source_scope: str = "project") -> ParsedSkill | N
|
||||
source_scope=source_scope,
|
||||
body=body,
|
||||
license=frontmatter.get("license"),
|
||||
compatibility=frontmatter.get("compatibility"),
|
||||
compatibility=raw_compat,
|
||||
metadata=frontmatter.get("metadata"),
|
||||
allowed_tools=frontmatter.get("allowed-tools"),
|
||||
allowed_tools=raw_tools,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,206 @@
|
||||
"""Registry client for the Hive community skill registry.
|
||||
|
||||
Fetches the skill index from the hive-skill-registry GitHub repo, caches it
|
||||
locally, and provides search and resolution utilities.
|
||||
|
||||
The registry repo (Phase 3) may not exist yet. All public methods degrade
|
||||
gracefully — returning None or [] on any network or parse failure.
|
||||
|
||||
Configure a custom registry URL via the HIVE_REGISTRY_URL environment variable.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from urllib.error import URLError
|
||||
from urllib.request import urlopen
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default registry index URL (Phase 3 repo, may not exist yet)
|
||||
_DEFAULT_REGISTRY_URL = (
|
||||
"https://raw.githubusercontent.com/hive-skill-registry/"
|
||||
"hive-skill-registry/main/skill_index.json"
|
||||
)
|
||||
|
||||
_CACHE_DIR = Path.home() / ".hive" / "registry_cache"
|
||||
_CACHE_INDEX_PATH = _CACHE_DIR / "skill_index.json"
|
||||
_CACHE_METADATA_PATH = _CACHE_DIR / "metadata.json"
|
||||
_CACHE_TTL_SECONDS = 3600 # 1 hour
|
||||
|
||||
|
||||
class RegistryClient:
|
||||
"""Client for the Hive community skill registry.
|
||||
|
||||
All public methods return None / [] on any failure — never raise.
|
||||
Network errors, parse failures, and missing registries are all
|
||||
treated as graceful degradation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
registry_url: str | None = None,
|
||||
cache_dir: Path | None = None,
|
||||
) -> None:
|
||||
self._url = registry_url or os.environ.get("HIVE_REGISTRY_URL", _DEFAULT_REGISTRY_URL)
|
||||
cache_root = cache_dir or _CACHE_DIR
|
||||
self._index_path = cache_root / "skill_index.json"
|
||||
self._metadata_path = cache_root / "metadata.json"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def fetch_index(self, force_refresh: bool = False) -> dict | None:
|
||||
"""Return the registry index dict.
|
||||
|
||||
Uses the local cache if it is fresh (within TTL) unless
|
||||
force_refresh=True. Returns None on any failure.
|
||||
"""
|
||||
if not force_refresh and self._is_cache_fresh():
|
||||
cached = self._load_cache()
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
raw = self._http_fetch(self._url)
|
||||
if raw is None:
|
||||
# Network unavailable — fall back to stale cache if present
|
||||
stale = self._load_cache()
|
||||
if stale is not None:
|
||||
logger.debug("registry: network unavailable, using stale cache")
|
||||
return stale
|
||||
|
||||
try:
|
||||
data = json.loads(raw.decode("utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as exc:
|
||||
logger.warning("registry: failed to parse index JSON: %s", exc)
|
||||
return self._load_cache()
|
||||
|
||||
if not isinstance(data, dict):
|
||||
logger.warning("registry: index is not a JSON object")
|
||||
return self._load_cache()
|
||||
|
||||
self._save_cache(data)
|
||||
return data
|
||||
|
||||
def search(self, query: str) -> list[dict]:
|
||||
"""Search registry skills by name, description, or tags.
|
||||
|
||||
Case-insensitive substring match. Returns [] if index unavailable.
|
||||
"""
|
||||
index = self.fetch_index()
|
||||
if not index:
|
||||
return []
|
||||
skills = index.get("skills", [])
|
||||
if not isinstance(skills, list):
|
||||
return []
|
||||
q = query.lower()
|
||||
results = []
|
||||
for entry in skills:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
name = str(entry.get("name", "")).lower()
|
||||
description = str(entry.get("description", "")).lower()
|
||||
tags = " ".join(str(t) for t in entry.get("tags", [])).lower()
|
||||
if q in name or q in description or q in tags:
|
||||
results.append(entry)
|
||||
return results
|
||||
|
||||
def get_skill_entry(self, name: str) -> dict | None:
|
||||
"""Look up a single skill by exact name. Returns None if not found."""
|
||||
index = self.fetch_index()
|
||||
if not index:
|
||||
return None
|
||||
for entry in index.get("skills", []):
|
||||
if isinstance(entry, dict) and entry.get("name") == name:
|
||||
return entry
|
||||
return None
|
||||
|
||||
def get_pack(self, pack_name: str) -> list[str] | None:
|
||||
"""Return the list of skill names in a starter pack.
|
||||
|
||||
Returns None if the pack is not found or the index is unavailable.
|
||||
"""
|
||||
index = self.fetch_index()
|
||||
if not index:
|
||||
return None
|
||||
for pack in index.get("packs", []):
|
||||
if isinstance(pack, dict) and pack.get("name") == pack_name:
|
||||
skills = pack.get("skills", [])
|
||||
if isinstance(skills, list):
|
||||
return [s for s in skills if isinstance(s, str)]
|
||||
return None
|
||||
|
||||
def resolve_git_url(self, name: str) -> tuple[str, str | None] | None:
|
||||
"""Return (git_url, subdirectory) for a skill name.
|
||||
|
||||
Returns None if the skill is not in the registry or the index
|
||||
is unavailable.
|
||||
"""
|
||||
entry = self.get_skill_entry(name)
|
||||
if not entry:
|
||||
return None
|
||||
git_url = entry.get("git_url")
|
||||
if not git_url:
|
||||
return None
|
||||
subdirectory = entry.get("subdirectory") or None
|
||||
return str(git_url), subdirectory
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Cache internals
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _load_cache(self) -> dict | None:
|
||||
"""Read cached index from disk. Returns None if absent or unreadable."""
|
||||
try:
|
||||
data = json.loads(self._index_path.read_text(encoding="utf-8"))
|
||||
return data if isinstance(data, dict) else None
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
except Exception as exc:
|
||||
logger.debug("registry: could not read cache: %s", exc)
|
||||
return None
|
||||
|
||||
def _save_cache(self, data: dict) -> None:
|
||||
"""Write index to disk atomically (.tmp then rename)."""
|
||||
try:
|
||||
self._index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = self._index_path.with_suffix(".tmp")
|
||||
tmp.write_text(json.dumps(data, indent=2), encoding="utf-8")
|
||||
tmp.replace(self._index_path)
|
||||
# Update metadata
|
||||
meta = {"last_fetched": datetime.now(tz=UTC).isoformat()}
|
||||
meta_tmp = self._metadata_path.with_suffix(".tmp")
|
||||
meta_tmp.write_text(json.dumps(meta, indent=2), encoding="utf-8")
|
||||
meta_tmp.replace(self._metadata_path)
|
||||
except Exception as exc:
|
||||
logger.debug("registry: could not write cache: %s", exc)
|
||||
|
||||
def _is_cache_fresh(self) -> bool:
|
||||
"""Return True if the cached index was fetched within the TTL."""
|
||||
try:
|
||||
meta = json.loads(self._metadata_path.read_text(encoding="utf-8"))
|
||||
last_fetched = datetime.fromisoformat(meta["last_fetched"])
|
||||
age = (datetime.now(tz=UTC) - last_fetched).total_seconds()
|
||||
return age < _CACHE_TTL_SECONDS
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _http_fetch(self, url: str, timeout: int = 10) -> bytes | None:
|
||||
"""Fetch URL contents. Returns None on any network error — never raises."""
|
||||
try:
|
||||
with urlopen(url, timeout=timeout) as resp: # noqa: S310
|
||||
return resp.read()
|
||||
except URLError as exc:
|
||||
logger.debug("registry: HTTP fetch failed for %s: %s", url, exc)
|
||||
return None
|
||||
except TimeoutError as exc:
|
||||
logger.debug("registry: HTTP fetch timed out for %s: %s", url, exc)
|
||||
return None
|
||||
except Exception as exc:
|
||||
logger.debug("registry: unexpected error fetching %s: %s", url, exc)
|
||||
return None
|
||||
@@ -0,0 +1,178 @@
|
||||
"""Strict SKILL.md validation for contributor tooling (hive skill validate).
|
||||
|
||||
Unlike the lenient parser used at runtime, this module applies hard-error rules
|
||||
that match the Agent Skills specification exactly. Intended for contributor
|
||||
tooling, CI gates, and hive skill doctor.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import stat
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from framework.skills.parser import _MAX_NAME_LENGTH
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Result of a strict SKILL.md validation run."""
|
||||
|
||||
passed: bool
|
||||
errors: list[str] = field(default_factory=list)
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
def validate_strict(path: Path) -> ValidationResult:
|
||||
"""Run all strict checks against a SKILL.md file.
|
||||
|
||||
Applies hard-error rules that go beyond the lenient runtime parser:
|
||||
- name must be explicit (no directory-name fallback)
|
||||
- YAML must parse without fixup
|
||||
- name/directory mismatch is an error, not a warning
|
||||
- empty body is an error
|
||||
- scripts must be executable
|
||||
|
||||
Args:
|
||||
path: Path to the SKILL.md file to validate.
|
||||
|
||||
Returns:
|
||||
ValidationResult with passed=True if no errors, plus any warnings.
|
||||
"""
|
||||
errors: list[str] = []
|
||||
warnings: list[str] = []
|
||||
|
||||
# 1. File exists and is readable
|
||||
try:
|
||||
content = path.read_text(encoding="utf-8")
|
||||
except FileNotFoundError:
|
||||
return ValidationResult(passed=False, errors=[f"File not found: {path}"])
|
||||
except PermissionError:
|
||||
return ValidationResult(passed=False, errors=[f"Permission denied reading: {path}"])
|
||||
except OSError as exc:
|
||||
return ValidationResult(passed=False, errors=[f"Cannot read file: {exc}"])
|
||||
|
||||
# 2. File not empty
|
||||
if not content.strip():
|
||||
return ValidationResult(passed=False, errors=["File is empty."])
|
||||
|
||||
# 3. YAML frontmatter present
|
||||
parts = content.split("---", 2)
|
||||
if len(parts) < 3:
|
||||
return ValidationResult(
|
||||
passed=False,
|
||||
errors=["Missing YAML frontmatter — wrap frontmatter with --- delimiters."],
|
||||
)
|
||||
|
||||
raw_yaml = parts[1].strip()
|
||||
body = parts[2].strip()
|
||||
|
||||
if not raw_yaml:
|
||||
return ValidationResult(
|
||||
passed=False,
|
||||
errors=["Frontmatter delimiters present but YAML block is empty."],
|
||||
)
|
||||
|
||||
# 4. YAML parses WITHOUT fixup (strict: unquoted colons are an error)
|
||||
import yaml
|
||||
|
||||
frontmatter: dict | None = None
|
||||
try:
|
||||
frontmatter = yaml.safe_load(raw_yaml)
|
||||
except yaml.YAMLError as exc:
|
||||
errors.append(
|
||||
f"YAML parse error: {exc}. "
|
||||
'Wrap values containing colons in quotes, e.g. description: "Use for: research".'
|
||||
)
|
||||
return ValidationResult(passed=False, errors=errors, warnings=warnings)
|
||||
|
||||
if not isinstance(frontmatter, dict):
|
||||
return ValidationResult(
|
||||
passed=False,
|
||||
errors=["Frontmatter is not a YAML key-value mapping."],
|
||||
)
|
||||
|
||||
# 5. description present and non-empty
|
||||
description = frontmatter.get("description")
|
||||
if not description or not str(description).strip():
|
||||
errors.append("Missing required field: 'description' must be present and non-empty.")
|
||||
|
||||
# 6. name present and non-empty (no directory-name fallback in strict mode)
|
||||
name = frontmatter.get("name")
|
||||
if not name or not str(name).strip():
|
||||
errors.append(
|
||||
"Missing required field: 'name' must be present. "
|
||||
"Add 'name: your-skill-name' to the frontmatter."
|
||||
)
|
||||
else:
|
||||
name = str(name).strip()
|
||||
parent_dir_name = path.parent.name
|
||||
|
||||
# 7. name length <= 64 chars
|
||||
if len(name) > _MAX_NAME_LENGTH:
|
||||
errors.append(
|
||||
f"Skill name '{name}' is {len(name)} characters — "
|
||||
f"maximum is {_MAX_NAME_LENGTH}. Shorten the name."
|
||||
)
|
||||
|
||||
# 8. name matches parent directory (dot-namespace prefix allowed: hive.X with dir X)
|
||||
if name != parent_dir_name and not name.endswith(f".{parent_dir_name}"):
|
||||
errors.append(
|
||||
f"Name '{name}' does not match directory '{parent_dir_name}'. "
|
||||
f"Rename the directory to '{name}' or set name to '{parent_dir_name}'."
|
||||
)
|
||||
|
||||
# 9. body non-empty
|
||||
if not body:
|
||||
errors.append(
|
||||
"Skill body (instructions) is empty. "
|
||||
"Add markdown instructions after the closing --- delimiter."
|
||||
)
|
||||
|
||||
# 10. license present — warning only
|
||||
if not frontmatter.get("license"):
|
||||
warnings.append("No 'license' field — consider adding a license (e.g. MIT, Apache-2.0).")
|
||||
|
||||
# 11. Scripts in scripts/ exist and are executable (POSIX only —
|
||||
# Windows does not use POSIX permission bits)
|
||||
base_dir = path.parent
|
||||
scripts_dir = base_dir / "scripts"
|
||||
if scripts_dir.is_dir() and os.name != "nt":
|
||||
for script_path in sorted(scripts_dir.iterdir()):
|
||||
if script_path.is_file():
|
||||
if not (script_path.stat().st_mode & (stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)):
|
||||
errors.append(
|
||||
f"Script not executable: {script_path.name}. Run: chmod +x {script_path}"
|
||||
)
|
||||
|
||||
# 12. allowed-tools entries are non-empty strings — warning if malformed
|
||||
allowed_tools = frontmatter.get("allowed-tools")
|
||||
if allowed_tools is not None:
|
||||
if not isinstance(allowed_tools, list):
|
||||
warnings.append("'allowed-tools' should be a list of strings.")
|
||||
else:
|
||||
for tool in allowed_tools:
|
||||
if not isinstance(tool, str) or not tool.strip():
|
||||
warnings.append(f"'allowed-tools' entry {tool!r} is not a non-empty string.")
|
||||
|
||||
# 13. compatibility is a list of strings — error if malformed
|
||||
compatibility = frontmatter.get("compatibility")
|
||||
if compatibility is not None:
|
||||
if not isinstance(compatibility, list):
|
||||
errors.append("'compatibility' must be a list of strings.")
|
||||
else:
|
||||
for item in compatibility:
|
||||
if not isinstance(item, str):
|
||||
errors.append(f"'compatibility' entry {item!r} is not a string.")
|
||||
|
||||
# 14. metadata is a dict — error if malformed
|
||||
metadata = frontmatter.get("metadata")
|
||||
if metadata is not None and not isinstance(metadata, dict):
|
||||
errors.append("'metadata' must be a YAML mapping (dict), not a scalar or list.")
|
||||
|
||||
return ValidationResult(
|
||||
passed=len(errors) == 0,
|
||||
errors=errors,
|
||||
warnings=warnings,
|
||||
)
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Storage backends for runtime data."""
|
||||
|
||||
from framework.storage.backend import FileStorage
|
||||
from framework.storage.concurrent import ConcurrentStorage
|
||||
from framework.storage.conversation_store import FileConversationStore
|
||||
|
||||
__all__ = ["FileStorage", "FileConversationStore"]
|
||||
__all__ = ["ConcurrentStorage", "FileConversationStore"]
|
||||
|
||||
@@ -1,266 +0,0 @@
|
||||
"""
|
||||
File-based storage backend for runtime data.
|
||||
|
||||
DEPRECATED: This storage backend is deprecated for new sessions.
|
||||
New sessions use unified storage at sessions/{session_id}/state.json.
|
||||
This module is kept for backward compatibility with old run data only.
|
||||
|
||||
Uses Pydantic's built-in serialization.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from framework.schemas.run import Run, RunStatus, RunSummary
|
||||
from framework.utils.io import atomic_write
|
||||
|
||||
|
||||
class FileStorage:
|
||||
"""
|
||||
DEPRECATED: File-based storage for old runs only.
|
||||
|
||||
New sessions use unified storage at sessions/{session_id}/state.json.
|
||||
This class is kept for backward compatibility with old run data.
|
||||
|
||||
Old directory structure (deprecated):
|
||||
{base_path}/
|
||||
runs/ # DEPRECATED - no longer written
|
||||
{run_id}.json
|
||||
summaries/ # DEPRECATED - no longer written
|
||||
{run_id}.json
|
||||
indexes/ # DEPRECATED - no longer written or read
|
||||
by_goal/
|
||||
{goal_id}.json
|
||||
by_status/
|
||||
{status}.json
|
||||
by_node/
|
||||
{node_id}.json
|
||||
"""
|
||||
|
||||
def __init__(self, base_path: str | Path):
|
||||
self.base_path = Path(base_path)
|
||||
self._ensure_dirs()
|
||||
|
||||
def _ensure_dirs(self) -> None:
|
||||
"""Create directory structure if it doesn't exist.
|
||||
|
||||
DEPRECATED: All directories (runs/, summaries/, indexes/) are deprecated.
|
||||
New sessions use unified storage at sessions/{session_id}/state.json.
|
||||
This method is now a no-op. Tests should not rely on this.
|
||||
"""
|
||||
# No-op: do not create deprecated directories
|
||||
pass
|
||||
|
||||
def _validate_key(self, key: str) -> None:
|
||||
"""
|
||||
Validate key to prevent path traversal attacks.
|
||||
|
||||
Args:
|
||||
key: The key to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If key contains path traversal or dangerous patterns
|
||||
"""
|
||||
if not key or key.strip() == "":
|
||||
raise ValueError("Key cannot be empty")
|
||||
|
||||
# Block path separators
|
||||
if "/" in key or "\\" in key:
|
||||
raise ValueError(f"Invalid key format: path separators not allowed in '{key}'")
|
||||
|
||||
# Block parent directory references
|
||||
if ".." in key or key.startswith("."):
|
||||
raise ValueError(f"Invalid key format: path traversal detected in '{key}'")
|
||||
|
||||
# Block absolute paths
|
||||
if key.startswith("/") or (len(key) > 1 and key[1] == ":"):
|
||||
raise ValueError(f"Invalid key format: absolute paths not allowed in '{key}'")
|
||||
|
||||
# Block null bytes (Unix path injection)
|
||||
if "\x00" in key:
|
||||
raise ValueError("Invalid key format: null bytes not allowed")
|
||||
|
||||
# Block other dangerous special characters
|
||||
dangerous_chars = {"<", ">", "|", "&", "$", "`", "'", '"'}
|
||||
if any(char in key for char in dangerous_chars):
|
||||
raise ValueError(f"Invalid key format: contains dangerous characters in '{key}'")
|
||||
|
||||
# === RUN OPERATIONS ===
|
||||
|
||||
def save_run(self, run: Run) -> None:
|
||||
"""Save a run to storage.
|
||||
|
||||
DEPRECATED: This method is now a no-op.
|
||||
New sessions use unified storage at sessions/{session_id}/state.json.
|
||||
Tests should not rely on FileStorage - use unified session storage instead.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"FileStorage.save_run() is deprecated. "
|
||||
"New sessions use unified storage at sessions/{session_id}/state.json. "
|
||||
"This write has been skipped.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
# No-op: do not write to deprecated locations
|
||||
|
||||
def load_run(self, run_id: str) -> Run | None:
|
||||
"""Load a run from storage."""
|
||||
run_path = self.base_path / "runs" / f"{run_id}.json"
|
||||
if not run_path.exists():
|
||||
return None
|
||||
with open(run_path, encoding="utf-8") as f:
|
||||
return Run.model_validate_json(f.read())
|
||||
|
||||
def load_summary(self, run_id: str) -> RunSummary | None:
|
||||
"""Load just the summary (faster than full run)."""
|
||||
summary_path = self.base_path / "summaries" / f"{run_id}.json"
|
||||
if not summary_path.exists():
|
||||
# Fall back to computing from full run
|
||||
run = self.load_run(run_id)
|
||||
if run:
|
||||
return RunSummary.from_run(run)
|
||||
return None
|
||||
|
||||
with open(summary_path, encoding="utf-8") as f:
|
||||
return RunSummary.model_validate_json(f.read())
|
||||
|
||||
def delete_run(self, run_id: str) -> bool:
|
||||
"""Delete a run from storage."""
|
||||
run_path = self.base_path / "runs" / f"{run_id}.json"
|
||||
summary_path = self.base_path / "summaries" / f"{run_id}.json"
|
||||
|
||||
if not run_path.exists():
|
||||
return False
|
||||
|
||||
# Load run to get index keys
|
||||
run = self.load_run(run_id)
|
||||
if run:
|
||||
self._remove_from_index("by_goal", run.goal_id, run_id)
|
||||
self._remove_from_index("by_status", run.status.value, run_id)
|
||||
for node_id in run.metrics.nodes_executed:
|
||||
self._remove_from_index("by_node", node_id, run_id)
|
||||
|
||||
run_path.unlink()
|
||||
if summary_path.exists():
|
||||
summary_path.unlink()
|
||||
|
||||
return True
|
||||
|
||||
# === QUERY OPERATIONS ===
|
||||
|
||||
def get_runs_by_goal(self, goal_id: str) -> list[str]:
|
||||
"""Get all run IDs for a goal.
|
||||
|
||||
DEPRECATED: Indexes are deprecated. For new sessions, scan sessions/*/state.json instead.
|
||||
This method only returns old run IDs from deprecated indexes.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"FileStorage.get_runs_by_goal() is deprecated. "
|
||||
"For new sessions, scan sessions/*/state.json instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self._get_index("by_goal", goal_id)
|
||||
|
||||
def get_runs_by_status(self, status: str | RunStatus) -> list[str]:
|
||||
"""Get all run IDs with a status.
|
||||
|
||||
DEPRECATED: Indexes are deprecated. For new sessions, scan sessions/*/state.json instead.
|
||||
This method only returns old run IDs from deprecated indexes.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"FileStorage.get_runs_by_status() is deprecated. "
|
||||
"For new sessions, scan sessions/*/state.json instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if isinstance(status, RunStatus):
|
||||
status = status.value
|
||||
return self._get_index("by_status", status)
|
||||
|
||||
def get_runs_by_node(self, node_id: str) -> list[str]:
|
||||
"""Get all run IDs that executed a node.
|
||||
|
||||
DEPRECATED: Indexes are deprecated. For new sessions, scan sessions/*/state.json instead.
|
||||
This method only returns old run IDs from deprecated indexes.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"FileStorage.get_runs_by_node() is deprecated. "
|
||||
"For new sessions, scan sessions/*/state.json instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self._get_index("by_node", node_id)
|
||||
|
||||
def list_all_runs(self) -> list[str]:
|
||||
"""List all run IDs."""
|
||||
runs_dir = self.base_path / "runs"
|
||||
return [f.stem for f in runs_dir.glob("*.json")]
|
||||
|
||||
def list_all_goals(self) -> list[str]:
|
||||
"""List all goal IDs that have runs.
|
||||
|
||||
DEPRECATED: Indexes are deprecated. For new sessions, scan sessions/*/state.json instead.
|
||||
This method only returns goals from old run IDs in deprecated indexes.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"FileStorage.list_all_goals() is deprecated. "
|
||||
"For new sessions, scan sessions/*/state.json instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
goals_dir = self.base_path / "indexes" / "by_goal"
|
||||
if not goals_dir.exists():
|
||||
return []
|
||||
return [f.stem for f in goals_dir.glob("*.json")]
|
||||
|
||||
# === INDEX OPERATIONS ===
|
||||
|
||||
def _get_index(self, index_type: str, key: str) -> list[str]:
|
||||
"""Get values from an index."""
|
||||
self._validate_key(key) # Prevent path traversal
|
||||
index_path = self.base_path / "indexes" / index_type / f"{key}.json"
|
||||
if not index_path.exists():
|
||||
return []
|
||||
with open(index_path, encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
def _add_to_index(self, index_type: str, key: str, value: str) -> None:
|
||||
"""Add a value to an index."""
|
||||
self._validate_key(key) # Prevent path traversal
|
||||
index_path = self.base_path / "indexes" / index_type / f"{key}.json"
|
||||
values = self._get_index(index_type, key) # Already validated in _get_index
|
||||
if value not in values:
|
||||
values.append(value)
|
||||
with atomic_write(index_path) as f:
|
||||
json.dump(values, f, indent=2)
|
||||
|
||||
def _remove_from_index(self, index_type: str, key: str, value: str) -> None:
|
||||
"""Remove a value from an index."""
|
||||
self._validate_key(key) # Prevent path traversal
|
||||
index_path = self.base_path / "indexes" / index_type / f"{key}.json"
|
||||
values = self._get_index(index_type, key) # Already validated in _get_index
|
||||
if value in values:
|
||||
values.remove(value)
|
||||
with atomic_write(index_path) as f:
|
||||
json.dump(values, f, indent=2)
|
||||
|
||||
# === UTILITY ===
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get storage statistics."""
|
||||
return {
|
||||
"total_runs": len(self.list_all_runs()),
|
||||
"total_goals": len(self.list_all_goals()),
|
||||
"storage_path": str(self.base_path),
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
Concurrent Storage - Thread-safe storage backend with file locking.
|
||||
|
||||
Wraps FileStorage with:
|
||||
Provides:
|
||||
- Async file locking for atomic writes
|
||||
- Write batching for performance
|
||||
- Read caching for concurrent access
|
||||
@@ -16,8 +16,8 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from framework.schemas.run import Run, RunStatus, RunSummary
|
||||
from framework.storage.backend import FileStorage
|
||||
from framework.schemas.run import Run, RunSummary
|
||||
from framework.utils.io import atomic_write
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -41,7 +41,6 @@ class ConcurrentStorage:
|
||||
- Async file locking to prevent concurrent write corruption
|
||||
- Write batching to reduce I/O overhead
|
||||
- Read caching for frequently accessed data
|
||||
- Compatible API with FileStorage
|
||||
|
||||
Example:
|
||||
storage = ConcurrentStorage("/path/to/storage")
|
||||
@@ -75,7 +74,6 @@ class ConcurrentStorage:
|
||||
max_locks: Maximum number of active file locks to track strongly
|
||||
"""
|
||||
self.base_path = Path(base_path)
|
||||
self._base_storage = FileStorage(base_path)
|
||||
|
||||
# Caching
|
||||
self._cache: dict[str, CacheEntry] = {}
|
||||
@@ -157,6 +155,93 @@ class ConcurrentStorage:
|
||||
|
||||
return lock
|
||||
|
||||
# === KEY VALIDATION ===
|
||||
|
||||
@staticmethod
|
||||
def _validate_key(key: str) -> None:
|
||||
"""Validate key to prevent path traversal attacks.
|
||||
|
||||
Args:
|
||||
key: The key to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If key contains path traversal or dangerous patterns
|
||||
"""
|
||||
if not key or key.strip() == "":
|
||||
raise ValueError("Key cannot be empty")
|
||||
|
||||
if "/" in key or "\\" in key:
|
||||
raise ValueError(f"Invalid key format: path separators not allowed in '{key}'")
|
||||
|
||||
if ".." in key or key.startswith("."):
|
||||
raise ValueError(f"Invalid key format: path traversal detected in '{key}'")
|
||||
|
||||
if key.startswith("/") or (len(key) > 1 and key[1] == ":"):
|
||||
raise ValueError(f"Invalid key format: absolute paths not allowed in '{key}'")
|
||||
|
||||
if "\x00" in key:
|
||||
raise ValueError("Invalid key format: null bytes not allowed")
|
||||
|
||||
dangerous_chars = {"<", ">", "|", "&", "$", "`", "'", '"'}
|
||||
if any(char in key for char in dangerous_chars):
|
||||
raise ValueError(f"Invalid key format: contains dangerous characters in '{key}'")
|
||||
|
||||
# === FILE OPERATIONS (formerly in FileStorage) ===
|
||||
|
||||
def _save_run_sync(self, run: Run) -> None:
|
||||
"""Persist a run to disk as ``runs/{run_id}.json``.
|
||||
|
||||
Uses an atomic write (temp-file + rename) so a mid-write crash
|
||||
never leaves a partially written file on disk.
|
||||
"""
|
||||
self._validate_key(run.id)
|
||||
runs_dir = self.base_path / "runs"
|
||||
runs_dir.mkdir(parents=True, exist_ok=True)
|
||||
run_path = runs_dir / f"{run.id}.json"
|
||||
with atomic_write(run_path) as f:
|
||||
f.write(run.model_dump_json(indent=2))
|
||||
|
||||
def _load_run_sync(self, run_id: str) -> Run | None:
|
||||
"""Load a run from storage."""
|
||||
run_path = self.base_path / "runs" / f"{run_id}.json"
|
||||
if not run_path.exists():
|
||||
return None
|
||||
with open(run_path, encoding="utf-8") as f:
|
||||
return Run.model_validate_json(f.read())
|
||||
|
||||
def _load_summary_sync(self, run_id: str) -> RunSummary | None:
|
||||
"""Load just the summary (faster than full run)."""
|
||||
self._validate_key(run_id)
|
||||
summary_path = self.base_path / "summaries" / f"{run_id}.json"
|
||||
if not summary_path.exists():
|
||||
run = self._load_run_sync(run_id)
|
||||
if run:
|
||||
return RunSummary.from_run(run)
|
||||
return None
|
||||
with open(summary_path, encoding="utf-8") as f:
|
||||
return RunSummary.model_validate_json(f.read())
|
||||
|
||||
def _delete_run_sync(self, run_id: str) -> bool:
|
||||
"""Delete a run from storage."""
|
||||
run_path = self.base_path / "runs" / f"{run_id}.json"
|
||||
summary_path = self.base_path / "summaries" / f"{run_id}.json"
|
||||
|
||||
if not run_path.exists():
|
||||
return False
|
||||
|
||||
run_path.unlink()
|
||||
if summary_path.exists():
|
||||
summary_path.unlink()
|
||||
|
||||
return True
|
||||
|
||||
def _list_all_runs_sync(self) -> list[str]:
|
||||
"""List all run IDs."""
|
||||
runs_dir = self.base_path / "runs"
|
||||
if not runs_dir.exists():
|
||||
return []
|
||||
return [f.stem for f in runs_dir.glob("*.json")]
|
||||
|
||||
# === RUN OPERATIONS (Async, Thread-Safe) ===
|
||||
|
||||
async def save_run(self, run: Run, immediate: bool = False) -> None:
|
||||
@@ -180,40 +265,17 @@ class ConcurrentStorage:
|
||||
await self._write_queue.put(("run", run))
|
||||
|
||||
async def _save_run_locked(self, run: Run) -> None:
|
||||
"""Save a run with file locking, including index locks."""
|
||||
"""Save a run with file locking."""
|
||||
lock_key = f"run:{run.id}"
|
||||
|
||||
# Helper to get lock
|
||||
async def get_lock(k):
|
||||
return await self._get_lock(k)
|
||||
|
||||
# Acquire main lock
|
||||
run_lock = await get_lock(lock_key)
|
||||
run_lock = await self._get_lock(lock_key)
|
||||
|
||||
async with run_lock:
|
||||
# 2. Acquire index locks
|
||||
index_lock_keys = [
|
||||
f"index:by_goal:{run.goal_id}",
|
||||
f"index:by_status:{run.status.value}",
|
||||
]
|
||||
for node_id in run.metrics.nodes_executed:
|
||||
index_lock_keys.append(f"index:by_node:{node_id}")
|
||||
|
||||
# Collect index locks
|
||||
index_locks = [await get_lock(k) for k in index_lock_keys]
|
||||
|
||||
# Recursive acquisition
|
||||
async def with_locks(locks, callback):
|
||||
if not locks:
|
||||
return await callback()
|
||||
async with locks[0]:
|
||||
return await with_locks(locks[1:], callback)
|
||||
|
||||
async def perform_save():
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, self._base_storage.save_run, run)
|
||||
await loop.run_in_executor(None, self._save_run_sync, run)
|
||||
|
||||
await with_locks(index_locks, perform_save)
|
||||
await perform_save()
|
||||
|
||||
async def load_run(self, run_id: str, use_cache: bool = True) -> Run | None:
|
||||
"""
|
||||
@@ -225,7 +287,11 @@ class ConcurrentStorage:
|
||||
|
||||
Returns:
|
||||
Run object or None if not found
|
||||
|
||||
Raises:
|
||||
ValueError: If run_id contains path traversal characters.
|
||||
"""
|
||||
self._validate_key(run_id)
|
||||
if use_cache:
|
||||
cache_key = f"run:{run_id}"
|
||||
cached = self._cache.get(cache_key)
|
||||
@@ -240,7 +306,7 @@ class ConcurrentStorage:
|
||||
lock_key = f"run:{run_id}"
|
||||
async with await self._get_lock(lock_key):
|
||||
loop = asyncio.get_event_loop()
|
||||
run = await loop.run_in_executor(None, self._base_storage.load_run, run_id)
|
||||
run = await loop.run_in_executor(None, self._load_run_sync, run_id)
|
||||
|
||||
# Update cache
|
||||
if run:
|
||||
@@ -249,7 +315,12 @@ class ConcurrentStorage:
|
||||
return run
|
||||
|
||||
async def load_summary(self, run_id: str, use_cache: bool = True) -> RunSummary | None:
|
||||
"""Load just the summary (faster than full run)."""
|
||||
"""Load just the summary (faster than full run).
|
||||
|
||||
Raises:
|
||||
ValueError: If run_id contains path traversal characters.
|
||||
"""
|
||||
self._validate_key(run_id)
|
||||
cache_key = f"summary:{run_id}"
|
||||
|
||||
# Check cache
|
||||
@@ -262,7 +333,7 @@ class ConcurrentStorage:
|
||||
lock_key = f"summary:{run_id}"
|
||||
async with await self._get_lock(lock_key):
|
||||
loop = asyncio.get_event_loop()
|
||||
summary = await loop.run_in_executor(None, self._base_storage.load_summary, run_id)
|
||||
summary = await loop.run_in_executor(None, self._load_summary_sync, run_id)
|
||||
|
||||
# Update cache
|
||||
if summary:
|
||||
@@ -271,11 +342,16 @@ class ConcurrentStorage:
|
||||
return summary
|
||||
|
||||
async def delete_run(self, run_id: str) -> bool:
|
||||
"""Delete a run from storage."""
|
||||
"""Delete a run from storage.
|
||||
|
||||
Raises:
|
||||
ValueError: If run_id contains path traversal characters.
|
||||
"""
|
||||
self._validate_key(run_id)
|
||||
lock_key = f"run:{run_id}"
|
||||
async with await self._get_lock(lock_key):
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(None, self._base_storage.delete_run, run_id)
|
||||
result = await loop.run_in_executor(None, self._delete_run_sync, run_id)
|
||||
|
||||
# Clear cache
|
||||
self._cache.pop(f"run:{run_id}", None)
|
||||
@@ -283,37 +359,10 @@ class ConcurrentStorage:
|
||||
|
||||
return result
|
||||
|
||||
# === QUERY OPERATIONS (Async, with Locking) ===
|
||||
|
||||
async def get_runs_by_goal(self, goal_id: str) -> list[str]:
|
||||
"""Get all run IDs for a goal."""
|
||||
async with await self._get_lock(f"index:by_goal:{goal_id}"):
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self._base_storage.get_runs_by_goal, goal_id)
|
||||
|
||||
async def get_runs_by_status(self, status: str | RunStatus) -> list[str]:
|
||||
"""Get all run IDs with a status."""
|
||||
if isinstance(status, RunStatus):
|
||||
status = status.value
|
||||
async with await self._get_lock(f"index:by_status:{status}"):
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self._base_storage.get_runs_by_status, status)
|
||||
|
||||
async def get_runs_by_node(self, node_id: str) -> list[str]:
|
||||
"""Get all run IDs that executed a node."""
|
||||
async with await self._get_lock(f"index:by_node:{node_id}"):
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self._base_storage.get_runs_by_node, node_id)
|
||||
|
||||
async def list_all_runs(self) -> list[str]:
|
||||
"""List all run IDs."""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self._base_storage.list_all_runs)
|
||||
|
||||
async def list_all_goals(self) -> list[str]:
|
||||
"""List all goal IDs that have runs."""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self._base_storage.list_all_goals)
|
||||
return await loop.run_in_executor(None, self._list_all_runs_sync)
|
||||
|
||||
# === BATCH OPERATIONS ===
|
||||
|
||||
@@ -411,10 +460,11 @@ class ConcurrentStorage:
|
||||
async def get_stats(self) -> dict:
|
||||
"""Get storage statistics."""
|
||||
loop = asyncio.get_event_loop()
|
||||
base_stats = await loop.run_in_executor(None, self._base_storage.get_stats)
|
||||
all_runs = await loop.run_in_executor(None, self._list_all_runs_sync)
|
||||
|
||||
return {
|
||||
**base_stats,
|
||||
"total_runs": len(all_runs),
|
||||
"storage_path": str(self.base_path),
|
||||
"cache": self.get_cache_stats(),
|
||||
"pending_writes": self._write_queue.qsize(),
|
||||
"running": self._running,
|
||||
@@ -423,10 +473,21 @@ class ConcurrentStorage:
|
||||
# === SYNC API (for backward compatibility) ===
|
||||
|
||||
def save_run_sync(self, run: Run) -> None:
|
||||
"""Synchronous save (uses base storage directly with lock)."""
|
||||
# Use threading lock for sync operations
|
||||
self._base_storage.save_run(run)
|
||||
"""Synchronous save — persists a run to disk immediately."""
|
||||
self._validate_key(run.id)
|
||||
# Invalidate summary cache since the run data is changing
|
||||
self._cache.pop(f"summary:{run.id}", None)
|
||||
|
||||
self._save_run_sync(run)
|
||||
|
||||
# Refresh run cache
|
||||
self._cache[f"run:{run.id}"] = CacheEntry(run, time.time())
|
||||
|
||||
def load_run_sync(self, run_id: str) -> Run | None:
|
||||
"""Synchronous load (uses base storage directly)."""
|
||||
return self._base_storage.load_run(run_id)
|
||||
"""Synchronous load.
|
||||
|
||||
Raises:
|
||||
ValueError: If run_id contains path traversal characters.
|
||||
"""
|
||||
self._validate_key(run_id)
|
||||
return self._load_run_sync(run_id)
|
||||
|
||||
@@ -62,8 +62,14 @@ class SessionStore:
|
||||
|
||||
Returns:
|
||||
Path to session directory
|
||||
|
||||
Raises:
|
||||
ValueError: If session_id resolves outside the sessions directory
|
||||
"""
|
||||
return self.sessions_dir / session_id
|
||||
resolved = (self.sessions_dir / session_id).resolve()
|
||||
if not resolved.is_relative_to(self.sessions_dir.resolve()):
|
||||
raise ValueError(f"Invalid session ID: {session_id}")
|
||||
return resolved
|
||||
|
||||
def get_state_path(self, session_id: str) -> Path:
|
||||
"""
|
||||
|
||||
@@ -73,7 +73,9 @@ class DebugTool:
|
||||
|
||||
Args:
|
||||
test_storage: Storage for test and result data
|
||||
runtime_storage: Optional FileStorage for Runtime data
|
||||
runtime_storage: Optional storage backend for Runtime data.
|
||||
Must expose a synchronous ``load_run_sync(run_id)`` method
|
||||
(e.g. ``ConcurrentStorage``).
|
||||
"""
|
||||
self.test_storage = test_storage
|
||||
self.runtime_storage = runtime_storage
|
||||
@@ -233,7 +235,9 @@ class DebugTool:
|
||||
return {}
|
||||
|
||||
try:
|
||||
run = self.runtime_storage.load_run(run_id)
|
||||
# Use the synchronous loader — _get_runtime_data is not async
|
||||
# and ConcurrentStorage.load_run() is a coroutine.
|
||||
run = self.runtime_storage.load_run_sync(run_id)
|
||||
if not run:
|
||||
return {"error": f"Run {run_id} not found"}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
File-based storage backend for test data.
|
||||
|
||||
Follows the same pattern as framework/storage/backend.py (FileStorage),
|
||||
Follows the same pattern as framework/storage/concurrent.py (ConcurrentStorage),
|
||||
storing tests as JSON files with indexes for efficient querying.
|
||||
"""
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ def recall_diary(query: str = "", days_back: int = 7) -> str:
|
||||
"""
|
||||
from datetime import date, timedelta
|
||||
|
||||
from framework.agents.queen.queen_memory import read_episodic_memory
|
||||
from framework.agents.queen.queen_memory import format_memory_date, read_episodic_memory
|
||||
|
||||
days_back = max(1, min(days_back, 30))
|
||||
today = date.today()
|
||||
@@ -70,7 +70,7 @@ def recall_diary(query: str = "", days_back: int = 7) -> str:
|
||||
if not matched:
|
||||
continue
|
||||
content = "### ".join(matched)
|
||||
label = d.strftime("%B %-d, %Y")
|
||||
label = format_memory_date(d)
|
||||
if d == today:
|
||||
label = f"Today — {label}"
|
||||
entry = f"## {label}\n\n{content}"
|
||||
|
||||
@@ -2,6 +2,7 @@ import { Routes, Route } from "react-router-dom";
|
||||
import Home from "./pages/home";
|
||||
import MyAgents from "./pages/my-agents";
|
||||
import Workspace from "./pages/workspace";
|
||||
import NotFound from "./pages/not-found";
|
||||
|
||||
function App() {
|
||||
return (
|
||||
@@ -9,6 +10,7 @@ function App() {
|
||||
<Route path="/" element={<Home />} />
|
||||
<Route path="/my-agents" element={<MyAgents />} />
|
||||
<Route path="/workspace" element={<Workspace />} />
|
||||
<Route path="*" element={<NotFound />} />
|
||||
</Routes>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
import { Link } from "react-router-dom";
|
||||
|
||||
export default function NotFound() {
|
||||
return (
|
||||
<div className="min-h-screen bg-background flex flex-col items-center justify-center px-6 text-center">
|
||||
<h1 className="text-5xl font-semibold text-foreground">404</h1>
|
||||
<p className="mt-3 text-sm text-muted-foreground">Page not found</p>
|
||||
<p className="mt-1 text-sm text-muted-foreground/80">
|
||||
The page you’re looking for doesn’t exist.
|
||||
</p>
|
||||
<Link
|
||||
to="/"
|
||||
className="mt-6 inline-flex items-center rounded-lg border border-border/40 px-4 py-2 text-sm font-medium text-foreground hover:bg-muted/40 transition-colors"
|
||||
>
|
||||
Back to Home
|
||||
</Link>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -32,7 +32,7 @@ class _FakeRegistry:
|
||||
|
||||
def load_agent_selection(self, agent_path: Path):
|
||||
self.loaded_paths.append(agent_path)
|
||||
return list(self._returned_configs)
|
||||
return list(self._returned_configs), None
|
||||
|
||||
|
||||
def test_agent_runner_loads_registry_selected_servers(tmp_path, monkeypatch):
|
||||
@@ -61,7 +61,7 @@ def test_agent_runner_loads_registry_selected_servers(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(AgentRunner, "_resolve_default_model", staticmethod(lambda: "test-model"))
|
||||
monkeypatch.setattr(
|
||||
"framework.runner.tool_registry.ToolRegistry.register_mcp_server",
|
||||
lambda self, server_config, use_connection_manager=True: (
|
||||
lambda self, server_config, use_connection_manager=True, **kwargs: (
|
||||
registered.append(server_config) or 1
|
||||
),
|
||||
)
|
||||
@@ -95,7 +95,7 @@ def test_agent_runner_skips_registry_when_no_servers_selected(tmp_path, monkeypa
|
||||
monkeypatch.setattr(AgentRunner, "_resolve_default_model", staticmethod(lambda: "test-model"))
|
||||
monkeypatch.setattr(
|
||||
"framework.runner.tool_registry.ToolRegistry.register_mcp_server",
|
||||
lambda self, server_config, use_connection_manager=True: (
|
||||
lambda self, server_config, use_connection_manager=True, **kwargs: (
|
||||
registered.append(server_config) or 1
|
||||
),
|
||||
)
|
||||
@@ -135,7 +135,7 @@ def test_agent_runner_logs_actual_registry_load_results(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(AgentRunner, "_resolve_default_model", staticmethod(lambda: "test-model"))
|
||||
monkeypatch.setattr(
|
||||
"framework.runner.tool_registry.ToolRegistry.load_registry_servers",
|
||||
lambda self, server_configs: [
|
||||
lambda self, server_configs, **kwargs: [
|
||||
{"server": "jira", "status": "loaded", "tools_loaded": 2, "skipped_reason": None},
|
||||
{
|
||||
"server": "slack",
|
||||
@@ -223,7 +223,7 @@ def test_integration_real_registry_to_agent_runner(tmp_path, monkeypatch):
|
||||
registered: list[dict] = []
|
||||
monkeypatch.setattr(
|
||||
"framework.runner.tool_registry.ToolRegistry.register_mcp_server",
|
||||
lambda self, server_config, use_connection_manager=True: (
|
||||
lambda self, server_config, use_connection_manager=True, **kwargs: (
|
||||
registered.append(server_config) or 1
|
||||
),
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ from framework.skills.defaults import (
|
||||
SHARED_MEMORY_KEYS,
|
||||
SKILL_REGISTRY,
|
||||
DefaultSkillManager,
|
||||
is_batch_scenario,
|
||||
)
|
||||
from framework.skills.parser import parse_skill_md
|
||||
|
||||
@@ -186,3 +187,128 @@ class TestSkillsConfig:
|
||||
assert config.skills == []
|
||||
assert config.default_skills == {}
|
||||
assert config.all_defaults_disabled is False
|
||||
|
||||
|
||||
class TestConfigOverrideSubstitution:
|
||||
"""Config overrides replace {{placeholder}} values in injected protocol text."""
|
||||
|
||||
def test_quality_monitor_default_interval(self):
|
||||
manager = DefaultSkillManager()
|
||||
manager.load()
|
||||
prompt = manager.build_protocols_prompt()
|
||||
assert "Every 5 iterations" in prompt
|
||||
|
||||
def test_quality_monitor_override_interval(self):
|
||||
config = SkillsConfig.from_agent_vars(
|
||||
default_skills={"hive.quality-monitor": {"assessment_interval": 10}}
|
||||
)
|
||||
manager = DefaultSkillManager(config)
|
||||
manager.load()
|
||||
prompt = manager.build_protocols_prompt()
|
||||
assert "Every 10 iterations" in prompt
|
||||
assert "Every 5 iterations" not in prompt
|
||||
|
||||
def test_error_recovery_default_retries(self):
|
||||
manager = DefaultSkillManager()
|
||||
manager.load()
|
||||
prompt = manager.build_protocols_prompt()
|
||||
assert "3+ times" in prompt
|
||||
|
||||
def test_error_recovery_override_retries(self):
|
||||
config = SkillsConfig.from_agent_vars(
|
||||
default_skills={"hive.error-recovery": {"max_retries_per_tool": 5}}
|
||||
)
|
||||
manager = DefaultSkillManager(config)
|
||||
manager.load()
|
||||
prompt = manager.build_protocols_prompt()
|
||||
assert "5+ times" in prompt
|
||||
assert "3+ times" not in prompt
|
||||
|
||||
def test_context_preservation_default_threshold(self):
|
||||
manager = DefaultSkillManager()
|
||||
manager.load()
|
||||
prompt = manager.build_protocols_prompt()
|
||||
assert "45%" in prompt
|
||||
|
||||
def test_context_preservation_override_threshold(self):
|
||||
config = SkillsConfig.from_agent_vars(
|
||||
default_skills={"hive.context-preservation": {"warn_at_usage_ratio": 0.4}}
|
||||
)
|
||||
manager = DefaultSkillManager(config)
|
||||
manager.load()
|
||||
prompt = manager.build_protocols_prompt()
|
||||
assert "40%" in prompt
|
||||
assert "45%" not in prompt
|
||||
|
||||
def test_no_unreplaced_placeholders_with_defaults(self):
|
||||
"""All {{...}} placeholders should be replaced when using defaults."""
|
||||
manager = DefaultSkillManager()
|
||||
manager.load()
|
||||
prompt = manager.build_protocols_prompt()
|
||||
assert "{{" not in prompt
|
||||
|
||||
|
||||
class TestBatchAutoDetection:
|
||||
"""DS-12: is_batch_scenario() and batch_init_nudge property."""
|
||||
|
||||
def test_detects_list_of(self):
|
||||
assert is_batch_scenario("process a list of 100 leads") is True
|
||||
|
||||
def test_detects_collection_of(self):
|
||||
assert is_batch_scenario("a collection of invoices") is True
|
||||
|
||||
def test_detects_items(self):
|
||||
assert is_batch_scenario("go through all items in the spreadsheet") is True
|
||||
|
||||
def test_detects_for_each(self):
|
||||
assert is_batch_scenario("for each record, send an email") is True
|
||||
|
||||
def test_no_match_single_task(self):
|
||||
assert is_batch_scenario("write a summary of the quarterly report") is False
|
||||
|
||||
def test_batch_nudge_active_by_default(self):
|
||||
manager = DefaultSkillManager()
|
||||
manager.load()
|
||||
assert manager.batch_init_nudge is not None
|
||||
assert "_batch_ledger" in manager.batch_init_nudge
|
||||
|
||||
def test_batch_nudge_none_when_skill_disabled(self):
|
||||
config = SkillsConfig.from_agent_vars(
|
||||
default_skills={"hive.batch-ledger": {"enabled": False}}
|
||||
)
|
||||
manager = DefaultSkillManager(config)
|
||||
manager.load()
|
||||
assert manager.batch_init_nudge is None
|
||||
|
||||
def test_batch_nudge_none_when_auto_detect_disabled(self):
|
||||
config = SkillsConfig.from_agent_vars(
|
||||
default_skills={"hive.batch-ledger": {"auto_detect_batch": False}}
|
||||
)
|
||||
manager = DefaultSkillManager(config)
|
||||
manager.load()
|
||||
assert manager.batch_init_nudge is None
|
||||
|
||||
|
||||
class TestContextWarnRatio:
|
||||
"""DS-13: context_warn_ratio property."""
|
||||
|
||||
def test_default_ratio(self):
|
||||
manager = DefaultSkillManager()
|
||||
manager.load()
|
||||
assert manager.context_warn_ratio == pytest.approx(0.45)
|
||||
|
||||
def test_override_ratio(self):
|
||||
config = SkillsConfig.from_agent_vars(
|
||||
default_skills={"hive.context-preservation": {"warn_at_usage_ratio": 0.3}}
|
||||
)
|
||||
manager = DefaultSkillManager(config)
|
||||
manager.load()
|
||||
assert manager.context_warn_ratio == pytest.approx(0.3)
|
||||
|
||||
def test_ratio_none_when_skill_disabled(self):
|
||||
config = SkillsConfig.from_agent_vars(
|
||||
default_skills={"hive.context-preservation": {"enabled": False}}
|
||||
)
|
||||
manager = DefaultSkillManager(config)
|
||||
manager.load()
|
||||
assert manager.context_warn_ratio is None
|
||||
|
||||
@@ -18,11 +18,14 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.config import get_llm_extra_kwargs
|
||||
from framework.llm.anthropic import AnthropicProvider
|
||||
from framework.llm.litellm import (
|
||||
OPENROUTER_TOOL_COMPAT_MODEL_CACHE,
|
||||
LiteLLMProvider,
|
||||
_compute_retry_delay,
|
||||
_ensure_ollama_chat_prefix,
|
||||
_is_ollama_model,
|
||||
)
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
|
||||
@@ -93,9 +96,9 @@ class TestLiteLLMProviderInit:
|
||||
def test_init_ollama_no_key_needed(self):
|
||||
"""Test that Ollama models don't require API key."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
# Should not raise.
|
||||
# Should not raise; ollama/ is normalised to ollama_chat/ for tool-call support.
|
||||
provider = LiteLLMProvider(model="ollama/llama3")
|
||||
assert provider.model == "ollama/llama3"
|
||||
assert provider.model == "ollama_chat/llama3"
|
||||
|
||||
|
||||
class TestLiteLLMProviderComplete:
|
||||
@@ -1084,3 +1087,103 @@ class TestIsLocalModel:
|
||||
from framework.runner.runner import AgentRunner
|
||||
|
||||
assert AgentRunner._is_local_model(model) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ollama helper functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsOllamaModel:
|
||||
"""Tests for _is_ollama_model()."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"ollama/llama3",
|
||||
"ollama/mistral:7b",
|
||||
"ollama_chat/llama3",
|
||||
"ollama_chat/qwen2.5:72b",
|
||||
],
|
||||
)
|
||||
def test_ollama_models_return_true(self, model):
|
||||
assert _is_ollama_model(model) is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"gpt-4o-mini",
|
||||
"anthropic/claude-3-haiku",
|
||||
"openai/gpt-4o",
|
||||
"gemini/gemini-1.5-flash",
|
||||
"llama3",
|
||||
"",
|
||||
],
|
||||
)
|
||||
def test_non_ollama_models_return_false(self, model):
|
||||
assert _is_ollama_model(model) is False
|
||||
|
||||
|
||||
class TestEnsureOllamaChatPrefix:
|
||||
"""Tests for _ensure_ollama_chat_prefix()."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("input_model", "expected"),
|
||||
[
|
||||
("ollama/llama3", "ollama_chat/llama3"),
|
||||
("ollama/mistral:7b", "ollama_chat/mistral:7b"),
|
||||
("ollama/qwen2.5:72b-instruct", "ollama_chat/qwen2.5:72b-instruct"),
|
||||
],
|
||||
)
|
||||
def test_rewrites_ollama_to_ollama_chat(self, input_model, expected):
|
||||
assert _ensure_ollama_chat_prefix(input_model) == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"ollama_chat/llama3",
|
||||
"gpt-4o-mini",
|
||||
"anthropic/claude-3-haiku",
|
||||
"gemini/gemini-1.5-flash",
|
||||
"",
|
||||
],
|
||||
)
|
||||
def test_leaves_non_ollama_prefix_unchanged(self, model):
|
||||
assert _ensure_ollama_chat_prefix(model) == model
|
||||
|
||||
|
||||
class TestGetLlmExtraKwargsOllama:
|
||||
"""Tests for num_ctx injection via get_llm_extra_kwargs() for Ollama."""
|
||||
|
||||
def test_ollama_provider_returns_num_ctx(self):
|
||||
"""Ollama config should inject num_ctx with default 16384."""
|
||||
config = {
|
||||
"llm": {"provider": "ollama", "model": "ollama/llama3"},
|
||||
}
|
||||
with patch("framework.config.get_hive_config", return_value=config):
|
||||
result = get_llm_extra_kwargs()
|
||||
assert result == {"num_ctx": 16384}
|
||||
|
||||
def test_ollama_provider_respects_custom_num_ctx(self):
|
||||
"""User-specified num_ctx in config should take precedence."""
|
||||
config = {
|
||||
"llm": {"provider": "ollama", "model": "ollama/llama3", "num_ctx": 32768},
|
||||
}
|
||||
with patch("framework.config.get_hive_config", return_value=config):
|
||||
result = get_llm_extra_kwargs()
|
||||
assert result == {"num_ctx": 32768}
|
||||
|
||||
def test_non_ollama_provider_returns_empty(self):
|
||||
"""Non-Ollama provider without subscriptions should return empty dict."""
|
||||
config = {
|
||||
"llm": {"provider": "anthropic", "model": "claude-3-haiku"},
|
||||
}
|
||||
with patch("framework.config.get_hive_config", return_value=config):
|
||||
result = get_llm_extra_kwargs()
|
||||
assert result == {}
|
||||
|
||||
def test_empty_config_returns_empty(self):
|
||||
"""Missing config should return empty dict."""
|
||||
with patch("framework.config.get_hive_config", return_value={}):
|
||||
result = get_llm_extra_kwargs()
|
||||
assert result == {}
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
"""Tests for MCP structured error formatting."""
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.runner.mcp_errors import (
|
||||
MCPAuthError,
|
||||
MCPError,
|
||||
MCPErrorCode,
|
||||
MCPToolNotFoundError,
|
||||
)
|
||||
|
||||
|
||||
def test_mcp_error_code_stored():
|
||||
err = MCPError(
|
||||
code=MCPErrorCode.MCP_AUTH_MISSING,
|
||||
what="Could not connect to server 'jira'",
|
||||
why="JIRA_API_TOKEN is not set",
|
||||
fix="Run: hive mcp config jira --set JIRA_API_TOKEN=<token>",
|
||||
)
|
||||
assert err.code == MCPErrorCode.MCP_AUTH_MISSING
|
||||
|
||||
|
||||
def test_mcp_error_message_format():
|
||||
err = MCPError(
|
||||
code=MCPErrorCode.MCP_AUTH_MISSING,
|
||||
what="Could not connect to server 'jira'",
|
||||
why="JIRA_API_TOKEN is not set",
|
||||
fix="Run: hive mcp config jira --set JIRA_API_TOKEN=<token>",
|
||||
)
|
||||
expected = (
|
||||
"[MCP_AUTH_MISSING]\n"
|
||||
"What failed: Could not connect to server 'jira'\n"
|
||||
"Why: JIRA_API_TOKEN is not set\n"
|
||||
"Fix: Run: hive mcp config jira --set JIRA_API_TOKEN=<token>"
|
||||
)
|
||||
assert str(err) == expected
|
||||
|
||||
|
||||
def test_mcp_tool_not_found_error():
|
||||
err = MCPToolNotFoundError(server="github", tool_name="create_pr")
|
||||
assert err.code == MCPErrorCode.MCP_TOOL_NOT_FOUND
|
||||
assert "create_pr" in str(err)
|
||||
assert "github" in str(err)
|
||||
|
||||
|
||||
def test_mcp_auth_error():
|
||||
err = MCPAuthError(server="jira", env_var="JIRA_API_TOKEN")
|
||||
assert err.code == MCPErrorCode.MCP_AUTH_MISSING
|
||||
assert "JIRA_API_TOKEN" in str(err)
|
||||
|
||||
|
||||
def test_mcp_client_raises_structured_error_for_missing_tool():
|
||||
from framework.runner.mcp_client import MCPClient, MCPServerConfig
|
||||
|
||||
config = MCPServerConfig(name="test-server", transport="stdio")
|
||||
client = MCPClient(config)
|
||||
client._connected = True
|
||||
client._tools = {} # empty — no tools registered
|
||||
|
||||
with pytest.raises(MCPToolNotFoundError) as exc_info:
|
||||
client.call_tool("nonexistent_tool", {})
|
||||
|
||||
assert exc_info.value.code == MCPErrorCode.MCP_TOOL_NOT_FOUND
|
||||
assert "test-server" in str(exc_info.value)
|
||||
assert "nonexistent_tool" in str(exc_info.value)
|
||||
@@ -619,8 +619,9 @@ def test_load_agent_selection(tmp_path: Path):
|
||||
agent_dir = tmp_path / "agent"
|
||||
agent_dir.mkdir()
|
||||
(agent_dir / "mcp_registry.json").write_text(json.dumps({"include": ["jira", "slack"]}))
|
||||
dicts = registry.load_agent_selection(agent_dir)
|
||||
assert len(dicts) == 2 and all("transport" in d for d in dicts)
|
||||
dicts, max_tools = registry.load_agent_selection(agent_dir)
|
||||
assert len(dicts) == 2 and max_tools is None
|
||||
assert all("transport" in d for d in dicts)
|
||||
|
||||
|
||||
def test_load_agent_selection_no_file(tmp_path: Path):
|
||||
@@ -628,7 +629,7 @@ def test_load_agent_selection_no_file(tmp_path: Path):
|
||||
registry.initialize()
|
||||
agent_dir = tmp_path / "agent"
|
||||
agent_dir.mkdir()
|
||||
assert registry.load_agent_selection(agent_dir) == []
|
||||
assert registry.load_agent_selection(agent_dir) == ([], None)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -648,9 +649,9 @@ def test_load_agent_selection_rejects_wrong_types(tmp_path: Path, field, bad_val
|
||||
agent_dir = tmp_path / "agent"
|
||||
agent_dir.mkdir()
|
||||
(agent_dir / "mcp_registry.json").write_text(json.dumps({field: bad_value}))
|
||||
configs = registry.load_agent_selection(agent_dir)
|
||||
configs, max_tools = registry.load_agent_selection(agent_dir)
|
||||
# All bad fields are dropped, so resolve_for_agent gets no criteria and returns []
|
||||
assert configs == []
|
||||
assert configs == [] and max_tools is None
|
||||
|
||||
|
||||
# ── run_health_check ────────────────────────────────────────────────
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,140 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from framework.runner.mcp_client import MCPTool
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
|
||||
|
||||
def _patch_connection_manager_for_fake_stdio(monkeypatch, tool_map: dict[str, list[str]]) -> None:
|
||||
"""Avoid spawning real stdio MCP processes; return in-memory clients per server name."""
|
||||
|
||||
class FakeMCPClient:
|
||||
def __init__(self, config: Any):
|
||||
self.config = config
|
||||
|
||||
def connect(self) -> None:
|
||||
return
|
||||
|
||||
def disconnect(self) -> None:
|
||||
return
|
||||
|
||||
def list_tools(self) -> list[MCPTool]:
|
||||
names = tool_map.get(self.config.name, [])
|
||||
return [_make_tool(n, self.config.name) for n in names]
|
||||
|
||||
def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
class FakeManager:
|
||||
def acquire(self, config: Any) -> FakeMCPClient:
|
||||
return FakeMCPClient(config)
|
||||
|
||||
def release(self, _server_name: str) -> None:
|
||||
return
|
||||
|
||||
monkeypatch.setattr(
|
||||
"framework.runner.mcp_connection_manager.MCPConnectionManager.get_instance",
|
||||
lambda: FakeManager(),
|
||||
)
|
||||
|
||||
|
||||
def _make_tool(name: str, server_name: str) -> MCPTool:
|
||||
return MCPTool(
|
||||
name=name,
|
||||
description=f"{name} from {server_name}",
|
||||
input_schema={"type": "object", "properties": {}, "required": []},
|
||||
server_name=server_name,
|
||||
)
|
||||
|
||||
|
||||
def test_registry_first_wins_collisions(monkeypatch):
|
||||
"""
|
||||
When multiple registry servers expose the same tool name, the first server
|
||||
in load order should win and later servers should not overwrite it.
|
||||
"""
|
||||
|
||||
tool_map: dict[str, list[str]] = {
|
||||
"s1": ["tool_common", "tool_hive"],
|
||||
"s2": ["tool_common", "tool_coder"],
|
||||
}
|
||||
_patch_connection_manager_for_fake_stdio(monkeypatch, tool_map)
|
||||
|
||||
resolved_servers = [
|
||||
{"name": "s1", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
|
||||
{"name": "s2", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
|
||||
]
|
||||
|
||||
registry = ToolRegistry()
|
||||
registry.load_registry_servers(
|
||||
resolved_servers,
|
||||
log_summary=False,
|
||||
preserve_existing_tools=True,
|
||||
log_collisions=True,
|
||||
)
|
||||
|
||||
assert registry.has_tool("tool_common") is True
|
||||
assert registry.has_tool("tool_hive") is True
|
||||
assert registry.has_tool("tool_coder") is True
|
||||
|
||||
assert registry.get_server_tool_names("s1") == {"tool_common", "tool_hive"}
|
||||
assert registry.get_server_tool_names("s2") == {"tool_coder"}
|
||||
|
||||
|
||||
def test_registry_precedence_over_existing_mcp_servers(monkeypatch):
|
||||
"""Registry-loaded tools should not overwrite already registered MCP tools."""
|
||||
|
||||
tool_map: dict[str, list[str]] = {
|
||||
"pre": ["tool_common", "tool_pre"],
|
||||
"s1": ["tool_common", "tool_hive"],
|
||||
"s2": ["tool_common", "tool_coder"],
|
||||
}
|
||||
_patch_connection_manager_for_fake_stdio(monkeypatch, tool_map)
|
||||
|
||||
resolved_servers = [
|
||||
{"name": "s1", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
|
||||
{"name": "s2", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
|
||||
]
|
||||
|
||||
registry = ToolRegistry()
|
||||
registry.register_mcp_server(
|
||||
{"name": "pre", "transport": "stdio", "command": "fake", "args": [], "cwd": None}
|
||||
)
|
||||
|
||||
registry.load_registry_servers(
|
||||
resolved_servers,
|
||||
log_summary=False,
|
||||
preserve_existing_tools=True,
|
||||
log_collisions=True,
|
||||
)
|
||||
|
||||
assert registry.get_server_tool_names("pre") == {"tool_common", "tool_pre"}
|
||||
assert registry.get_server_tool_names("s1") == {"tool_hive"}
|
||||
assert registry.get_server_tool_names("s2") == {"tool_coder"}
|
||||
|
||||
|
||||
def test_registry_max_tools_cap(monkeypatch):
|
||||
"""max_tools caps the total number of newly added tools from registry servers."""
|
||||
|
||||
tool_map: dict[str, list[str]] = {
|
||||
"s1": ["tool_a", "tool_b"],
|
||||
"s2": ["tool_c"],
|
||||
}
|
||||
_patch_connection_manager_for_fake_stdio(monkeypatch, tool_map)
|
||||
|
||||
resolved_servers = [
|
||||
{"name": "s1", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
|
||||
{"name": "s2", "transport": "stdio", "command": "fake", "args": [], "cwd": None},
|
||||
]
|
||||
|
||||
registry = ToolRegistry()
|
||||
registry.load_registry_servers(
|
||||
resolved_servers,
|
||||
log_summary=False,
|
||||
preserve_existing_tools=True,
|
||||
max_tools=2,
|
||||
)
|
||||
|
||||
assert registry.has_tool("tool_a") is True
|
||||
assert registry.has_tool("tool_b") is True
|
||||
assert registry.has_tool("tool_c") is False
|
||||
@@ -1,7 +1,8 @@
|
||||
"""
|
||||
Tests for path traversal vulnerability fix in FileStorage.
|
||||
Tests for path traversal vulnerability protection in ConcurrentStorage.
|
||||
|
||||
Verifies that the _validate_key() method properly blocks path traversal attempts.
|
||||
Verifies that the _validate_key() method properly blocks path traversal
|
||||
attempts and that the public storage API enforces these checks end-to-end.
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
@@ -9,23 +10,22 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.storage.backend import FileStorage
|
||||
from framework.storage.concurrent import ConcurrentStorage
|
||||
|
||||
|
||||
class TestPathTraversalProtection:
|
||||
"""Tests for path traversal vulnerability protection."""
|
||||
"""Tests for path traversal vulnerability protection in ConcurrentStorage."""
|
||||
|
||||
@pytest.fixture
|
||||
def storage(self):
|
||||
"""Create a temporary storage instance for testing."""
|
||||
"""Create a temporary ConcurrentStorage instance for testing."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield FileStorage(tmpdir)
|
||||
yield ConcurrentStorage(tmpdir)
|
||||
|
||||
# === VALID KEYS (should pass validation) ===
|
||||
|
||||
def test_valid_alphanumeric_key(self, storage):
|
||||
"""Alphanumeric keys should be allowed."""
|
||||
# Should not raise
|
||||
storage._validate_key("goal_123")
|
||||
storage._validate_key("run_abc_def")
|
||||
storage._validate_key("status_completed")
|
||||
@@ -40,7 +40,6 @@ class TestPathTraversalProtection:
|
||||
|
||||
def test_blocks_parent_directory_traversal(self, storage):
|
||||
"""Block .. path traversal attempts."""
|
||||
# These all have path separators which are blocked first
|
||||
with pytest.raises(ValueError):
|
||||
storage._validate_key("../../../etc/passwd")
|
||||
|
||||
@@ -55,13 +54,12 @@ class TestPathTraversalProtection:
|
||||
with pytest.raises(ValueError, match="path traversal detected"):
|
||||
storage._validate_key(".env")
|
||||
|
||||
# This also has path separator which is caught first
|
||||
# Also has a path separator which is caught first
|
||||
with pytest.raises(ValueError):
|
||||
storage._validate_key(".ssh/id_rsa")
|
||||
|
||||
def test_blocks_absolute_paths_unix(self, storage):
|
||||
"""Block absolute paths (Unix)."""
|
||||
# These have path separators which are blocked first
|
||||
with pytest.raises(ValueError):
|
||||
storage._validate_key("/etc/passwd")
|
||||
|
||||
@@ -70,7 +68,6 @@ class TestPathTraversalProtection:
|
||||
|
||||
def test_blocks_absolute_paths_windows(self, storage):
|
||||
"""Block absolute paths (Windows)."""
|
||||
# These have path separators which are blocked first
|
||||
with pytest.raises(ValueError):
|
||||
storage._validate_key("C:\\Windows\\System32")
|
||||
|
||||
@@ -115,68 +112,76 @@ class TestPathTraversalProtection:
|
||||
with pytest.raises(ValueError, match="empty"):
|
||||
storage._validate_key(" ")
|
||||
|
||||
# === END-TO-END TESTS ===
|
||||
# === END-TO-END TESTS (public API enforces validation) ===
|
||||
|
||||
def test_get_runs_by_goal_blocks_traversal(self, storage):
|
||||
"""get_runs_by_goal() should block path traversal."""
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_run_blocks_traversal(self, storage):
|
||||
"""load_run() must reject path traversal in the run_id."""
|
||||
with pytest.raises(ValueError):
|
||||
storage.get_runs_by_goal("../../../.env")
|
||||
await storage.load_run("../../../.env")
|
||||
|
||||
def test_get_runs_by_node_blocks_traversal(self, storage):
|
||||
"""get_runs_by_node() should block path traversal."""
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_run_valid_id_returns_none(self, storage):
|
||||
"""A valid but nonexistent run_id returns None, not an error."""
|
||||
result = await storage.load_run("legitimate_run_id", use_cache=False)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_run_blocks_traversal(self, storage):
|
||||
"""delete_run() must reject path traversal in the run_id."""
|
||||
with pytest.raises(ValueError):
|
||||
storage.get_runs_by_node("/etc/passwd")
|
||||
await storage.delete_run("../etc/passwd")
|
||||
|
||||
def test_get_runs_by_status_blocks_traversal(self, storage):
|
||||
"""get_runs_by_status() should block path traversal."""
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_summary_blocks_traversal(self, storage):
|
||||
"""load_summary() must reject path traversal in the run_id."""
|
||||
with pytest.raises(ValueError):
|
||||
storage.get_runs_by_status("..\\..\\windows\\system32")
|
||||
await storage.load_summary("../../../.env")
|
||||
|
||||
def test_valid_queries_still_work(self, storage):
|
||||
"""Valid queries should work after fix."""
|
||||
# These should return empty list, not raise errors
|
||||
result = storage.get_runs_by_goal("legitimate_goal")
|
||||
assert result == []
|
||||
|
||||
result = storage.get_runs_by_node("legitimate_node")
|
||||
assert result == []
|
||||
|
||||
result = storage.get_runs_by_status("completed")
|
||||
assert result == []
|
||||
|
||||
# === REAL-WORLD ATTACK SCENARIOS ===
|
||||
|
||||
def test_blocks_env_file_escape(self, storage):
|
||||
"""Block attempts to access .env files."""
|
||||
def test_load_run_sync_blocks_traversal(self, storage):
|
||||
"""load_run_sync() must reject path traversal in the run_id."""
|
||||
with pytest.raises(ValueError):
|
||||
storage.get_runs_by_goal("../../../.env")
|
||||
storage.load_run_sync("../../../.env")
|
||||
|
||||
def test_blocks_config_file_escape(self, storage):
|
||||
"""Block attempts to access config files."""
|
||||
with pytest.raises(ValueError):
|
||||
storage.get_runs_by_goal("../../../../etc/aden/database.yaml")
|
||||
def test_save_run_sync_blocks_traversal(self, storage):
|
||||
"""save_run_sync() must reject path traversal in the run_id."""
|
||||
from framework.schemas.run import Run
|
||||
|
||||
def test_blocks_web_shell_creation(self, storage):
|
||||
"""Block attempts to create web shells."""
|
||||
run = Run(id="../../../.env", goal_id="test", goal_description="", input_data={})
|
||||
with pytest.raises(ValueError):
|
||||
storage._add_to_index("by_goal", "../../var/www/html/shell", "malicious_code")
|
||||
storage.save_run_sync(run)
|
||||
|
||||
def test_blocks_cron_injection(self, storage):
|
||||
"""Block attempts to create cron jobs."""
|
||||
with pytest.raises(ValueError):
|
||||
storage._add_to_index("by_node", "../../../etc/cron.d/backdoor", "reverse_shell")
|
||||
def test_load_run_sync_valid_id_returns_none(self, storage):
|
||||
"""load_run_sync with a legitimate nonexistent ID returns None."""
|
||||
result = storage.load_run_sync("legitimate_run_id")
|
||||
assert result is None
|
||||
|
||||
def test_blocks_sudoers_modification(self, storage):
|
||||
"""Block attempts to modify sudoers file."""
|
||||
# === REAL-WORLD ATTACK SCENARIOS (end-to-end) ===
|
||||
|
||||
def test_blocks_env_file_escape_via_load_sync(self, storage):
|
||||
"""Block attempts to read .env files via load_run_sync."""
|
||||
with pytest.raises(ValueError):
|
||||
storage._add_to_index("by_status", "../../../../etc/sudoers", "ALL=(ALL) NOPASSWD:ALL")
|
||||
storage.load_run_sync("../../../.env")
|
||||
|
||||
def test_blocks_config_file_escape_via_load_sync(self, storage):
|
||||
"""Block attempts to access config files via load_run_sync."""
|
||||
with pytest.raises(ValueError):
|
||||
storage.load_run_sync("../../../../etc/aden/database.yaml")
|
||||
|
||||
def test_blocks_arbitrary_write_via_save_sync(self, storage):
|
||||
"""Block attempts to write arbitrary files via save_run_sync."""
|
||||
from framework.schemas.run import Run
|
||||
|
||||
run = Run(id="../../var/www/html/shell", goal_id="test", goal_description="", input_data={})
|
||||
with pytest.raises(ValueError):
|
||||
storage.save_run_sync(run)
|
||||
|
||||
|
||||
class TestPathTraversalWithActualFiles:
|
||||
"""Test path traversal protection with actual file operations."""
|
||||
|
||||
def test_cannot_escape_storage_directory(self):
|
||||
"""Verify that even with path traversal, we can't escape storage dir."""
|
||||
"""Verify that path traversal is caught before any filesystem access."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
storage_dir = tmpdir_path / "storage"
|
||||
@@ -186,31 +191,61 @@ class TestPathTraversalWithActualFiles:
|
||||
secret_file = tmpdir_path / "secret.txt"
|
||||
secret_file.write_text("SENSITIVE_DATA", encoding="utf-8")
|
||||
|
||||
storage = FileStorage(storage_dir)
|
||||
storage = ConcurrentStorage(storage_dir)
|
||||
|
||||
# Attempt to read the secret file via path traversal
|
||||
# Attempt to read the secret file via path traversal — must raise
|
||||
with pytest.raises(ValueError):
|
||||
storage.get_runs_by_goal("../secret")
|
||||
storage.load_run_sync("../secret")
|
||||
|
||||
# Verify the secret file was not accessed (still contains original data)
|
||||
# Verify the secret file was not accessed
|
||||
assert secret_file.read_text(encoding="utf-8") == "SENSITIVE_DATA"
|
||||
|
||||
def test_cannot_write_outside_storage(self):
|
||||
"""Verify that we can't write files outside storage directory."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
storage_dir = tmpdir_path / "storage"
|
||||
storage_dir.mkdir()
|
||||
def test_save_and_load_roundtrip(self, tmp_path):
|
||||
"""Verify save_run_sync/load_run_sync roundtrip works correctly."""
|
||||
from framework.schemas.run import Run, RunStatus
|
||||
|
||||
storage = FileStorage(storage_dir)
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
run = Run(
|
||||
id="run_test_123",
|
||||
goal_id="goal_abc",
|
||||
goal_description="Integration test",
|
||||
input_data={},
|
||||
)
|
||||
run.complete(RunStatus.COMPLETED, "done")
|
||||
|
||||
# Attempt to write outside storage directory
|
||||
with pytest.raises(ValueError):
|
||||
storage._add_to_index("by_goal", "../../malicious", "payload")
|
||||
storage.save_run_sync(run)
|
||||
|
||||
# Verify no file was created outside storage
|
||||
malicious_file = tmpdir_path / "malicious.json"
|
||||
assert not malicious_file.exists()
|
||||
loaded = storage.load_run_sync("run_test_123")
|
||||
assert loaded is not None
|
||||
assert loaded.id == "run_test_123"
|
||||
assert loaded.status == RunStatus.COMPLETED
|
||||
|
||||
# Verify the file is at the expected path
|
||||
run_file = tmp_path / "runs" / "run_test_123.json"
|
||||
assert run_file.exists()
|
||||
|
||||
|
||||
class TestSessionStorePathTraversal:
|
||||
"""Path traversal protection in SessionStore.get_session_path()."""
|
||||
|
||||
@pytest.fixture
|
||||
def store(self, tmp_path):
|
||||
from framework.storage.session_store import SessionStore
|
||||
|
||||
return SessionStore(tmp_path)
|
||||
|
||||
def test_valid_session_id(self, store):
|
||||
path = store.get_session_path("session_20260206_143022_abc12345")
|
||||
assert path.name == "session_20260206_143022_abc12345"
|
||||
|
||||
def test_blocks_parent_traversal(self, store):
|
||||
with pytest.raises(ValueError, match="Invalid session ID"):
|
||||
store.get_session_path("../../etc/passwd")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_session_blocks_traversal(self, store):
|
||||
with pytest.raises(ValueError, match="Invalid session ID"):
|
||||
await store.delete_session("../../package")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
from datetime import date
|
||||
|
||||
from framework.agents.queen import queen_memory
|
||||
from framework.tools.queen_memory_tools import recall_diary
|
||||
|
||||
|
||||
def test_format_memory_date_uses_unpadded_day() -> None:
|
||||
assert queen_memory.format_memory_date(date(2026, 3, 7)) == "March 7, 2026"
|
||||
|
||||
|
||||
def test_format_for_injection_formats_recent_memory(monkeypatch) -> None:
|
||||
monkeypatch.setattr(queen_memory, "read_semantic_memory", lambda: "")
|
||||
monkeypatch.setattr(
|
||||
queen_memory,
|
||||
"_find_recent_episodic",
|
||||
lambda lookback=7: (date(2026, 3, 7), "Remembered context."),
|
||||
)
|
||||
|
||||
result = queen_memory.format_for_injection()
|
||||
|
||||
assert "## March 7, 2026" in result
|
||||
assert "Remembered context." in result
|
||||
|
||||
|
||||
def test_recall_diary_formats_today_without_platform_specific_strftime(monkeypatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
queen_memory,
|
||||
"read_episodic_memory",
|
||||
lambda d=None: "Today's note." if d == date.today() else "",
|
||||
)
|
||||
|
||||
result = recall_diary(days_back=1)
|
||||
|
||||
assert "## Today" in result
|
||||
assert "Today's note." in result
|
||||
@@ -37,20 +37,21 @@ class TestRuntimeBasics:
|
||||
runtime.end_run(success=True)
|
||||
assert runtime.current_run is None
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="FileStorage.save_run() is deprecated and now a no-op. "
|
||||
"New sessions use unified storage at sessions/{session_id}/state.json"
|
||||
)
|
||||
def test_run_saved_on_end(self, tmp_path: Path):
|
||||
"""Run is saved to storage when ended."""
|
||||
"""Run is persisted to disk when ended.
|
||||
|
||||
ConcurrentStorage.save_run_sync() writes to runs/{run_id}.json
|
||||
via an atomic temp-file+rename. This is the primary guardrail
|
||||
ensuring end_run() does not silently discard completed runs.
|
||||
"""
|
||||
runtime = Runtime(tmp_path)
|
||||
|
||||
run_id = runtime.start_run("test_goal", "Test")
|
||||
runtime.end_run(success=True)
|
||||
|
||||
# Check file exists
|
||||
# ConcurrentStorage writes to {base_path}/runs/{run_id}.json
|
||||
run_file = tmp_path / "runs" / f"{run_id}.json"
|
||||
assert run_file.exists()
|
||||
assert run_file.exists(), f"Expected persisted run at {run_file}"
|
||||
|
||||
|
||||
class TestDecisionRecording:
|
||||
@@ -346,7 +347,7 @@ class TestNarrativeGeneration:
|
||||
"""Test automatic narrative generation."""
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="FileStorage.save_run() and get_runs_by_goal() are deprecated. "
|
||||
reason="save_run() and get_runs_by_goal() are deprecated. "
|
||||
"New sessions use unified storage at sessions/{session_id}/state.json"
|
||||
)
|
||||
def test_default_narrative_success(self, tmp_path: Path):
|
||||
@@ -369,7 +370,7 @@ class TestNarrativeGeneration:
|
||||
assert "completed successfully" in run.narrative
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="FileStorage.save_run() and get_runs_by_goal() are deprecated. "
|
||||
reason="save_run() and get_runs_by_goal() are deprecated. "
|
||||
"New sessions use unified storage at sessions/{session_id}/state.json"
|
||||
)
|
||||
def test_default_narrative_failure(self, tmp_path: Path):
|
||||
|
||||
@@ -0,0 +1,579 @@
|
||||
"""Integration tests for hive skill CLI command handlers.
|
||||
|
||||
Uses argparse.Namespace objects directly (not argv parsing) for concise tests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from framework.skills.cli import (
|
||||
cmd_skill_doctor,
|
||||
cmd_skill_info,
|
||||
cmd_skill_init,
|
||||
cmd_skill_install,
|
||||
cmd_skill_list,
|
||||
cmd_skill_remove,
|
||||
cmd_skill_search,
|
||||
cmd_skill_test,
|
||||
cmd_skill_validate,
|
||||
)
|
||||
|
||||
|
||||
def _make_valid_skill(parent: Path, name: str) -> Path:
|
||||
"""Create a minimal valid skill in parent/name/SKILL.md."""
|
||||
d = parent / name
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
(d / "SKILL.md").write_text(
|
||||
f"---\nname: {name}\ndescription: A test skill.\nlicense: MIT\n---\n\n## Body\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
return d
|
||||
|
||||
|
||||
class TestCmdSkillInit:
|
||||
def test_creates_skill_md(self, tmp_path):
|
||||
args = Namespace(skill_name="test-skill", target_dir=str(tmp_path))
|
||||
result = cmd_skill_init(args)
|
||||
assert result == 0
|
||||
assert (tmp_path / "test-skill" / "SKILL.md").exists()
|
||||
|
||||
def test_skill_md_contains_name(self, tmp_path):
|
||||
args = Namespace(skill_name="my-skill", target_dir=str(tmp_path))
|
||||
cmd_skill_init(args)
|
||||
content = (tmp_path / "my-skill" / "SKILL.md").read_text()
|
||||
assert "name: my-skill" in content
|
||||
|
||||
def test_error_when_dir_exists(self, tmp_path, capsys):
|
||||
(tmp_path / "existing").mkdir()
|
||||
args = Namespace(skill_name="existing", target_dir=str(tmp_path))
|
||||
result = cmd_skill_init(args)
|
||||
assert result == 1
|
||||
assert "already exists" in capsys.readouterr().err
|
||||
|
||||
def test_error_when_no_name(self, tmp_path, monkeypatch, capsys):
|
||||
# Non-interactive (stdin not a tty in test env) → error
|
||||
monkeypatch.setattr("sys.stdin.isatty", lambda: False)
|
||||
args = Namespace(skill_name=None, target_dir=str(tmp_path))
|
||||
result = cmd_skill_init(args)
|
||||
assert result == 1
|
||||
|
||||
|
||||
class TestCmdSkillValidate:
|
||||
def test_exits_0_on_valid_skill(self, tmp_path):
|
||||
skill_dir = _make_valid_skill(tmp_path, "my-skill")
|
||||
args = Namespace(path=str(skill_dir / "SKILL.md"))
|
||||
result = cmd_skill_validate(args)
|
||||
assert result == 0
|
||||
|
||||
def test_accepts_directory_path(self, tmp_path):
|
||||
skill_dir = _make_valid_skill(tmp_path, "my-skill")
|
||||
args = Namespace(path=str(skill_dir))
|
||||
result = cmd_skill_validate(args)
|
||||
assert result == 0
|
||||
|
||||
def test_exits_1_on_invalid_skill(self, tmp_path, capsys):
|
||||
skill_dir = tmp_path / "bad-skill"
|
||||
skill_dir.mkdir()
|
||||
(skill_dir / "SKILL.md").write_text("no frontmatter here", encoding="utf-8")
|
||||
args = Namespace(path=str(skill_dir / "SKILL.md"))
|
||||
result = cmd_skill_validate(args)
|
||||
assert result == 1
|
||||
assert "[ERROR]" in capsys.readouterr().out
|
||||
|
||||
def test_shows_warnings_on_valid_skill_without_license(self, tmp_path, capsys):
|
||||
skill_dir = tmp_path / "my-skill"
|
||||
skill_dir.mkdir()
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\nname: my-skill\ndescription: No license.\n---\n\n## Body\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
args = Namespace(path=str(skill_dir / "SKILL.md"))
|
||||
result = cmd_skill_validate(args)
|
||||
assert result == 0
|
||||
assert "[WARN]" in capsys.readouterr().out
|
||||
|
||||
|
||||
class TestCmdSkillDoctor:
|
||||
def test_defaults_pass_against_real_framework_skills(self):
|
||||
"""All 6 framework default skills should be healthy (no mocking)."""
|
||||
args = Namespace(defaults=True, name=None, project_dir=None)
|
||||
result = cmd_skill_doctor(args)
|
||||
assert result == 0
|
||||
|
||||
def test_named_skill_not_found_exits_1(self, tmp_path, capsys):
|
||||
args = Namespace(name="nonexistent-skill", defaults=False, project_dir=str(tmp_path))
|
||||
result = cmd_skill_doctor(args)
|
||||
assert result == 1
|
||||
assert "not found" in capsys.readouterr().err
|
||||
|
||||
def test_healthy_skill_exits_0(self, tmp_path):
|
||||
_make_valid_skill(tmp_path, "my-skill")
|
||||
args = Namespace(name=None, defaults=False, project_dir=str(tmp_path))
|
||||
with patch("framework.skills.discovery.SkillDiscovery.discover") as mock_discover:
|
||||
from framework.skills.parser import ParsedSkill
|
||||
|
||||
mock_discover.return_value = [
|
||||
ParsedSkill(
|
||||
name="my-skill",
|
||||
description="Test.",
|
||||
location=str(tmp_path / "my-skill" / "SKILL.md"),
|
||||
base_dir=str(tmp_path / "my-skill"),
|
||||
source_scope="user",
|
||||
body="## Body",
|
||||
)
|
||||
]
|
||||
result = cmd_skill_doctor(args)
|
||||
assert result == 0
|
||||
|
||||
|
||||
class TestCmdSkillInstall:
|
||||
def test_shows_security_notice_on_first_use(self, tmp_path, monkeypatch, capsys):
|
||||
sentinel = tmp_path / ".install_notice_shown"
|
||||
monkeypatch.setattr("framework.skills.installer.INSTALL_NOTICE_SENTINEL", sentinel)
|
||||
|
||||
installed_path = tmp_path / "skills" / "my-skill"
|
||||
installed_path.mkdir(parents=True)
|
||||
|
||||
args = Namespace(
|
||||
name_or_url=None,
|
||||
from_url="https://example.com/skill.git",
|
||||
pack=None,
|
||||
install_name="my-skill",
|
||||
version=None,
|
||||
)
|
||||
|
||||
with patch("framework.skills.installer.install_from_git", return_value=installed_path):
|
||||
with patch("shutil.which", return_value="/usr/bin/git"):
|
||||
result = cmd_skill_install(args)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Security Notice" in captured.out
|
||||
assert result == 0
|
||||
|
||||
def test_install_from_url_calls_install_from_git(self, tmp_path, monkeypatch):
|
||||
sentinel = tmp_path / ".install_notice_shown"
|
||||
sentinel.parent.mkdir(parents=True, exist_ok=True)
|
||||
sentinel.touch()
|
||||
monkeypatch.setattr("framework.skills.installer.INSTALL_NOTICE_SENTINEL", sentinel)
|
||||
|
||||
installed_path = tmp_path / "skills" / "my-skill"
|
||||
installed_path.mkdir(parents=True)
|
||||
|
||||
args = Namespace(
|
||||
name_or_url=None,
|
||||
from_url="https://github.com/org/my-skill.git",
|
||||
pack=None,
|
||||
install_name=None,
|
||||
version=None,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"framework.skills.installer.install_from_git", return_value=installed_path
|
||||
) as mock_install:
|
||||
result = cmd_skill_install(args)
|
||||
|
||||
mock_install.assert_called_once()
|
||||
assert result == 0
|
||||
|
||||
def test_registry_not_found_exits_1(self, tmp_path, monkeypatch, capsys):
|
||||
sentinel = tmp_path / ".install_notice_shown"
|
||||
sentinel.parent.mkdir(parents=True, exist_ok=True)
|
||||
sentinel.touch()
|
||||
monkeypatch.setattr("framework.skills.installer.INSTALL_NOTICE_SENTINEL", sentinel)
|
||||
|
||||
args = Namespace(
|
||||
name_or_url="nonexistent-skill",
|
||||
from_url=None,
|
||||
pack=None,
|
||||
install_name=None,
|
||||
version=None,
|
||||
)
|
||||
|
||||
with patch("framework.skills.registry.RegistryClient.get_skill_entry", return_value=None):
|
||||
result = cmd_skill_install(args)
|
||||
|
||||
assert result == 1
|
||||
assert "not found in registry" in capsys.readouterr().err
|
||||
|
||||
def test_no_args_exits_1(self, tmp_path, monkeypatch, capsys):
|
||||
sentinel = tmp_path / ".install_notice_shown"
|
||||
sentinel.parent.mkdir(parents=True, exist_ok=True)
|
||||
sentinel.touch()
|
||||
monkeypatch.setattr("framework.skills.installer.INSTALL_NOTICE_SENTINEL", sentinel)
|
||||
|
||||
args = Namespace(
|
||||
name_or_url=None, from_url=None, pack=None, install_name=None, version=None
|
||||
)
|
||||
result = cmd_skill_install(args)
|
||||
assert result == 1
|
||||
|
||||
|
||||
class TestCmdSkillRemove:
|
||||
def test_removes_installed_skill(self, tmp_path, capsys):
|
||||
skills_dir = tmp_path / "skills"
|
||||
skill_dir = skills_dir / "my-skill"
|
||||
skill_dir.mkdir(parents=True)
|
||||
|
||||
with patch("framework.skills.installer.USER_SKILLS_DIR", skills_dir):
|
||||
with patch("framework.skills.installer.remove_skill", return_value=True):
|
||||
args = Namespace(name="my-skill")
|
||||
result = cmd_skill_remove(args)
|
||||
|
||||
assert result == 0
|
||||
assert "Removed" in capsys.readouterr().out
|
||||
|
||||
def test_exits_1_when_not_found(self, tmp_path, capsys):
|
||||
with patch("framework.skills.installer.remove_skill", return_value=False):
|
||||
args = Namespace(name="missing-skill")
|
||||
result = cmd_skill_remove(args)
|
||||
|
||||
assert result == 1
|
||||
assert "not found" in capsys.readouterr().err
|
||||
|
||||
|
||||
class TestCmdSkillSearch:
|
||||
def test_exits_1_when_registry_unavailable(self, capsys):
|
||||
with patch("framework.skills.registry.RegistryClient.fetch_index", return_value=None):
|
||||
args = Namespace(query="research")
|
||||
result = cmd_skill_search(args)
|
||||
|
||||
assert result == 1
|
||||
assert "registry unavailable" in capsys.readouterr().err.lower()
|
||||
|
||||
def test_prints_results_when_found(self, capsys):
|
||||
mock_index = {
|
||||
"skills": [
|
||||
{
|
||||
"name": "deep-research",
|
||||
"description": "Multi-step research.",
|
||||
"tags": ["research"],
|
||||
"trust_tier": "official",
|
||||
}
|
||||
]
|
||||
}
|
||||
with patch("framework.skills.registry.RegistryClient.fetch_index", return_value=mock_index):
|
||||
args = Namespace(query="research")
|
||||
result = cmd_skill_search(args)
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert result == 0
|
||||
assert "deep-research" in out
|
||||
|
||||
def test_no_results_message(self, capsys):
|
||||
mock_index = {"skills": []}
|
||||
with patch("framework.skills.registry.RegistryClient.fetch_index", return_value=mock_index):
|
||||
args = Namespace(query="xyzzy-nothing")
|
||||
result = cmd_skill_search(args)
|
||||
|
||||
assert result == 0
|
||||
assert "No skills found" in capsys.readouterr().out
|
||||
|
||||
|
||||
class TestCmdSkillInfo:
|
||||
def test_shows_locally_installed_skill(self, tmp_path, capsys):
|
||||
skill_dir = _make_valid_skill(tmp_path, "my-skill")
|
||||
from framework.skills.parser import ParsedSkill
|
||||
|
||||
mock_skill = ParsedSkill(
|
||||
name="my-skill",
|
||||
description="A test skill.",
|
||||
location=str(skill_dir / "SKILL.md"),
|
||||
base_dir=str(skill_dir),
|
||||
source_scope="user",
|
||||
body="## Body",
|
||||
license="MIT",
|
||||
)
|
||||
|
||||
with patch("framework.skills.discovery.SkillDiscovery.discover", return_value=[mock_skill]):
|
||||
args = Namespace(name="my-skill", project_dir=str(tmp_path))
|
||||
result = cmd_skill_info(args)
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert result == 0
|
||||
assert "my-skill" in out
|
||||
assert "A test skill." in out
|
||||
|
||||
def test_falls_back_to_registry_when_not_installed(self, capsys):
|
||||
registry_entry = {
|
||||
"name": "deep-research",
|
||||
"description": "Multi-step research.",
|
||||
"version": "1.0.0",
|
||||
"author": "anthropics",
|
||||
"trust_tier": "official",
|
||||
}
|
||||
|
||||
with patch("framework.skills.discovery.SkillDiscovery.discover", return_value=[]):
|
||||
with patch(
|
||||
"framework.skills.registry.RegistryClient.get_skill_entry",
|
||||
return_value=registry_entry,
|
||||
):
|
||||
args = Namespace(name="deep-research", project_dir=None)
|
||||
result = cmd_skill_info(args)
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert result == 0
|
||||
assert "not installed" in out
|
||||
assert "deep-research" in out
|
||||
|
||||
def test_exits_1_when_not_found_anywhere(self, tmp_path, capsys):
|
||||
with patch("framework.skills.discovery.SkillDiscovery.discover", return_value=[]):
|
||||
with patch(
|
||||
"framework.skills.registry.RegistryClient.get_skill_entry", return_value=None
|
||||
):
|
||||
args = Namespace(name="ghost-skill", project_dir=str(tmp_path))
|
||||
result = cmd_skill_info(args)
|
||||
|
||||
assert result == 1
|
||||
|
||||
|
||||
class TestJsonFlag:
|
||||
def test_list_json_produces_valid_json(self, tmp_path, capsys):
|
||||
args = Namespace(project_dir=str(tmp_path), json=True)
|
||||
with patch("framework.skills.discovery.SkillDiscovery.discover", return_value=[]):
|
||||
result = cmd_skill_list(args)
|
||||
out = capsys.readouterr().out
|
||||
data = json.loads(out)
|
||||
assert result == 0
|
||||
assert "skills" in data
|
||||
assert isinstance(data["skills"], list)
|
||||
|
||||
def test_validate_json_valid_skill(self, tmp_path, capsys):
|
||||
from framework.skills.cli import cmd_skill_validate
|
||||
|
||||
skill_dir = _make_valid_skill(tmp_path, "my-skill")
|
||||
args = Namespace(path=str(skill_dir / "SKILL.md"), json=True)
|
||||
result = cmd_skill_validate(args)
|
||||
out = capsys.readouterr().out
|
||||
data = json.loads(out)
|
||||
assert result == 0
|
||||
assert data["passed"] is True
|
||||
assert data["errors"] == []
|
||||
assert "warnings" in data
|
||||
|
||||
def test_doctor_defaults_json(self, capsys):
|
||||
args = Namespace(defaults=True, name=None, project_dir=None, json=True)
|
||||
result = cmd_skill_doctor(args)
|
||||
out = capsys.readouterr().out
|
||||
data = json.loads(out)
|
||||
assert result == 0
|
||||
assert "skills" in data
|
||||
assert len(data["skills"]) == 6 # 6 framework default skills
|
||||
assert data["total_errors"] == 0
|
||||
|
||||
def test_search_json_registry_unavailable_exits_1(self, capsys):
|
||||
with patch("framework.skills.registry.RegistryClient.fetch_index", return_value=None):
|
||||
args = Namespace(query="research", json=True)
|
||||
result = cmd_skill_search(args)
|
||||
out = capsys.readouterr().out
|
||||
data = json.loads(out)
|
||||
assert result == 1
|
||||
assert "error" in data
|
||||
|
||||
def test_remove_json_not_found_exits_1(self, capsys):
|
||||
with patch("framework.skills.installer.remove_skill", return_value=False):
|
||||
args = Namespace(name="ghost-skill", json=True)
|
||||
result = cmd_skill_remove(args)
|
||||
out = capsys.readouterr().out
|
||||
data = json.loads(out)
|
||||
assert result == 1
|
||||
assert "error" in data
|
||||
|
||||
|
||||
class TestCmdSkillTest:
|
||||
"""Tests for hive skill test (CLI-9)."""
|
||||
|
||||
def test_structural_only_valid_exits_0(self, tmp_path):
|
||||
skill_dir = _make_valid_skill(tmp_path, "my-skill")
|
||||
args = Namespace(path=str(skill_dir), input_json=None, model=None, json=False)
|
||||
result = cmd_skill_test(args)
|
||||
assert result == 0
|
||||
|
||||
def test_structural_invalid_exits_1(self, tmp_path, capsys):
|
||||
skill_dir = tmp_path / "bad-skill"
|
||||
skill_dir.mkdir()
|
||||
(skill_dir / "SKILL.md").write_text("no frontmatter", encoding="utf-8")
|
||||
args = Namespace(path=str(skill_dir), input_json=None, model=None, json=False)
|
||||
result = cmd_skill_test(args)
|
||||
assert result == 1
|
||||
assert "[ERROR]" in capsys.readouterr().out
|
||||
|
||||
def test_invocation_mode_calls_provider_with_skill_body(self, tmp_path):
|
||||
skill_dir = _make_valid_skill(tmp_path, "my-skill")
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from framework.llm.provider import LLMResponse
|
||||
|
||||
mock_response = LLMResponse(content="Hello!", model="claude-haiku-4-5-20251001")
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.complete.return_value = mock_response
|
||||
|
||||
args = Namespace(
|
||||
path=str(skill_dir), input_json='{"prompt": "say hello"}', model=None, json=False
|
||||
)
|
||||
with patch("framework.llm.anthropic.AnthropicProvider", return_value=mock_provider):
|
||||
result = cmd_skill_test(args)
|
||||
|
||||
assert result == 0
|
||||
call_kwargs = mock_provider.complete.call_args
|
||||
assert call_kwargs is not None
|
||||
# system should be the skill body
|
||||
assert "system" in call_kwargs.kwargs or len(call_kwargs.args) >= 2
|
||||
|
||||
def test_invocation_extracts_prompt_from_json(self, tmp_path):
|
||||
skill_dir = _make_valid_skill(tmp_path, "my-skill")
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from framework.llm.provider import LLMResponse
|
||||
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.complete.return_value = LLMResponse(
|
||||
content="response", model="claude-haiku-4-5-20251001"
|
||||
)
|
||||
|
||||
args = Namespace(
|
||||
path=str(skill_dir), input_json='{"prompt": "extracted prompt"}', model=None, json=False
|
||||
)
|
||||
with patch("framework.llm.anthropic.AnthropicProvider", return_value=mock_provider):
|
||||
cmd_skill_test(args)
|
||||
|
||||
call = mock_provider.complete.call_args
|
||||
messages = call.kwargs.get("messages") or (call.args[0] if call.args else [])
|
||||
assert any("extracted prompt" in m.get("content", "") for m in messages)
|
||||
|
||||
def test_eval_suite_all_pass_exits_0(self, tmp_path):
|
||||
skill_dir = _make_valid_skill(tmp_path, "my-skill")
|
||||
evals_dir = skill_dir / "evals"
|
||||
evals_dir.mkdir()
|
||||
(evals_dir / "evals.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"skill_name": "my-skill",
|
||||
"evals": [
|
||||
{"id": 1, "prompt": "Say hi.", "assertions": ["Response is a greeting"]}
|
||||
],
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from framework.llm.provider import LLMResponse
|
||||
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.complete.return_value = LLMResponse(
|
||||
content="Hello!", model="claude-haiku-4-5-20251001"
|
||||
)
|
||||
mock_judge = MagicMock()
|
||||
mock_judge.evaluate.return_value = {"passes": True, "explanation": "Looks good."}
|
||||
|
||||
args = Namespace(path=str(skill_dir), input_json=None, model=None, json=False)
|
||||
with patch("framework.llm.anthropic.AnthropicProvider", return_value=mock_provider):
|
||||
with patch("framework.testing.llm_judge.LLMJudge", return_value=mock_judge):
|
||||
result = cmd_skill_test(args)
|
||||
|
||||
assert result == 0
|
||||
|
||||
def test_eval_any_fail_exits_1(self, tmp_path):
|
||||
skill_dir = _make_valid_skill(tmp_path, "my-skill")
|
||||
evals_dir = skill_dir / "evals"
|
||||
evals_dir.mkdir()
|
||||
(evals_dir / "evals.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"skill_name": "my-skill",
|
||||
"evals": [
|
||||
{"id": 1, "prompt": "Say hi.", "assertions": ["Impossible assertion"]}
|
||||
],
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from framework.llm.provider import LLMResponse
|
||||
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.complete.return_value = LLMResponse(
|
||||
content="Hello!", model="claude-haiku-4-5-20251001"
|
||||
)
|
||||
mock_judge = MagicMock()
|
||||
mock_judge.evaluate.return_value = {"passes": False, "explanation": "Did not satisfy."}
|
||||
|
||||
args = Namespace(path=str(skill_dir), input_json=None, model=None, json=False)
|
||||
with patch("framework.llm.anthropic.AnthropicProvider", return_value=mock_provider):
|
||||
with patch("framework.testing.llm_judge.LLMJudge", return_value=mock_judge):
|
||||
result = cmd_skill_test(args)
|
||||
|
||||
assert result == 1
|
||||
|
||||
def test_json_flag_structural_output(self, tmp_path, capsys):
|
||||
skill_dir = _make_valid_skill(tmp_path, "my-skill")
|
||||
args = Namespace(path=str(skill_dir), input_json=None, model=None, json=True)
|
||||
result = cmd_skill_test(args)
|
||||
out = capsys.readouterr().out
|
||||
data = json.loads(out)
|
||||
assert result == 0
|
||||
assert "structural" in data
|
||||
assert data["structural"]["passed"] is True
|
||||
assert data["skill"] == "my-skill"
|
||||
|
||||
def test_json_flag_eval_results(self, tmp_path, capsys):
|
||||
skill_dir = _make_valid_skill(tmp_path, "my-skill")
|
||||
evals_dir = skill_dir / "evals"
|
||||
evals_dir.mkdir()
|
||||
(evals_dir / "evals.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"skill_name": "my-skill",
|
||||
"evals": [{"id": 1, "prompt": "Hi.", "assertions": ["Is a greeting"]}],
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from framework.llm.provider import LLMResponse
|
||||
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.complete.return_value = LLMResponse(
|
||||
content="Hello!", model="claude-haiku-4-5-20251001"
|
||||
)
|
||||
mock_judge = MagicMock()
|
||||
mock_judge.evaluate.return_value = {"passes": True, "explanation": "Yes."}
|
||||
|
||||
args = Namespace(path=str(skill_dir), input_json=None, model=None, json=True)
|
||||
with patch("framework.llm.anthropic.AnthropicProvider", return_value=mock_provider):
|
||||
with patch("framework.testing.llm_judge.LLMJudge", return_value=mock_judge):
|
||||
result = cmd_skill_test(args)
|
||||
|
||||
out = capsys.readouterr().out
|
||||
data = json.loads(out)
|
||||
assert result == 0
|
||||
assert "evals" in data
|
||||
assert data["total_passed"] == 1
|
||||
assert data["total_failed"] == 0
|
||||
|
||||
def test_no_api_key_with_evals_degrades_gracefully(self, tmp_path, capsys):
|
||||
"""No API key + evals present → structural checks pass, skip LLM, exit 0."""
|
||||
skill_dir = _make_valid_skill(tmp_path, "my-skill")
|
||||
(skill_dir / "evals").mkdir()
|
||||
(skill_dir / "evals" / "evals.json").write_text(
|
||||
json.dumps({"skill_name": "my-skill", "evals": []}), encoding="utf-8"
|
||||
)
|
||||
|
||||
args = Namespace(path=str(skill_dir), input_json=None, model=None, json=False)
|
||||
with patch(
|
||||
"framework.llm.anthropic.AnthropicProvider",
|
||||
side_effect=ValueError("ANTHROPIC_API_KEY not set"),
|
||||
):
|
||||
result = cmd_skill_test(args)
|
||||
|
||||
assert result == 0
|
||||
assert "ANTHROPIC_API_KEY" in capsys.readouterr().err
|
||||
@@ -0,0 +1,248 @@
|
||||
"""Tests for skill install, remove, and fork operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.skills.installer import (
|
||||
fork_skill,
|
||||
install_from_git,
|
||||
maybe_show_install_notice,
|
||||
remove_skill,
|
||||
)
|
||||
from framework.skills.parser import ParsedSkill
|
||||
from framework.skills.skill_errors import SkillError
|
||||
|
||||
|
||||
def _make_skill_dir(parent: Path, name: str, body: str = "## Instructions\n\nDo things.") -> Path:
|
||||
"""Create a minimal skill directory with a valid SKILL.md."""
|
||||
skill_dir = parent / name
|
||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
f"---\nname: {name}\ndescription: A test skill.\n---\n\n{body}\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
return skill_dir
|
||||
|
||||
|
||||
def _make_parsed_skill(base_dir: Path, name: str) -> ParsedSkill:
|
||||
"""Create a ParsedSkill pointing to base_dir."""
|
||||
return ParsedSkill(
|
||||
name=name,
|
||||
description="Test skill.",
|
||||
location=str(base_dir / "SKILL.md"),
|
||||
base_dir=str(base_dir),
|
||||
source_scope="user",
|
||||
body="## Instructions",
|
||||
)
|
||||
|
||||
|
||||
class TestInstallFromGit:
|
||||
def test_copies_skill_dir_to_target(self, tmp_path):
|
||||
"""Successful clone copies skill directory to target."""
|
||||
source_repo = tmp_path / "repo"
|
||||
_make_skill_dir(source_repo, ".") # SKILL.md at repo root
|
||||
|
||||
target = tmp_path / "skills"
|
||||
|
||||
def fake_clone(git_url, target_path, version=None):
|
||||
# Simulate git clone by copying source_repo into target_path
|
||||
import shutil
|
||||
|
||||
if target_path.exists():
|
||||
shutil.rmtree(target_path)
|
||||
shutil.copytree(source_repo, target_path)
|
||||
|
||||
with patch("framework.skills.installer._git_clone_shallow", side_effect=fake_clone):
|
||||
with patch("shutil.which", return_value="/usr/bin/git"):
|
||||
dest = install_from_git(
|
||||
git_url="https://example.com/skill.git",
|
||||
skill_name="my-skill",
|
||||
target_dir=target,
|
||||
)
|
||||
|
||||
assert (dest / "SKILL.md").exists()
|
||||
assert dest == target / "my-skill"
|
||||
|
||||
def test_raises_when_git_not_found(self, tmp_path):
|
||||
with patch("shutil.which", return_value=None):
|
||||
with pytest.raises(SkillError) as exc_info:
|
||||
install_from_git(
|
||||
git_url="https://example.com/skill.git",
|
||||
skill_name="my-skill",
|
||||
target_dir=tmp_path / "skills",
|
||||
)
|
||||
assert "git is not installed" in exc_info.value.why
|
||||
|
||||
def test_raises_when_skill_md_missing(self, tmp_path):
|
||||
"""Clone succeeds but no SKILL.md in the subdirectory → error."""
|
||||
empty_repo = tmp_path / "empty_repo"
|
||||
empty_repo.mkdir()
|
||||
|
||||
def fake_clone(git_url, target_path, version=None):
|
||||
import shutil
|
||||
|
||||
if target_path.exists():
|
||||
shutil.rmtree(target_path)
|
||||
shutil.copytree(empty_repo, target_path)
|
||||
|
||||
with patch("framework.skills.installer._git_clone_shallow", side_effect=fake_clone):
|
||||
with patch("shutil.which", return_value="/usr/bin/git"):
|
||||
with pytest.raises(SkillError) as exc_info:
|
||||
install_from_git(
|
||||
git_url="https://example.com/skill.git",
|
||||
skill_name="my-skill",
|
||||
subdirectory="deep-research",
|
||||
target_dir=tmp_path / "skills",
|
||||
)
|
||||
assert exc_info.value.code.value == "SKILL_NOT_FOUND"
|
||||
|
||||
def test_raises_when_target_already_exists(self, tmp_path):
|
||||
skills_dir = tmp_path / "skills"
|
||||
(skills_dir / "existing-skill").mkdir(parents=True)
|
||||
|
||||
with patch("shutil.which", return_value="/usr/bin/git"):
|
||||
with pytest.raises(SkillError) as exc_info:
|
||||
install_from_git(
|
||||
git_url="https://example.com/skill.git",
|
||||
skill_name="existing-skill",
|
||||
target_dir=skills_dir,
|
||||
)
|
||||
assert "already exists" in exc_info.value.why
|
||||
|
||||
def test_cleans_temp_dir_on_clone_failure(self, tmp_path):
|
||||
"""Temporary directory is cleaned up even when clone fails."""
|
||||
created_tmp_dirs = []
|
||||
original_mkdtemp = __import__("tempfile").mkdtemp
|
||||
|
||||
def tracking_mkdtemp(**kwargs):
|
||||
d = original_mkdtemp(**kwargs)
|
||||
created_tmp_dirs.append(d)
|
||||
return d
|
||||
|
||||
def failing_clone(git_url, target_path, version=None):
|
||||
from framework.skills.skill_errors import SkillErrorCode as SEC
|
||||
|
||||
raise SkillError(
|
||||
code=SEC.SKILL_ACTIVATION_FAILED,
|
||||
what="clone failed",
|
||||
why="network error",
|
||||
fix="check network",
|
||||
)
|
||||
|
||||
with patch("tempfile.mkdtemp", side_effect=tracking_mkdtemp):
|
||||
with patch("framework.skills.installer._git_clone_shallow", side_effect=failing_clone):
|
||||
with patch("shutil.which", return_value="/usr/bin/git"):
|
||||
with pytest.raises(SkillError):
|
||||
install_from_git(
|
||||
git_url="https://example.com/skill.git",
|
||||
skill_name="my-skill",
|
||||
target_dir=tmp_path / "skills",
|
||||
)
|
||||
|
||||
# All created temp dirs should be cleaned up
|
||||
for d in created_tmp_dirs:
|
||||
assert not Path(d).exists(), f"Temp dir not cleaned: {d}"
|
||||
|
||||
|
||||
class TestRemoveSkill:
|
||||
def test_removes_existing_skill(self, tmp_path):
|
||||
skills_dir = tmp_path / "skills"
|
||||
skill_dir = _make_skill_dir(skills_dir, "my-skill")
|
||||
assert skill_dir.exists()
|
||||
|
||||
result = remove_skill("my-skill", skills_dir=skills_dir)
|
||||
assert result is True
|
||||
assert not skill_dir.exists()
|
||||
|
||||
def test_returns_false_when_not_found(self, tmp_path):
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
|
||||
result = remove_skill("nonexistent", skills_dir=skills_dir)
|
||||
assert result is False
|
||||
|
||||
def test_raises_on_permission_error(self, tmp_path):
|
||||
skills_dir = tmp_path / "skills"
|
||||
_make_skill_dir(skills_dir, "locked-skill")
|
||||
|
||||
with patch("shutil.rmtree", side_effect=OSError("permission denied")):
|
||||
with pytest.raises(SkillError) as exc_info:
|
||||
remove_skill("locked-skill", skills_dir=skills_dir)
|
||||
assert "permission" in exc_info.value.why.lower()
|
||||
|
||||
|
||||
class TestForkSkill:
|
||||
def test_copies_skill_to_new_name(self, tmp_path):
|
||||
source_dir = _make_skill_dir(tmp_path / "sources", "my-skill")
|
||||
source = _make_parsed_skill(source_dir, "my-skill")
|
||||
target_parent = tmp_path / "skills"
|
||||
|
||||
dest = fork_skill(source, "my-skill-fork", target_parent)
|
||||
|
||||
assert dest.exists()
|
||||
assert (dest / "SKILL.md").exists()
|
||||
|
||||
def test_rewrites_name_in_skill_md(self, tmp_path):
|
||||
source_dir = _make_skill_dir(tmp_path / "sources", "original")
|
||||
source = _make_parsed_skill(source_dir, "original")
|
||||
target_parent = tmp_path / "skills"
|
||||
|
||||
dest = fork_skill(source, "forked", target_parent)
|
||||
|
||||
import yaml
|
||||
|
||||
content = (dest / "SKILL.md").read_text(encoding="utf-8")
|
||||
parts = content.split("---", 2)
|
||||
fm = yaml.safe_load(parts[1])
|
||||
assert fm["name"] == "forked"
|
||||
|
||||
def test_raises_when_dest_already_exists(self, tmp_path):
|
||||
source_dir = _make_skill_dir(tmp_path / "sources", "my-skill")
|
||||
source = _make_parsed_skill(source_dir, "my-skill")
|
||||
target_parent = tmp_path / "skills"
|
||||
(target_parent / "my-skill-fork").mkdir(parents=True)
|
||||
|
||||
with pytest.raises(SkillError) as exc_info:
|
||||
fork_skill(source, "my-skill-fork", target_parent)
|
||||
assert "already exists" in exc_info.value.why
|
||||
|
||||
def test_preserves_scripts_and_references(self, tmp_path):
|
||||
source_dir = _make_skill_dir(tmp_path / "sources", "my-skill")
|
||||
(source_dir / "scripts").mkdir()
|
||||
(source_dir / "scripts" / "run.sh").write_text("#!/bin/sh\necho hi")
|
||||
(source_dir / "references").mkdir()
|
||||
(source_dir / "references" / "guide.md").write_text("# Guide")
|
||||
source = _make_parsed_skill(source_dir, "my-skill")
|
||||
target_parent = tmp_path / "skills"
|
||||
|
||||
dest = fork_skill(source, "fork", target_parent)
|
||||
|
||||
assert (dest / "scripts" / "run.sh").exists()
|
||||
assert (dest / "references" / "guide.md").exists()
|
||||
|
||||
|
||||
class TestInstallNotice:
|
||||
def test_shown_on_first_call(self, tmp_path, monkeypatch, capsys):
|
||||
sentinel = tmp_path / ".install_notice_shown"
|
||||
monkeypatch.setattr("framework.skills.installer.INSTALL_NOTICE_SENTINEL", sentinel)
|
||||
|
||||
maybe_show_install_notice()
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Security Notice" in captured.out
|
||||
assert sentinel.exists()
|
||||
|
||||
def test_not_shown_on_second_call(self, tmp_path, monkeypatch, capsys):
|
||||
sentinel = tmp_path / ".install_notice_shown"
|
||||
sentinel.parent.mkdir(parents=True, exist_ok=True)
|
||||
sentinel.touch()
|
||||
monkeypatch.setattr("framework.skills.installer.INSTALL_NOTICE_SENTINEL", sentinel)
|
||||
|
||||
maybe_show_install_notice()
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Security Notice" not in captured.out
|
||||
@@ -0,0 +1,244 @@
|
||||
"""Tests for the RegistryClient skill registry client."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
from urllib.error import URLError
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.skills.registry import _CACHE_TTL_SECONDS, RegistryClient
|
||||
|
||||
_SAMPLE_INDEX = {
|
||||
"version": 1,
|
||||
"skills": [
|
||||
{
|
||||
"name": "deep-research",
|
||||
"description": "Multi-step web research with source verification.",
|
||||
"version": "1.0.0",
|
||||
"author": "anthropics",
|
||||
"license": "MIT",
|
||||
"tags": ["research", "web"],
|
||||
"git_url": "https://github.com/anthropics/skills",
|
||||
"subdirectory": "deep-research",
|
||||
"trust_tier": "official",
|
||||
},
|
||||
{
|
||||
"name": "code-review",
|
||||
"description": "Automated code review for style and correctness.",
|
||||
"version": "0.9.0",
|
||||
"author": "contributor",
|
||||
"tags": ["code", "review"],
|
||||
"git_url": "https://github.com/contributor/code-review",
|
||||
"subdirectory": None,
|
||||
"trust_tier": "community",
|
||||
},
|
||||
],
|
||||
"packs": [
|
||||
{
|
||||
"name": "research-starter",
|
||||
"description": "Research-focused skill bundle",
|
||||
"skills": ["deep-research"],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cache_dir(tmp_path):
|
||||
return tmp_path / "registry_cache"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(cache_dir):
|
||||
return RegistryClient(registry_url="https://example.com/skill_index.json", cache_dir=cache_dir)
|
||||
|
||||
|
||||
class TestFetchIndex:
|
||||
def test_returns_none_on_network_error(self, client):
|
||||
with patch.object(client, "_http_fetch", return_value=None):
|
||||
result = client.fetch_index()
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_on_url_error(self, client):
|
||||
with patch("framework.skills.registry.urlopen", side_effect=URLError("connection refused")):
|
||||
result = client.fetch_index()
|
||||
assert result is None
|
||||
|
||||
def test_fetches_and_caches_index(self, client):
|
||||
raw = json.dumps(_SAMPLE_INDEX).encode()
|
||||
with patch.object(client, "_http_fetch", return_value=raw):
|
||||
result = client.fetch_index()
|
||||
assert result is not None
|
||||
assert len(result["skills"]) == 2
|
||||
# Cache should be written
|
||||
assert client._index_path.exists()
|
||||
|
||||
def test_uses_fresh_cache_without_network(self, client, cache_dir):
|
||||
# Write fresh cache
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
(cache_dir / "skill_index.json").write_text(json.dumps(_SAMPLE_INDEX))
|
||||
meta = {"last_fetched": datetime.now(tz=UTC).isoformat()}
|
||||
(cache_dir / "metadata.json").write_text(json.dumps(meta))
|
||||
|
||||
fetch_called = []
|
||||
|
||||
def _no_fetch(*a, **kw):
|
||||
fetch_called.append(1)
|
||||
|
||||
with patch.object(client, "_http_fetch", side_effect=_no_fetch):
|
||||
result = client.fetch_index()
|
||||
|
||||
assert not fetch_called, "Should not hit network when cache is fresh"
|
||||
assert result is not None
|
||||
|
||||
def test_refreshes_when_cache_is_stale(self, client, cache_dir):
|
||||
# Write stale cache (older than TTL)
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
(cache_dir / "skill_index.json").write_text(json.dumps(_SAMPLE_INDEX))
|
||||
old_time = (datetime.now(tz=UTC) - timedelta(seconds=_CACHE_TTL_SECONDS + 60)).isoformat()
|
||||
meta = {"last_fetched": old_time}
|
||||
(cache_dir / "metadata.json").write_text(json.dumps(meta))
|
||||
|
||||
raw = json.dumps(_SAMPLE_INDEX).encode()
|
||||
with patch.object(client, "_http_fetch", return_value=raw) as mock_fetch:
|
||||
client.fetch_index()
|
||||
mock_fetch.assert_called_once()
|
||||
|
||||
def test_force_refresh_bypasses_fresh_cache(self, client, cache_dir):
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
(cache_dir / "skill_index.json").write_text(json.dumps(_SAMPLE_INDEX))
|
||||
meta = {"last_fetched": datetime.now(tz=UTC).isoformat()}
|
||||
(cache_dir / "metadata.json").write_text(json.dumps(meta))
|
||||
|
||||
raw = json.dumps(_SAMPLE_INDEX).encode()
|
||||
with patch.object(client, "_http_fetch", return_value=raw) as mock_fetch:
|
||||
client.fetch_index(force_refresh=True)
|
||||
mock_fetch.assert_called_once()
|
||||
|
||||
def test_falls_back_to_stale_cache_on_network_error(self, client, cache_dir):
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
(cache_dir / "skill_index.json").write_text(json.dumps(_SAMPLE_INDEX))
|
||||
# No metadata → stale
|
||||
|
||||
with patch.object(client, "_http_fetch", return_value=None):
|
||||
result = client.fetch_index()
|
||||
|
||||
assert result is not None
|
||||
assert result["version"] == 1
|
||||
|
||||
|
||||
class TestSearch:
|
||||
def test_filters_by_name(self, client):
|
||||
with patch.object(client, "fetch_index", return_value=_SAMPLE_INDEX):
|
||||
results = client.search("deep")
|
||||
assert len(results) == 1
|
||||
assert results[0]["name"] == "deep-research"
|
||||
|
||||
def test_filters_by_description(self, client):
|
||||
with patch.object(client, "fetch_index", return_value=_SAMPLE_INDEX):
|
||||
results = client.search("source verification")
|
||||
assert any(r["name"] == "deep-research" for r in results)
|
||||
|
||||
def test_filters_by_tag(self, client):
|
||||
with patch.object(client, "fetch_index", return_value=_SAMPLE_INDEX):
|
||||
results = client.search("review")
|
||||
assert any(r["name"] == "code-review" for r in results)
|
||||
|
||||
def test_case_insensitive(self, client):
|
||||
with patch.object(client, "fetch_index", return_value=_SAMPLE_INDEX):
|
||||
results = client.search("DEEP")
|
||||
assert len(results) == 1
|
||||
|
||||
def test_returns_empty_when_unavailable(self, client):
|
||||
with patch.object(client, "fetch_index", return_value=None):
|
||||
results = client.search("anything")
|
||||
assert results == []
|
||||
|
||||
def test_returns_empty_on_no_match(self, client):
|
||||
with patch.object(client, "fetch_index", return_value=_SAMPLE_INDEX):
|
||||
results = client.search("xyzzy-no-match")
|
||||
assert results == []
|
||||
|
||||
|
||||
class TestGetSkillEntry:
|
||||
def test_finds_by_exact_name(self, client):
|
||||
with patch.object(client, "fetch_index", return_value=_SAMPLE_INDEX):
|
||||
entry = client.get_skill_entry("deep-research")
|
||||
assert entry is not None
|
||||
assert entry["name"] == "deep-research"
|
||||
|
||||
def test_returns_none_when_not_found(self, client):
|
||||
with patch.object(client, "fetch_index", return_value=_SAMPLE_INDEX):
|
||||
entry = client.get_skill_entry("nonexistent")
|
||||
assert entry is None
|
||||
|
||||
def test_returns_none_when_index_unavailable(self, client):
|
||||
with patch.object(client, "fetch_index", return_value=None):
|
||||
entry = client.get_skill_entry("deep-research")
|
||||
assert entry is None
|
||||
|
||||
|
||||
class TestGetPack:
|
||||
def test_returns_skill_names(self, client):
|
||||
with patch.object(client, "fetch_index", return_value=_SAMPLE_INDEX):
|
||||
skills = client.get_pack("research-starter")
|
||||
assert skills == ["deep-research"]
|
||||
|
||||
def test_returns_none_when_pack_not_found(self, client):
|
||||
with patch.object(client, "fetch_index", return_value=_SAMPLE_INDEX):
|
||||
result = client.get_pack("nonexistent-pack")
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_when_index_unavailable(self, client):
|
||||
with patch.object(client, "fetch_index", return_value=None):
|
||||
result = client.get_pack("research-starter")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestResolveGitUrl:
|
||||
def test_returns_git_url_and_subdirectory(self, client):
|
||||
with patch.object(client, "fetch_index", return_value=_SAMPLE_INDEX):
|
||||
result = client.resolve_git_url("deep-research")
|
||||
assert result == ("https://github.com/anthropics/skills", "deep-research")
|
||||
|
||||
def test_returns_none_subdirectory_when_absent(self, client):
|
||||
with patch.object(client, "fetch_index", return_value=_SAMPLE_INDEX):
|
||||
result = client.resolve_git_url("code-review")
|
||||
git_url, subdir = result
|
||||
assert subdir is None
|
||||
|
||||
def test_returns_none_when_not_in_registry(self, client):
|
||||
with patch.object(client, "fetch_index", return_value=_SAMPLE_INDEX):
|
||||
result = client.resolve_git_url("not-there")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestCacheAtomicWrite:
|
||||
def test_atomic_write_uses_tmp_then_replace(self, client, cache_dir, monkeypatch):
|
||||
written_paths = []
|
||||
original_write = Path.write_text
|
||||
|
||||
def tracking_write(self, data, encoding=None):
|
||||
written_paths.append(str(self))
|
||||
return original_write(self, data, encoding=encoding or "utf-8")
|
||||
|
||||
monkeypatch.setattr(Path, "write_text", tracking_write)
|
||||
client._save_cache(_SAMPLE_INDEX)
|
||||
|
||||
# .tmp file should have been written (then replaced — may not exist now)
|
||||
assert any(".tmp" in p for p in written_paths)
|
||||
# Final index file should exist
|
||||
assert client._index_path.exists()
|
||||
|
||||
def test_save_and_load_round_trip(self, client):
|
||||
client._save_cache(_SAMPLE_INDEX)
|
||||
loaded = client._load_cache()
|
||||
assert loaded == _SAMPLE_INDEX
|
||||
|
||||
def test_load_returns_none_when_absent(self, client):
|
||||
result = client._load_cache()
|
||||
assert result is None
|
||||
@@ -0,0 +1,401 @@
|
||||
"""Tests for strict SKILL.md validation (hive skill validate).
|
||||
|
||||
One test per strict check — happy path plus each individual failure mode.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from framework.skills.validator import validate_strict
|
||||
|
||||
|
||||
def _write_skill(tmp_path: Path, content: str, dir_name: str = "my-skill") -> Path:
|
||||
"""Write a SKILL.md in a named subdirectory and return the path."""
|
||||
skill_dir = tmp_path / dir_name
|
||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
||||
skill_md = skill_dir / "SKILL.md"
|
||||
skill_md.write_text(content, encoding="utf-8")
|
||||
return skill_md
|
||||
|
||||
|
||||
_VALID_CONTENT = """\
|
||||
---
|
||||
name: my-skill
|
||||
description: A test skill for validation.
|
||||
version: 0.1.0
|
||||
license: MIT
|
||||
compatibility:
|
||||
- claude-code
|
||||
- hive
|
||||
metadata:
|
||||
tags: []
|
||||
---
|
||||
|
||||
## Instructions
|
||||
|
||||
Do the thing properly.
|
||||
"""
|
||||
|
||||
|
||||
class TestHappyPath:
|
||||
def test_valid_skill_passes(self, tmp_path):
|
||||
path = _write_skill(tmp_path, _VALID_CONTENT)
|
||||
result = validate_strict(path)
|
||||
assert result.passed is True
|
||||
assert result.errors == []
|
||||
|
||||
def test_namespace_prefix_name_allowed(self, tmp_path):
|
||||
"""hive.my-skill with directory my-skill is valid."""
|
||||
content = """\
|
||||
---
|
||||
name: hive.my-skill
|
||||
description: A namespaced skill.
|
||||
license: MIT
|
||||
---
|
||||
|
||||
## Body
|
||||
"""
|
||||
path = _write_skill(tmp_path, content, dir_name="my-skill")
|
||||
result = validate_strict(path)
|
||||
assert result.passed is True
|
||||
|
||||
def test_warning_on_missing_license(self, tmp_path):
|
||||
content = """\
|
||||
---
|
||||
name: my-skill
|
||||
description: No license field.
|
||||
---
|
||||
|
||||
## Body
|
||||
"""
|
||||
path = _write_skill(tmp_path, content)
|
||||
result = validate_strict(path)
|
||||
assert result.passed is True
|
||||
assert any("license" in w.lower() for w in result.warnings)
|
||||
|
||||
|
||||
class TestCheck1FileExists:
|
||||
def test_error_on_missing_file(self, tmp_path):
|
||||
path = tmp_path / "nonexistent" / "SKILL.md"
|
||||
result = validate_strict(path)
|
||||
assert result.passed is False
|
||||
assert any("not found" in e.lower() for e in result.errors)
|
||||
|
||||
|
||||
class TestCheck2FileNotEmpty:
|
||||
def test_error_on_empty_file(self, tmp_path):
|
||||
skill_dir = tmp_path / "my-skill"
|
||||
skill_dir.mkdir()
|
||||
path = skill_dir / "SKILL.md"
|
||||
path.write_text(" \n", encoding="utf-8")
|
||||
result = validate_strict(path)
|
||||
assert result.passed is False
|
||||
assert any("empty" in e.lower() for e in result.errors)
|
||||
|
||||
|
||||
class TestCheck3FrontmatterPresent:
|
||||
def test_error_on_missing_delimiters(self, tmp_path):
|
||||
path = _write_skill(tmp_path, "name: my-skill\ndescription: no delimiters\n")
|
||||
result = validate_strict(path)
|
||||
assert result.passed is False
|
||||
assert any("frontmatter" in e.lower() or "---" in e for e in result.errors)
|
||||
|
||||
|
||||
class TestCheck4YamlNoFixup:
|
||||
def test_error_on_yaml_requiring_fixup(self, tmp_path):
|
||||
"""Unquoted colon in value — lenient parser accepts, strict rejects."""
|
||||
content = """\
|
||||
---
|
||||
name: my-skill
|
||||
description: Use for: research tasks
|
||||
---
|
||||
|
||||
## Body
|
||||
"""
|
||||
path = _write_skill(tmp_path, content)
|
||||
result = validate_strict(path)
|
||||
assert result.passed is False
|
||||
assert any("YAML" in e or "parse" in e.lower() for e in result.errors)
|
||||
|
||||
def test_quoted_colon_passes(self, tmp_path):
|
||||
content = """\
|
||||
---
|
||||
name: my-skill
|
||||
description: "Use for: research tasks"
|
||||
license: MIT
|
||||
---
|
||||
|
||||
## Body
|
||||
"""
|
||||
path = _write_skill(tmp_path, content)
|
||||
result = validate_strict(path)
|
||||
assert result.passed is True
|
||||
|
||||
|
||||
class TestCheck5Description:
|
||||
def test_error_on_missing_description(self, tmp_path):
|
||||
content = """\
|
||||
---
|
||||
name: my-skill
|
||||
license: MIT
|
||||
---
|
||||
|
||||
## Body
|
||||
"""
|
||||
path = _write_skill(tmp_path, content)
|
||||
result = validate_strict(path)
|
||||
assert result.passed is False
|
||||
assert any("description" in e.lower() for e in result.errors)
|
||||
|
||||
def test_error_on_empty_description(self, tmp_path):
|
||||
content = """\
|
||||
---
|
||||
name: my-skill
|
||||
description: ""
|
||||
license: MIT
|
||||
---
|
||||
|
||||
## Body
|
||||
"""
|
||||
path = _write_skill(tmp_path, content)
|
||||
result = validate_strict(path)
|
||||
assert result.passed is False
|
||||
|
||||
|
||||
class TestCheck6NamePresent:
|
||||
def test_error_on_missing_name(self, tmp_path):
|
||||
content = """\
|
||||
---
|
||||
description: A skill without a name.
|
||||
license: MIT
|
||||
---
|
||||
|
||||
## Body
|
||||
"""
|
||||
path = _write_skill(tmp_path, content)
|
||||
result = validate_strict(path)
|
||||
assert result.passed is False
|
||||
assert any("name" in e.lower() for e in result.errors)
|
||||
|
||||
|
||||
class TestCheck7NameLength:
|
||||
def test_error_on_name_too_long(self, tmp_path):
|
||||
long_name = "a" * 65
|
||||
skill_dir = tmp_path / long_name
|
||||
skill_dir.mkdir(parents=True)
|
||||
content = f"---\nname: {long_name}\ndescription: Too long.\nlicense: MIT\n---\n\n## Body\n"
|
||||
path = skill_dir / "SKILL.md"
|
||||
path.write_text(content, encoding="utf-8")
|
||||
|
||||
result = validate_strict(path)
|
||||
assert result.passed is False
|
||||
assert any("64" in e or "characters" in e.lower() for e in result.errors)
|
||||
|
||||
def test_exactly_64_chars_passes(self, tmp_path):
|
||||
name = "a" * 64
|
||||
skill_dir = tmp_path / name
|
||||
skill_dir.mkdir(parents=True)
|
||||
content = f"---\nname: {name}\ndescription: Exactly 64.\nlicense: MIT\n---\n\n## Body\n"
|
||||
path = skill_dir / "SKILL.md"
|
||||
path.write_text(content, encoding="utf-8")
|
||||
|
||||
result = validate_strict(path)
|
||||
# May have other warnings but should not error on length
|
||||
assert not any("64" in e or "characters" in e.lower() for e in result.errors)
|
||||
|
||||
|
||||
class TestCheck8NameDirectoryMatch:
|
||||
def test_error_on_name_dir_mismatch(self, tmp_path):
|
||||
content = """\
|
||||
---
|
||||
name: other-skill
|
||||
description: Wrong name.
|
||||
license: MIT
|
||||
---
|
||||
|
||||
## Body
|
||||
"""
|
||||
# Directory is my-skill but name is other-skill
|
||||
path = _write_skill(tmp_path, content, dir_name="my-skill")
|
||||
result = validate_strict(path)
|
||||
assert result.passed is False
|
||||
assert any("other-skill" in e or "my-skill" in e for e in result.errors)
|
||||
|
||||
def test_exact_match_passes(self, tmp_path):
|
||||
content = """\
|
||||
---
|
||||
name: my-skill
|
||||
description: Exact match.
|
||||
license: MIT
|
||||
---
|
||||
|
||||
## Body
|
||||
"""
|
||||
path = _write_skill(tmp_path, content, dir_name="my-skill")
|
||||
result = validate_strict(path)
|
||||
assert result.passed is True
|
||||
|
||||
def test_dot_namespace_prefix_passes(self, tmp_path):
|
||||
"""hive.my-skill with dir my-skill is valid (namespace prefix)."""
|
||||
content = """\
|
||||
---
|
||||
name: org.my-skill
|
||||
description: Namespaced.
|
||||
license: MIT
|
||||
---
|
||||
|
||||
## Body
|
||||
"""
|
||||
path = _write_skill(tmp_path, content, dir_name="my-skill")
|
||||
result = validate_strict(path)
|
||||
# Should not error on name/dir mismatch for namespace prefix
|
||||
assert not any("my-skill" in e and "other" in e for e in result.errors)
|
||||
# Check no dir mismatch error specifically
|
||||
name_mismatch_errors = [e for e in result.errors if "my-skill" in e and "org.my-skill" in e]
|
||||
assert len(name_mismatch_errors) == 0
|
||||
|
||||
|
||||
class TestCheck9BodyNotEmpty:
|
||||
def test_error_on_empty_body(self, tmp_path):
|
||||
content = """\
|
||||
---
|
||||
name: my-skill
|
||||
description: No body.
|
||||
license: MIT
|
||||
---
|
||||
"""
|
||||
path = _write_skill(tmp_path, content)
|
||||
result = validate_strict(path)
|
||||
assert result.passed is False
|
||||
assert any("body" in e.lower() or "instructions" in e.lower() for e in result.errors)
|
||||
|
||||
|
||||
class TestCheck11Scripts:
|
||||
def test_error_on_non_executable_script(self, tmp_path):
|
||||
path = _write_skill(tmp_path, _VALID_CONTENT)
|
||||
scripts_dir = path.parent / "scripts"
|
||||
scripts_dir.mkdir()
|
||||
script = scripts_dir / "run.sh"
|
||||
script.write_text("#!/bin/sh\necho hi")
|
||||
# Ensure NOT executable
|
||||
script.chmod(0o644)
|
||||
|
||||
result = validate_strict(path)
|
||||
assert result.passed is False
|
||||
assert any("executable" in e.lower() for e in result.errors)
|
||||
|
||||
def test_passes_with_executable_script(self, tmp_path):
|
||||
path = _write_skill(tmp_path, _VALID_CONTENT)
|
||||
scripts_dir = path.parent / "scripts"
|
||||
scripts_dir.mkdir()
|
||||
script = scripts_dir / "run.sh"
|
||||
script.write_text("#!/bin/sh\necho hi")
|
||||
script.chmod(0o755)
|
||||
|
||||
result = validate_strict(path)
|
||||
assert result.passed is True
|
||||
|
||||
|
||||
class TestCheck12AllowedTools:
|
||||
def test_warning_on_malformed_allowed_tools(self, tmp_path):
|
||||
content = """\
|
||||
---
|
||||
name: my-skill
|
||||
description: Skill with bad tools.
|
||||
license: MIT
|
||||
allowed-tools: "not a list"
|
||||
---
|
||||
|
||||
## Body
|
||||
"""
|
||||
path = _write_skill(tmp_path, content)
|
||||
result = validate_strict(path)
|
||||
assert any("allowed-tools" in w.lower() for w in result.warnings)
|
||||
|
||||
def test_valid_allowed_tools_no_warning(self, tmp_path):
|
||||
content = """\
|
||||
---
|
||||
name: my-skill
|
||||
description: Valid tools list.
|
||||
license: MIT
|
||||
allowed-tools:
|
||||
- web_search
|
||||
- file_read
|
||||
---
|
||||
|
||||
## Body
|
||||
"""
|
||||
path = _write_skill(tmp_path, content)
|
||||
result = validate_strict(path)
|
||||
assert not any("allowed-tools" in w.lower() for w in result.warnings)
|
||||
|
||||
|
||||
class TestCheck13Compatibility:
|
||||
def test_error_on_non_list_compatibility(self, tmp_path):
|
||||
content = """\
|
||||
---
|
||||
name: my-skill
|
||||
description: Bad compat.
|
||||
license: MIT
|
||||
compatibility: "claude-code"
|
||||
---
|
||||
|
||||
## Body
|
||||
"""
|
||||
path = _write_skill(tmp_path, content)
|
||||
result = validate_strict(path)
|
||||
assert result.passed is False
|
||||
assert any("compatibility" in e.lower() for e in result.errors)
|
||||
|
||||
def test_valid_compatibility_passes(self, tmp_path):
|
||||
content = """\
|
||||
---
|
||||
name: my-skill
|
||||
description: Good compat.
|
||||
license: MIT
|
||||
compatibility:
|
||||
- claude-code
|
||||
- hive
|
||||
---
|
||||
|
||||
## Body
|
||||
"""
|
||||
path = _write_skill(tmp_path, content)
|
||||
result = validate_strict(path)
|
||||
assert result.passed is True
|
||||
|
||||
|
||||
class TestCheck14Metadata:
|
||||
def test_error_on_non_dict_metadata(self, tmp_path):
|
||||
content = """\
|
||||
---
|
||||
name: my-skill
|
||||
description: Bad metadata.
|
||||
license: MIT
|
||||
metadata: "not a dict"
|
||||
---
|
||||
|
||||
## Body
|
||||
"""
|
||||
path = _write_skill(tmp_path, content)
|
||||
result = validate_strict(path)
|
||||
assert result.passed is False
|
||||
assert any("metadata" in e.lower() for e in result.errors)
|
||||
|
||||
def test_valid_metadata_passes(self, tmp_path):
|
||||
content = """\
|
||||
---
|
||||
name: my-skill
|
||||
description: Good metadata.
|
||||
license: MIT
|
||||
metadata:
|
||||
tags:
|
||||
- research
|
||||
---
|
||||
|
||||
## Body
|
||||
"""
|
||||
path = _write_skill(tmp_path, content)
|
||||
result = validate_strict(path)
|
||||
assert result.passed is True
|
||||
+3
-496
@@ -1,18 +1,16 @@
|
||||
"""Tests for the storage module - FileStorage and ConcurrentStorage backends.
|
||||
"""Tests for the storage module - ConcurrentStorage backend.
|
||||
|
||||
DEPRECATED: FileStorage and ConcurrentStorage are deprecated.
|
||||
DEPRECATED: FileStorage has been removed.
|
||||
New sessions use unified storage at sessions/{session_id}/state.json.
|
||||
These tests are kept for backward compatibility verification only.
|
||||
These tests are kept for backward compatibility verification of ConcurrentStorage only.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.schemas.run import Run, RunMetrics, RunStatus
|
||||
from framework.storage.backend import FileStorage
|
||||
from framework.storage.concurrent import CacheEntry, ConcurrentStorage
|
||||
|
||||
# === HELPER FUNCTIONS ===
|
||||
@@ -40,277 +38,6 @@ def create_test_run(
|
||||
)
|
||||
|
||||
|
||||
# === FILESTORAGE TESTS ===
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="FileStorage is deprecated - use unified session storage")
|
||||
class TestFileStorageBasics:
|
||||
"""Test basic FileStorage operations."""
|
||||
|
||||
def test_init_creates_directories(self, tmp_path: Path):
|
||||
"""FileStorage should create the directory structure on init."""
|
||||
FileStorage(tmp_path)
|
||||
|
||||
assert (tmp_path / "runs").exists()
|
||||
assert (tmp_path / "summaries").exists()
|
||||
assert (tmp_path / "indexes" / "by_goal").exists()
|
||||
assert (tmp_path / "indexes" / "by_status").exists()
|
||||
assert (tmp_path / "indexes" / "by_node").exists()
|
||||
|
||||
def test_init_with_string_path(self, tmp_path: Path):
|
||||
"""FileStorage should accept string paths."""
|
||||
storage = FileStorage(str(tmp_path))
|
||||
assert storage.base_path == tmp_path
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="FileStorage is deprecated - use unified session storage")
|
||||
class TestFileStorageRunOperations:
|
||||
"""Test FileStorage run CRUD operations."""
|
||||
|
||||
def test_save_and_load_run(self, tmp_path: Path):
|
||||
"""Test saving and loading a run."""
|
||||
storage = FileStorage(tmp_path)
|
||||
run = create_test_run()
|
||||
|
||||
storage.save_run(run)
|
||||
loaded = storage.load_run(run.id)
|
||||
|
||||
assert loaded is not None
|
||||
assert loaded.id == run.id
|
||||
assert loaded.goal_id == run.goal_id
|
||||
assert loaded.status == run.status
|
||||
|
||||
def test_load_nonexistent_run_returns_none(self, tmp_path: Path):
|
||||
"""Loading a nonexistent run should return None."""
|
||||
storage = FileStorage(tmp_path)
|
||||
|
||||
result = storage.load_run("nonexistent_id")
|
||||
assert result is None
|
||||
|
||||
def test_save_creates_json_file(self, tmp_path: Path):
|
||||
"""Saving a run should create a JSON file."""
|
||||
storage = FileStorage(tmp_path)
|
||||
run = create_test_run(run_id="my_run")
|
||||
|
||||
storage.save_run(run)
|
||||
|
||||
run_file = tmp_path / "runs" / "my_run.json"
|
||||
assert run_file.exists()
|
||||
|
||||
# Verify it's valid JSON
|
||||
with open(run_file, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
assert data["id"] == "my_run"
|
||||
|
||||
def test_save_creates_summary(self, tmp_path: Path):
|
||||
"""Saving a run should also create a summary file."""
|
||||
storage = FileStorage(tmp_path)
|
||||
run = create_test_run(run_id="my_run")
|
||||
|
||||
storage.save_run(run)
|
||||
|
||||
summary_file = tmp_path / "summaries" / "my_run.json"
|
||||
assert summary_file.exists()
|
||||
|
||||
def test_load_summary(self, tmp_path: Path):
|
||||
"""Test loading a run summary."""
|
||||
storage = FileStorage(tmp_path)
|
||||
run = create_test_run()
|
||||
|
||||
storage.save_run(run)
|
||||
summary = storage.load_summary(run.id)
|
||||
|
||||
assert summary is not None
|
||||
assert summary.run_id == run.id
|
||||
assert summary.goal_id == run.goal_id
|
||||
assert summary.status == run.status
|
||||
|
||||
def test_load_summary_fallback_to_run(self, tmp_path: Path):
|
||||
"""If summary file is missing, load_summary should compute from run."""
|
||||
storage = FileStorage(tmp_path)
|
||||
run = create_test_run()
|
||||
|
||||
storage.save_run(run)
|
||||
|
||||
# Delete the summary file
|
||||
summary_file = tmp_path / "summaries" / f"{run.id}.json"
|
||||
summary_file.unlink()
|
||||
|
||||
# Should still work by computing from run
|
||||
summary = storage.load_summary(run.id)
|
||||
assert summary is not None
|
||||
assert summary.run_id == run.id
|
||||
|
||||
def test_delete_run(self, tmp_path: Path):
|
||||
"""Test deleting a run."""
|
||||
storage = FileStorage(tmp_path)
|
||||
run = create_test_run()
|
||||
|
||||
storage.save_run(run)
|
||||
assert storage.load_run(run.id) is not None
|
||||
|
||||
result = storage.delete_run(run.id)
|
||||
|
||||
assert result is True
|
||||
assert storage.load_run(run.id) is None
|
||||
|
||||
def test_delete_nonexistent_run_returns_false(self, tmp_path: Path):
|
||||
"""Deleting a nonexistent run should return False."""
|
||||
storage = FileStorage(tmp_path)
|
||||
|
||||
result = storage.delete_run("nonexistent")
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="FileStorage is deprecated - use unified session storage")
|
||||
class TestFileStorageIndexing:
|
||||
"""Test FileStorage index operations."""
|
||||
|
||||
def test_index_by_goal(self, tmp_path: Path):
|
||||
"""Runs should be indexed by goal_id."""
|
||||
storage = FileStorage(tmp_path)
|
||||
|
||||
run1 = create_test_run(run_id="run_1", goal_id="goal_a")
|
||||
run2 = create_test_run(run_id="run_2", goal_id="goal_a")
|
||||
run3 = create_test_run(run_id="run_3", goal_id="goal_b")
|
||||
|
||||
storage.save_run(run1)
|
||||
storage.save_run(run2)
|
||||
storage.save_run(run3)
|
||||
|
||||
goal_a_runs = storage.get_runs_by_goal("goal_a")
|
||||
goal_b_runs = storage.get_runs_by_goal("goal_b")
|
||||
|
||||
assert len(goal_a_runs) == 2
|
||||
assert "run_1" in goal_a_runs
|
||||
assert "run_2" in goal_a_runs
|
||||
assert len(goal_b_runs) == 1
|
||||
assert "run_3" in goal_b_runs
|
||||
|
||||
def test_index_by_status(self, tmp_path: Path):
|
||||
"""Runs should be indexed by status."""
|
||||
storage = FileStorage(tmp_path)
|
||||
|
||||
run1 = create_test_run(run_id="run_1", status=RunStatus.COMPLETED)
|
||||
run2 = create_test_run(run_id="run_2", status=RunStatus.FAILED)
|
||||
run3 = create_test_run(run_id="run_3", status=RunStatus.COMPLETED)
|
||||
|
||||
storage.save_run(run1)
|
||||
storage.save_run(run2)
|
||||
storage.save_run(run3)
|
||||
|
||||
completed = storage.get_runs_by_status(RunStatus.COMPLETED)
|
||||
failed = storage.get_runs_by_status(RunStatus.FAILED)
|
||||
|
||||
assert len(completed) == 2
|
||||
assert len(failed) == 1
|
||||
|
||||
def test_index_by_status_string(self, tmp_path: Path):
|
||||
"""get_runs_by_status should accept string status."""
|
||||
storage = FileStorage(tmp_path)
|
||||
|
||||
run = create_test_run(status=RunStatus.RUNNING)
|
||||
storage.save_run(run)
|
||||
|
||||
runs = storage.get_runs_by_status("running")
|
||||
assert len(runs) == 1
|
||||
|
||||
def test_index_by_node(self, tmp_path: Path):
|
||||
"""Runs should be indexed by executed nodes."""
|
||||
storage = FileStorage(tmp_path)
|
||||
|
||||
run1 = create_test_run(run_id="run_1", nodes_executed=["node_a", "node_b"])
|
||||
run2 = create_test_run(run_id="run_2", nodes_executed=["node_a", "node_c"])
|
||||
|
||||
storage.save_run(run1)
|
||||
storage.save_run(run2)
|
||||
|
||||
node_a_runs = storage.get_runs_by_node("node_a")
|
||||
node_b_runs = storage.get_runs_by_node("node_b")
|
||||
node_c_runs = storage.get_runs_by_node("node_c")
|
||||
|
||||
assert len(node_a_runs) == 2
|
||||
assert len(node_b_runs) == 1
|
||||
assert len(node_c_runs) == 1
|
||||
|
||||
def test_delete_removes_from_indexes(self, tmp_path: Path):
|
||||
"""Deleting a run should remove it from all indexes."""
|
||||
storage = FileStorage(tmp_path)
|
||||
|
||||
run = create_test_run(
|
||||
run_id="run_1",
|
||||
goal_id="goal_a",
|
||||
status=RunStatus.COMPLETED,
|
||||
nodes_executed=["node_1"],
|
||||
)
|
||||
storage.save_run(run)
|
||||
|
||||
# Verify indexed
|
||||
assert "run_1" in storage.get_runs_by_goal("goal_a")
|
||||
assert "run_1" in storage.get_runs_by_status(RunStatus.COMPLETED)
|
||||
assert "run_1" in storage.get_runs_by_node("node_1")
|
||||
|
||||
# Delete
|
||||
storage.delete_run("run_1")
|
||||
|
||||
# Verify removed from indexes
|
||||
assert "run_1" not in storage.get_runs_by_goal("goal_a")
|
||||
assert "run_1" not in storage.get_runs_by_status(RunStatus.COMPLETED)
|
||||
assert "run_1" not in storage.get_runs_by_node("node_1")
|
||||
|
||||
def test_empty_index_returns_empty_list(self, tmp_path: Path):
|
||||
"""Querying an empty index should return empty list."""
|
||||
storage = FileStorage(tmp_path)
|
||||
|
||||
assert storage.get_runs_by_goal("nonexistent") == []
|
||||
assert storage.get_runs_by_status("nonexistent") == []
|
||||
assert storage.get_runs_by_node("nonexistent") == []
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="FileStorage is deprecated - use unified session storage")
|
||||
class TestFileStorageListOperations:
|
||||
"""Test FileStorage list operations."""
|
||||
|
||||
def test_list_all_runs(self, tmp_path: Path):
|
||||
"""Test listing all run IDs."""
|
||||
storage = FileStorage(tmp_path)
|
||||
|
||||
storage.save_run(create_test_run(run_id="run_1"))
|
||||
storage.save_run(create_test_run(run_id="run_2"))
|
||||
storage.save_run(create_test_run(run_id="run_3"))
|
||||
|
||||
all_runs = storage.list_all_runs()
|
||||
|
||||
assert len(all_runs) == 3
|
||||
assert set(all_runs) == {"run_1", "run_2", "run_3"}
|
||||
|
||||
def test_list_all_goals(self, tmp_path: Path):
|
||||
"""Test listing all goal IDs that have runs."""
|
||||
storage = FileStorage(tmp_path)
|
||||
|
||||
storage.save_run(create_test_run(run_id="run_1", goal_id="goal_a"))
|
||||
storage.save_run(create_test_run(run_id="run_2", goal_id="goal_b"))
|
||||
storage.save_run(create_test_run(run_id="run_3", goal_id="goal_a"))
|
||||
|
||||
all_goals = storage.list_all_goals()
|
||||
|
||||
assert len(all_goals) == 2
|
||||
assert set(all_goals) == {"goal_a", "goal_b"}
|
||||
|
||||
def test_get_stats(self, tmp_path: Path):
|
||||
"""Test getting storage statistics."""
|
||||
storage = FileStorage(tmp_path)
|
||||
|
||||
storage.save_run(create_test_run(run_id="run_1", goal_id="goal_a"))
|
||||
storage.save_run(create_test_run(run_id="run_2", goal_id="goal_b"))
|
||||
|
||||
stats = storage.get_stats()
|
||||
|
||||
assert stats["total_runs"] == 2
|
||||
assert stats["total_goals"] == 2
|
||||
assert stats["storage_path"] == str(tmp_path)
|
||||
|
||||
|
||||
# === CACHE ENTRY TESTS ===
|
||||
|
||||
|
||||
@@ -332,7 +59,6 @@ class TestCacheEntry:
|
||||
# === CONCURRENTSTORAGE TESTS ===
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="ConcurrentStorage is deprecated - wraps deprecated FileStorage")
|
||||
class TestConcurrentStorageBasics:
|
||||
"""Test basic ConcurrentStorage operations."""
|
||||
|
||||
@@ -377,168 +103,6 @@ class TestConcurrentStorageBasics:
|
||||
assert storage._running is False
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="ConcurrentStorage is deprecated - wraps deprecated FileStorage")
|
||||
class TestConcurrentStorageRunOperations:
|
||||
"""Test ConcurrentStorage run operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_and_load_run(self, tmp_path: Path):
|
||||
"""Test async save and load of a run."""
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
run = create_test_run()
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
loaded = await storage.load_run(run.id)
|
||||
|
||||
assert loaded is not None
|
||||
assert loaded.id == run.id
|
||||
assert loaded.goal_id == run.goal_id
|
||||
finally:
|
||||
await storage.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_run_uses_cache(self, tmp_path: Path):
|
||||
"""Second load should use cached value."""
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
run = create_test_run()
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# First load
|
||||
loaded1 = await storage.load_run(run.id)
|
||||
# Second load (should use cache)
|
||||
loaded2 = await storage.load_run(run.id, use_cache=True)
|
||||
|
||||
assert loaded1 is not None
|
||||
assert loaded2 is not None
|
||||
# Cache should return same object
|
||||
assert loaded1 is loaded2
|
||||
finally:
|
||||
await storage.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_run_bypass_cache(self, tmp_path: Path):
|
||||
"""Load with use_cache=False should bypass cache."""
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
run = create_test_run()
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
loaded1 = await storage.load_run(run.id)
|
||||
loaded2 = await storage.load_run(run.id, use_cache=False)
|
||||
|
||||
assert loaded1 is not None
|
||||
assert loaded2 is not None
|
||||
# Fresh load should be different object
|
||||
assert loaded1 is not loaded2
|
||||
finally:
|
||||
await storage.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_run(self, tmp_path: Path):
|
||||
"""Test async delete of a run."""
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
run = create_test_run()
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
result = await storage.delete_run(run.id)
|
||||
|
||||
assert result is True
|
||||
loaded = await storage.load_run(run.id)
|
||||
assert loaded is None
|
||||
finally:
|
||||
await storage.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_clears_cache(self, tmp_path: Path):
|
||||
"""Deleting a run should clear it from cache."""
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
run = create_test_run()
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Load to populate cache
|
||||
await storage.load_run(run.id)
|
||||
assert f"run:{run.id}" in storage._cache
|
||||
|
||||
# Delete
|
||||
await storage.delete_run(run.id)
|
||||
|
||||
# Cache should be cleared
|
||||
assert f"run:{run.id}" not in storage._cache
|
||||
finally:
|
||||
await storage.stop()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="ConcurrentStorage is deprecated - wraps deprecated FileStorage")
|
||||
class TestConcurrentStorageQueryOperations:
|
||||
"""Test ConcurrentStorage query operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_runs_by_goal(self, tmp_path: Path):
|
||||
"""Test async query by goal."""
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
run1 = create_test_run(run_id="run_1", goal_id="goal_a")
|
||||
run2 = create_test_run(run_id="run_2", goal_id="goal_a")
|
||||
|
||||
await storage.save_run(run1, immediate=True)
|
||||
await storage.save_run(run2, immediate=True)
|
||||
|
||||
runs = await storage.get_runs_by_goal("goal_a")
|
||||
|
||||
assert len(runs) == 2
|
||||
finally:
|
||||
await storage.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_runs_by_status(self, tmp_path: Path):
|
||||
"""Test async query by status."""
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
run = create_test_run(status=RunStatus.FAILED)
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
runs = await storage.get_runs_by_status(RunStatus.FAILED)
|
||||
|
||||
assert len(runs) == 1
|
||||
finally:
|
||||
await storage.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_all_runs(self, tmp_path: Path):
|
||||
"""Test async list all runs."""
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
await storage.save_run(create_test_run(run_id="run_1"), immediate=True)
|
||||
await storage.save_run(create_test_run(run_id="run_2"), immediate=True)
|
||||
|
||||
runs = await storage.list_all_runs()
|
||||
|
||||
assert len(runs) == 2
|
||||
finally:
|
||||
await storage.stop()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="ConcurrentStorage is deprecated - wraps deprecated FileStorage")
|
||||
class TestConcurrentStorageCacheManagement:
|
||||
"""Test ConcurrentStorage cache management."""
|
||||
|
||||
@@ -576,60 +140,3 @@ class TestConcurrentStorageCacheManagement:
|
||||
assert stats["total_entries"] == 2
|
||||
assert stats["expired_entries"] == 1
|
||||
assert stats["valid_entries"] == 1
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="ConcurrentStorage is deprecated - wraps deprecated FileStorage")
|
||||
class TestConcurrentStorageSyncAPI:
|
||||
"""Test ConcurrentStorage synchronous API for backward compatibility."""
|
||||
|
||||
def test_save_run_sync(self, tmp_path: Path):
|
||||
"""Test synchronous save."""
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
run = create_test_run()
|
||||
|
||||
storage.save_run_sync(run)
|
||||
|
||||
# Verify saved
|
||||
loaded = storage.load_run_sync(run.id)
|
||||
assert loaded is not None
|
||||
assert loaded.id == run.id
|
||||
|
||||
def test_load_run_sync(self, tmp_path: Path):
|
||||
"""Test synchronous load."""
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
run = create_test_run()
|
||||
|
||||
storage.save_run_sync(run)
|
||||
loaded = storage.load_run_sync(run.id)
|
||||
|
||||
assert loaded is not None
|
||||
|
||||
def test_load_run_sync_nonexistent(self, tmp_path: Path):
|
||||
"""Synchronous load of nonexistent run returns None."""
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
|
||||
loaded = storage.load_run_sync("nonexistent")
|
||||
assert loaded is None
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="ConcurrentStorage is deprecated - wraps deprecated FileStorage")
|
||||
class TestConcurrentStorageStats:
|
||||
"""Test ConcurrentStorage statistics."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats(self, tmp_path: Path):
|
||||
"""Test getting async storage stats."""
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
await storage.save_run(create_test_run(), immediate=True)
|
||||
|
||||
stats = await storage.get_stats()
|
||||
|
||||
assert stats["total_runs"] == 1
|
||||
assert "cache" in stats
|
||||
assert "pending_writes" in stats
|
||||
assert stats["running"] is True
|
||||
finally:
|
||||
await storage.stop()
|
||||
|
||||
@@ -214,7 +214,7 @@ def test_load_registry_servers_retries_when_registration_returns_zero(monkeypatc
|
||||
registry = ToolRegistry()
|
||||
attempts = {"count": 0}
|
||||
|
||||
def fake_register(server_config, use_connection_manager=True):
|
||||
def fake_register(server_config, use_connection_manager=True, **kwargs):
|
||||
attempts["count"] += 1
|
||||
return 0 if attempts["count"] == 1 else 2
|
||||
|
||||
|
||||
+53
-40
@@ -198,33 +198,44 @@ Use the coder-tools MCP tools from your IDE agent chat (e.g., initialize_and_bui
|
||||
|
||||
If you prefer to build agents manually:
|
||||
|
||||
```python
|
||||
# exports/my_agent/agent.json
|
||||
```jsonc
|
||||
// exports/my_agent/agent.json
|
||||
{
|
||||
"goal": {
|
||||
"agent": {
|
||||
"id": "my_agent",
|
||||
"name": "Support Ticket Handler",
|
||||
"version": "1.0.0",
|
||||
"description": "Process customer support tickets"
|
||||
},
|
||||
"graph": {
|
||||
"id": "my_agent-graph",
|
||||
"goal_id": "support_ticket",
|
||||
"entry_node": "analyze",
|
||||
"terminal_nodes": ["analyze"],
|
||||
"nodes": [
|
||||
{
|
||||
"id": "analyze",
|
||||
"name": "Analyze Ticket",
|
||||
"description": "Categorize and prioritize the support ticket",
|
||||
"node_type": "event_loop",
|
||||
"system_prompt": "Analyze this support ticket...",
|
||||
"input_keys": ["ticket_content"],
|
||||
"output_keys": ["category", "priority"]
|
||||
}
|
||||
],
|
||||
"edges": []
|
||||
},
|
||||
"goal": {
|
||||
"id": "support_ticket",
|
||||
"name": "Support Ticket Handler",
|
||||
"description": "Process customer support tickets",
|
||||
"success_criteria": "Ticket is categorized, prioritized, and routed correctly"
|
||||
},
|
||||
"nodes": [
|
||||
{
|
||||
"node_id": "analyze",
|
||||
"name": "Analyze Ticket",
|
||||
"node_type": "event_loop",
|
||||
"system_prompt": "Analyze this support ticket...",
|
||||
"input_keys": ["ticket_content"],
|
||||
"output_keys": ["category", "priority"]
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"edge_id": "start_to_analyze",
|
||||
"source": "START",
|
||||
"target": "analyze",
|
||||
"condition": "on_success"
|
||||
}
|
||||
]
|
||||
"success_criteria": [
|
||||
{
|
||||
"id": "sc-categorized",
|
||||
"description": "Ticket is categorized and prioritized correctly"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -532,16 +543,17 @@ def my_custom_tool(param1: str, param2: int) -> Dict[str, Any]:
|
||||
# Implementation
|
||||
return {"result": "success", "data": ...}
|
||||
|
||||
# Register tool in agent.json
|
||||
# Register tool in agent.json (inside "graph" → "nodes")
|
||||
{
|
||||
"nodes": [
|
||||
{
|
||||
"node_id": "use_tool",
|
||||
"node_type": "event_loop",
|
||||
"tools": ["my_custom_tool"],
|
||||
...
|
||||
}
|
||||
]
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "use_tool",
|
||||
"node_type": "event_loop",
|
||||
"tools": ["my_custom_tool"]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -560,15 +572,16 @@ def my_custom_tool(param1: str, param2: int) -> Dict[str, Any]:
|
||||
}
|
||||
}
|
||||
|
||||
# 2. Reference tools in agent.json
|
||||
# 2. Reference tools in agent.json (inside "graph" → "nodes")
|
||||
{
|
||||
"nodes": [
|
||||
{
|
||||
"node_id": "search",
|
||||
"tools": ["web_search", "web_scrape"],
|
||||
...
|
||||
}
|
||||
]
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "search",
|
||||
"tools": ["web_search", "web_scrape"]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
@@ -31,8 +31,8 @@
|
||||
"nullable_output_keys": [],
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
"system_prompt": "You are a career analyst helping a job seeker find their best opportunities.\n\n**STEP 1 \u2014 Greet and collect resume (text only, NO tool calls):**\n\nAsk the user to paste their resume. Be friendly and concise:\n\"Please paste your resume below. I'll analyze your experience and identify the roles where you have the strongest chance of success.\"\n\n**STEP 2 \u2014 After the user provides their resume:**\n\nAnalyze the resume thoroughly:\n1. Identify key skills (technical and soft skills)\n2. Summarize years and types of experience\n3. Identify 3-5 SPECIFIC, GRANULAR role types where they're competitive\n\n**IMPORTANT \u2014 Role Specificity:**\nRespect the job seeker by providing granular options, not generic buckets.\n- BAD: \"Software Engineer\" (too broad)\n- GOOD: \"Backend Engineer (Python/Django)\", \"Platform Engineer\", \"API Developer\", \"Data Pipeline Engineer\"\n\nEach role should be distinct and searchable. The more specific, the better the job matches will be\n\nPresent your analysis to the user and ask if they agree with the role types identified. DO NOT ask follow-up questions. DO NOT ask which roles to focus on.\n\n**STEP 3 \u2014 After user confirms roles, call set_output:**\n\nUse set_output to store:\n- set_output(\"resume_text\", \"<the full resume text>\")\n- set_output(\"role_analysis\", \"<JSON with: skills, experience_summary, target_roles (3-5 specific role titles)>\")\n\nIMPORTANT: When the user says \"yes\", \"sure\", \"go ahead\", \"find jobs\" or similar, call set_output IMMEDIATELY. NEVER ask the user to pick between roles.",
|
||||
"tools": [],
|
||||
"system_prompt": "You are a career analyst helping a job seeker find their best opportunities.\n\n**STEP 1 \u2014 Greet and collect resume:**\n\nAsk the user to provide their resume. They can either paste the text directly or provide a path to a PDF file. Be friendly and concise:\n\"Please paste your resume below, or provide the file path to your PDF resume (e.g., /path/to/resume.pdf). I'll analyze your experience and identify the roles where you have the strongest chance of success.\"\n\nIf the user provides a file path to a PDF, call pdf_read(file_path=\"<path>\") to extract the text before proceeding.\n\n**STEP 2 \u2014 After the user provides their resume:**\n\nAnalyze the resume thoroughly:\n1. Identify key skills (technical and soft skills)\n2. Summarize years and types of experience\n3. Identify 3-5 SPECIFIC, GRANULAR role types where they're competitive\n\n**IMPORTANT \u2014 Role Specificity:**\nRespect the job seeker by providing granular options, not generic buckets.\n- BAD: \"Software Engineer\" (too broad)\n- GOOD: \"Backend Engineer (Python/Django)\", \"Platform Engineer\", \"API Developer\", \"Data Pipeline Engineer\"\n\nEach role should be distinct and searchable. The more specific, the better the job matches will be\n\nPresent your analysis to the user and ask if they agree with the role types identified. DO NOT ask follow-up questions. DO NOT ask which roles to focus on.\n\n**STEP 3 \u2014 After user confirms roles, call set_output:**\n\nUse set_output to store:\n- set_output(\"resume_text\", \"<the full resume text>\")\n- set_output(\"role_analysis\", \"<JSON with: skills, experience_summary, target_roles (3-5 specific role titles)>\")\n\nIMPORTANT: When the user says \"yes\", \"sure\", \"go ahead\", \"find jobs\" or similar, call set_output IMMEDIATELY. NEVER ask the user to pick between roles.",
|
||||
"tools": ["pdf_read"],
|
||||
"model": null,
|
||||
"function": null,
|
||||
"routes": {},
|
||||
@@ -261,7 +261,8 @@
|
||||
"append_data",
|
||||
"serve_file_to_user",
|
||||
"web_scrape",
|
||||
"gmail_create_draft"
|
||||
"gmail_create_draft",
|
||||
"pdf_read"
|
||||
],
|
||||
"metadata": {
|
||||
"created_at": "2026-02-13T18:41:10.324531",
|
||||
|
||||
@@ -9,9 +9,9 @@ intake_node = NodeSpec(
|
||||
name="Intake",
|
||||
description="Analyze resume and identify 3-5 strongest role types",
|
||||
node_type="event_loop",
|
||||
client_facing=False,
|
||||
client_facing=True,
|
||||
max_node_visits=1,
|
||||
input_keys=["resume_text"],
|
||||
input_keys=[],
|
||||
output_keys=["resume_text", "role_analysis"],
|
||||
success_criteria=(
|
||||
"The user's resume has been analyzed and 3-5 target roles identified "
|
||||
@@ -20,6 +20,12 @@ intake_node = NodeSpec(
|
||||
system_prompt="""\
|
||||
You are a career analyst. Your task is to analyze the user's resume and identify the best role fits.
|
||||
|
||||
**ACCEPTING THE RESUME:**
|
||||
The user can provide their resume in two ways:
|
||||
1. **Paste text** — The user pastes their resume content directly.
|
||||
2. **PDF file path** — The user provides a path to a PDF file (e.g., "/path/to/resume.pdf"). \
|
||||
If a file path is provided, call pdf_read(file_path="<path>") to extract the text before analyzing.
|
||||
|
||||
**PROCESS:**
|
||||
1. Identify key skills (technical and soft skills).
|
||||
2. Summarize years and types of experience.
|
||||
@@ -32,7 +38,7 @@ You MUST call set_output to store:
|
||||
|
||||
Do NOT wait for user confirmation. Simply perform the analysis and set the outputs.
|
||||
""",
|
||||
tools=[],
|
||||
tools=["pdf_read"],
|
||||
)
|
||||
|
||||
# Node 2: Job Search (simple)
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
{ "include": ["hive-tools"] }
|
||||
+165
-13
@@ -1022,6 +1022,10 @@ $hiveKey = [System.Environment]::GetEnvironmentVariable("HIVE_API_KEY", "User")
|
||||
if (-not $hiveKey) { $hiveKey = $env:HIVE_API_KEY }
|
||||
if ($hiveKey) { $HiveCredDetected = $true }
|
||||
|
||||
$AntigravityCredDetected = $false
|
||||
$antigravityAuthPath = Join-Path $env:USERPROFILE ".hive\antigravity-accounts.json"
|
||||
if (Test-Path $antigravityAuthPath) { $AntigravityCredDetected = $true }
|
||||
|
||||
# Detect API key providers
|
||||
$ProviderMenuEnvVars = @("ANTHROPIC_API_KEY", "OPENAI_API_KEY", "GEMINI_API_KEY", "GROQ_API_KEY", "CEREBRAS_API_KEY", "OPENROUTER_API_KEY")
|
||||
$ProviderMenuNames = @("Anthropic (Claude) - Recommended", "OpenAI (GPT)", "Google Gemini - Free tier available", "Groq - Fast, free tier", "Cerebras - Fast, free tier", "OpenRouter - Bring any OpenRouter model")
|
||||
@@ -1035,6 +1039,12 @@ $ProviderMenuUrls = @(
|
||||
"https://openrouter.ai/keys"
|
||||
)
|
||||
|
||||
$OllamaDetected = $false
|
||||
try {
|
||||
$null = & ollama list 2>$null
|
||||
if ($LASTEXITCODE -eq 0) { $OllamaDetected = $true }
|
||||
} catch { }
|
||||
|
||||
# ── Read previous configuration (if any) ──────────────────────
|
||||
$PrevProvider = ""
|
||||
$PrevModel = ""
|
||||
@@ -1051,6 +1061,7 @@ if (Test-Path $HiveConfigFile) {
|
||||
if ($prevLlm.use_claude_code_subscription) { $PrevSubMode = "claude_code" }
|
||||
elseif ($prevLlm.use_codex_subscription) { $PrevSubMode = "codex" }
|
||||
elseif ($prevLlm.use_kimi_code_subscription) { $PrevSubMode = "kimi_code" }
|
||||
elseif ($prevLlm.use_antigravity_subscription) { $PrevSubMode = "antigravity" }
|
||||
elseif ($prevLlm.api_base -and $prevLlm.api_base -like "*api.z.ai*") { $PrevSubMode = "zai_code" }
|
||||
elseif ($prevLlm.provider -eq "minimax" -or ($prevLlm.api_base -and $prevLlm.api_base -like "*api.minimax.io*")) { $PrevSubMode = "minimax_code" }
|
||||
elseif ($prevLlm.api_base -and $prevLlm.api_base -like "*api.kimi.com*") { $PrevSubMode = "kimi_code" }
|
||||
@@ -1070,8 +1081,11 @@ if ($PrevSubMode -or $PrevProvider) {
|
||||
"minimax_code" { if ($MinimaxCredDetected) { $prevCredValid = $true } }
|
||||
"kimi_code" { if ($KimiCredDetected) { $prevCredValid = $true } }
|
||||
"hive_llm" { if ($HiveCredDetected) { $prevCredValid = $true } }
|
||||
"antigravity" { if ($AntigravityCredDetected) { $prevCredValid = $true } }
|
||||
default {
|
||||
if ($PrevEnvVar) {
|
||||
if ($PrevProvider -eq "ollama") {
|
||||
$prevCredValid = $true
|
||||
} elseif ($PrevEnvVar) {
|
||||
$envVal = [System.Environment]::GetEnvironmentVariable($PrevEnvVar, "Process")
|
||||
if (-not $envVal) { $envVal = [System.Environment]::GetEnvironmentVariable($PrevEnvVar, "User") }
|
||||
if ($envVal) { $prevCredValid = $true }
|
||||
@@ -1086,17 +1100,20 @@ if ($PrevSubMode -or $PrevProvider) {
|
||||
"minimax_code" { $DefaultChoice = "4" }
|
||||
"kimi_code" { $DefaultChoice = "5" }
|
||||
"hive_llm" { $DefaultChoice = "6" }
|
||||
"antigravity" { $DefaultChoice = "7" }
|
||||
}
|
||||
if (-not $DefaultChoice) {
|
||||
switch ($PrevProvider) {
|
||||
"anthropic" { $DefaultChoice = "7" }
|
||||
"openai" { $DefaultChoice = "8" }
|
||||
"gemini" { $DefaultChoice = "9" }
|
||||
"groq" { $DefaultChoice = "10" }
|
||||
"cerebras" { $DefaultChoice = "11" }
|
||||
"openrouter" { $DefaultChoice = "12" }
|
||||
"anthropic" { $DefaultChoice = "8" }
|
||||
"openai" { $DefaultChoice = "9" }
|
||||
"gemini" { $DefaultChoice = "10" }
|
||||
"groq" { $DefaultChoice = "11" }
|
||||
"cerebras" { $DefaultChoice = "12" }
|
||||
"openrouter" { $DefaultChoice = "13" }
|
||||
"ollama" { $DefaultChoice = "14" }
|
||||
"minimax" { $DefaultChoice = "4" }
|
||||
"kimi" { $DefaultChoice = "5" }
|
||||
"hive" { $DefaultChoice = "6" }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1149,12 +1166,19 @@ Write-Host ") Hive LLM " -NoNewline
|
||||
Write-Color -Text "(use your Hive API key)" -Color DarkGray -NoNewline
|
||||
if ($HiveCredDetected) { Write-Color -Text " (credential detected)" -Color Green } else { Write-Host "" }
|
||||
|
||||
# 7) Antigravity
|
||||
Write-Host " " -NoNewline
|
||||
Write-Color -Text "7" -Color Cyan -NoNewline
|
||||
Write-Host ") Antigravity Subscription " -NoNewline
|
||||
Write-Color -Text "(use your Google/Gemini plan)" -Color DarkGray -NoNewline
|
||||
if ($AntigravityCredDetected) { Write-Color -Text " (credential detected)" -Color Green } else { Write-Host "" }
|
||||
|
||||
Write-Host ""
|
||||
Write-Color -Text " API key providers:" -Color Cyan
|
||||
|
||||
# 7-12) API key providers
|
||||
# 8-13) API key providers
|
||||
for ($idx = 0; $idx -lt $ProviderMenuEnvVars.Count; $idx++) {
|
||||
$num = $idx + 7
|
||||
$num = $idx + 8
|
||||
$envVal = [System.Environment]::GetEnvironmentVariable($ProviderMenuEnvVars[$idx], "Process")
|
||||
if (-not $envVal) { $envVal = [System.Environment]::GetEnvironmentVariable($ProviderMenuEnvVars[$idx], "User") }
|
||||
Write-Host " " -NoNewline
|
||||
@@ -1163,7 +1187,17 @@ for ($idx = 0; $idx -lt $ProviderMenuEnvVars.Count; $idx++) {
|
||||
if ($envVal) { Write-Color -Text " (credential detected)" -Color Green } else { Write-Host "" }
|
||||
}
|
||||
|
||||
$SkipChoice = 7 + $ProviderMenuEnvVars.Count
|
||||
# 14) Local (Ollama) - no API key needed
|
||||
Write-Host " " -NoNewline
|
||||
Write-Color -Text "14" -Color Cyan -NoNewline
|
||||
if ($OllamaDetected) {
|
||||
Write-Host ") Local (Ollama) - No API key needed " -NoNewline
|
||||
Write-Color -Text "(ollama detected)" -Color Green
|
||||
} else {
|
||||
Write-Host ") Local (Ollama) - No API key needed"
|
||||
}
|
||||
|
||||
$SkipChoice = 8 + $ProviderMenuEnvVars.Count + 1
|
||||
Write-Host " " -NoNewline
|
||||
Write-Color -Text "$SkipChoice" -Color Cyan -NoNewline
|
||||
Write-Host ") Skip for now"
|
||||
@@ -1301,9 +1335,48 @@ switch ($num) {
|
||||
}
|
||||
Write-Color -Text " Model: $SelectedModel | API: $HiveLlmEndpoint" -Color DarkGray
|
||||
}
|
||||
{ $_ -ge 7 -and $_ -le 12 } {
|
||||
7 {
|
||||
# Antigravity Subscription
|
||||
if (-not $AntigravityCredDetected) {
|
||||
Write-Host ""
|
||||
Write-Color -Text " Setting up Antigravity authentication..." -Color Cyan
|
||||
Write-Host ""
|
||||
Write-Warn "A browser window will open for Google OAuth."
|
||||
Write-Host " Sign in with your Google account that has Antigravity access."
|
||||
Write-Host ""
|
||||
try {
|
||||
$null = & $UvCmd run python (Join-Path $ScriptDir "core\antigravity_auth.py") auth account add 2>&1
|
||||
if ($LASTEXITCODE -eq 0 -and (Test-Path $antigravityAuthPath)) {
|
||||
$AntigravityCredDetected = $true
|
||||
}
|
||||
} catch {
|
||||
$AntigravityCredDetected = $false
|
||||
}
|
||||
|
||||
if (-not $AntigravityCredDetected) {
|
||||
Write-Host ""
|
||||
Write-Fail "Authentication failed or was cancelled."
|
||||
Write-Host ""
|
||||
$SelectedProviderId = ""
|
||||
}
|
||||
}
|
||||
|
||||
if ($AntigravityCredDetected) {
|
||||
$SubscriptionMode = "antigravity"
|
||||
$SelectedProviderId = "openai"
|
||||
$SelectedModel = "gemini-3-flash"
|
||||
$SelectedMaxTokens = 32768
|
||||
$SelectedMaxContextTokens = 1000000
|
||||
Write-Host ""
|
||||
Write-Warn "Using Antigravity can technically cause your account suspension. Please use at your own risk."
|
||||
Write-Host ""
|
||||
Write-Ok "Using Antigravity subscription"
|
||||
Write-Color -Text " Model: gemini-3-flash | Direct OAuth (no proxy required)" -Color DarkGray
|
||||
}
|
||||
}
|
||||
{ $_ -ge 8 -and $_ -le 13 } {
|
||||
# API key providers
|
||||
$provIdx = $num - 7
|
||||
$provIdx = $num - 8
|
||||
$SelectedEnvVar = $ProviderMenuEnvVars[$provIdx]
|
||||
$SelectedProviderId = $ProviderMenuIds[$provIdx]
|
||||
$providerName = $ProviderMenuNames[$provIdx] -replace ' - .*', '' # strip description
|
||||
@@ -1383,6 +1456,75 @@ switch ($num) {
|
||||
}
|
||||
}
|
||||
}
|
||||
14 {
|
||||
# Local (Ollama)
|
||||
if (-not $OllamaDetected) {
|
||||
Write-Host ""
|
||||
Write-Warn "Ollama depends on a local Ollama server, but 'ollama list' failed."
|
||||
Write-Host " Please install Ollama (https://ollama.com) and start the server,"
|
||||
Write-Host " then run this quickstart again."
|
||||
Write-Host ""
|
||||
exit 1
|
||||
}
|
||||
$SelectedProviderId = "ollama"
|
||||
Write-Host ""
|
||||
Write-Ok "Using Local (Ollama)"
|
||||
Write-Host ""
|
||||
|
||||
# Fetch available models
|
||||
$ollamaModels = @()
|
||||
try {
|
||||
$listOutput = & ollama list 2>$null
|
||||
if ($listOutput.Count -gt 1) {
|
||||
for ($i = 1; $i -lt $listOutput.Count; $i++) {
|
||||
$line = $listOutput[$i].Trim()
|
||||
if ($line) {
|
||||
$mName = ($line -split '\s+')[0]
|
||||
if ($mName) { $ollamaModels += $mName }
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch { }
|
||||
|
||||
if ($ollamaModels.Count -eq 0) {
|
||||
Write-Warn "No Ollama models found."
|
||||
Write-Host " Please open another terminal, run 'ollama run <model>' (e.g. 'ollama run llama3'),"
|
||||
Write-Host " and then run this quickstart again."
|
||||
Write-Host ""
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Show model picker
|
||||
Write-Host " Select an Ollama model:"
|
||||
Write-Host ""
|
||||
$defaultIdx = "1"
|
||||
for ($i = 0; $i -lt $ollamaModels.Count; $i++) {
|
||||
Write-Color -Text " $($i + 1)" -Color Cyan -NoNewline
|
||||
Write-Host ") $($ollamaModels[$i])"
|
||||
if ($PrevProvider -eq "ollama" -and $PrevModel -eq $ollamaModels[$i]) {
|
||||
$defaultIdx = [string]($i + 1)
|
||||
}
|
||||
}
|
||||
Write-Host ""
|
||||
|
||||
while ($true) {
|
||||
$raw = Read-Host "Enter choice (1-$($ollamaModels.Count)) [$defaultIdx]"
|
||||
if ([string]::IsNullOrWhiteSpace($raw)) { $raw = $defaultIdx }
|
||||
if ($raw -match '^\d+$') {
|
||||
$num = [int]$raw
|
||||
if ($num -ge 1 -and $num -le $ollamaModels.Count) {
|
||||
$SelectedModel = $ollamaModels[$num - 1]
|
||||
Write-Host ""
|
||||
Write-Ok "Model: $SelectedModel"
|
||||
$SelectedMaxTokens = 8192
|
||||
$SelectedMaxContextTokens = 16384
|
||||
$SelectedApiBase = "http://localhost:11434"
|
||||
break
|
||||
}
|
||||
}
|
||||
Write-Color -Text "Invalid choice. Please enter 1-$($ollamaModels.Count)" -Color Red
|
||||
}
|
||||
}
|
||||
{ $_ -eq $SkipChoice } {
|
||||
Write-Host ""
|
||||
Write-Warn "Skipped. An LLM API key is required to test and use worker agents."
|
||||
@@ -1686,6 +1828,8 @@ if ($SelectedProviderId) {
|
||||
$config.llm["use_claude_code_subscription"] = $true
|
||||
} elseif ($SubscriptionMode -eq "codex") {
|
||||
$config.llm["use_codex_subscription"] = $true
|
||||
} elseif ($SubscriptionMode -eq "antigravity") {
|
||||
$config.llm["use_antigravity_subscription"] = $true
|
||||
} elseif ($SubscriptionMode -eq "zai_code") {
|
||||
$config.llm["api_base"] = "https://api.z.ai/api/coding/paas/v4"
|
||||
$config.llm["api_key_env_var"] = $SelectedEnvVar
|
||||
@@ -1701,8 +1845,13 @@ if ($SelectedProviderId) {
|
||||
} elseif ($SelectedProviderId -eq "openrouter") {
|
||||
$config.llm["api_base"] = "https://openrouter.ai/api/v1"
|
||||
$config.llm["api_key_env_var"] = $SelectedEnvVar
|
||||
} else {
|
||||
} elseif ($SelectedProviderId -eq "ollama") {
|
||||
$config.llm["api_base"] = "http://localhost:11434"
|
||||
$config.llm.Remove("api_key_env_var")
|
||||
} elseif ($SelectedEnvVar) {
|
||||
$config.llm["api_key_env_var"] = $SelectedEnvVar
|
||||
} else {
|
||||
$config.llm.Remove("api_key_env_var")
|
||||
}
|
||||
|
||||
$config | ConvertTo-Json -Depth 4 | Set-Content -Path $HiveConfigFile -Encoding UTF8
|
||||
@@ -2003,6 +2152,9 @@ if ($SelectedProviderId) {
|
||||
Write-Color -Text " API: api.minimax.io/v1 (OpenAI-compatible)" -Color DarkGray
|
||||
} elseif ($SubscriptionMode -eq "codex") {
|
||||
Write-Ok "OpenAI Codex Subscription -> $SelectedModel"
|
||||
} elseif ($SubscriptionMode -eq "antigravity") {
|
||||
Write-Ok "Antigravity Subscription -> $SelectedModel"
|
||||
Write-Color -Text " Direct OAuth (no proxy required)" -Color DarkGray
|
||||
} elseif ($SelectedProviderId -eq "openrouter") {
|
||||
Write-Ok "OpenRouter API Key -> $SelectedModel"
|
||||
Write-Color -Text " API: openrouter.ai/api/v1 (OpenAI-compatible)" -Color DarkGray
|
||||
|
||||
+97
-8
@@ -673,7 +673,18 @@ detect_shell_rc() {
|
||||
fi
|
||||
;;
|
||||
bash)
|
||||
if [ -f "$HOME/.bashrc" ]; then
|
||||
# Git Bash on Windows commonly starts as a login shell, so prefer
|
||||
# .bash_profile there when it already exists. On Unix-like shells,
|
||||
# keep the traditional .bashrc-first behavior.
|
||||
if [ -n "$MSYSTEM" ] || [ -n "$MINGW_PREFIX" ]; then
|
||||
if [ -f "$HOME/.bash_profile" ]; then
|
||||
echo "$HOME/.bash_profile"
|
||||
elif [ -f "$HOME/.bashrc" ]; then
|
||||
echo "$HOME/.bashrc"
|
||||
else
|
||||
echo "$HOME/.profile"
|
||||
fi
|
||||
elif [ -f "$HOME/.bashrc" ]; then
|
||||
echo "$HOME/.bashrc"
|
||||
elif [ -f "$HOME/.bash_profile" ]; then
|
||||
echo "$HOME/.bash_profile"
|
||||
@@ -912,8 +923,9 @@ config["llm"] = {
|
||||
"model": model,
|
||||
"max_tokens": int(max_tokens),
|
||||
"max_context_tokens": int(max_context_tokens),
|
||||
"api_key_env_var": env_var,
|
||||
}
|
||||
if env_var:
|
||||
config["llm"]["api_key_env_var"] = env_var
|
||||
config["created_at"] = created_at
|
||||
|
||||
if use_claude_code_sub == "true":
|
||||
@@ -1024,6 +1036,11 @@ elif [ -f "$HOME/.hive/antigravity-accounts.json" ]; then
|
||||
ANTIGRAVITY_CRED_DETECTED=true
|
||||
fi
|
||||
|
||||
OLLAMA_DETECTED=false
|
||||
if ollama list >/dev/null 2>&1; then
|
||||
OLLAMA_DETECTED=true
|
||||
fi
|
||||
|
||||
# Detect API key providers
|
||||
if [ "$USE_ASSOC_ARRAYS" = true ]; then
|
||||
for env_var in "${!PROVIDER_NAMES[@]}"; do
|
||||
@@ -1056,9 +1073,12 @@ try:
|
||||
with open(cfg_path, encoding="utf-8-sig") as f:
|
||||
c = json.load(f)
|
||||
llm = c.get("llm", {})
|
||||
print(f"PREV_PROVIDER={llm.get(\"provider\", \"\")}")
|
||||
print(f"PREV_MODEL={llm.get(\"model\", \"\")}")
|
||||
print(f"PREV_ENV_VAR={llm.get(\"api_key_env_var\", \"\")}")
|
||||
prov = llm.get("provider", "")
|
||||
mod = llm.get("model", "")
|
||||
env = llm.get("api_key_env_var", "")
|
||||
print(f"PREV_PROVIDER='{prov}'")
|
||||
print(f"PREV_MODEL='{mod}'")
|
||||
print(f"PREV_ENV_VAR='{env}'")
|
||||
sub = ""
|
||||
if llm.get("use_claude_code_subscription"):
|
||||
sub = "claude_code"
|
||||
@@ -1093,8 +1113,12 @@ if [ -n "$PREV_SUB_MODE" ] || [ -n "$PREV_PROVIDER" ]; then
|
||||
hive_llm) [ "$HIVE_CRED_DETECTED" = true ] && PREV_CRED_VALID=true ;;
|
||||
antigravity) [ "$ANTIGRAVITY_CRED_DETECTED" = true ] && PREV_CRED_VALID=true ;;
|
||||
*)
|
||||
# API key provider — check if the env var is set
|
||||
if [ -n "$PREV_ENV_VAR" ] && [ -n "${!PREV_ENV_VAR}" ]; then
|
||||
# API key provider — check if the env var is set; ollama uses local runtime detection
|
||||
if [ "$PREV_PROVIDER" = "ollama" ]; then
|
||||
if [ "$OLLAMA_DETECTED" = true ]; then
|
||||
PREV_CRED_VALID=true
|
||||
fi
|
||||
elif [ -n "$PREV_ENV_VAR" ] && [ -n "${!PREV_ENV_VAR}" ]; then
|
||||
PREV_CRED_VALID=true
|
||||
fi
|
||||
;;
|
||||
@@ -1118,6 +1142,7 @@ if [ -n "$PREV_SUB_MODE" ] || [ -n "$PREV_PROVIDER" ]; then
|
||||
groq) DEFAULT_CHOICE=11 ;;
|
||||
cerebras) DEFAULT_CHOICE=12 ;;
|
||||
openrouter) DEFAULT_CHOICE=13 ;;
|
||||
ollama) DEFAULT_CHOICE=14 ;;
|
||||
minimax) DEFAULT_CHOICE=4 ;;
|
||||
kimi) DEFAULT_CHOICE=5 ;;
|
||||
hive) DEFAULT_CHOICE=6 ;;
|
||||
@@ -1196,7 +1221,14 @@ for idx in "${!PROVIDER_MENU_ENVS[@]}"; do
|
||||
fi
|
||||
done
|
||||
|
||||
SKIP_CHOICE=$((8 + ${#PROVIDER_MENU_ENVS[@]}))
|
||||
# 14) Local (Ollama) — no API key needed
|
||||
if [ "$OLLAMA_DETECTED" = true ]; then
|
||||
echo -e " ${CYAN}14)${NC} Local (Ollama) - No API key needed ${GREEN}(ollama detected)${NC}"
|
||||
else
|
||||
echo -e " ${CYAN}14)${NC} Local (Ollama) - No API key needed"
|
||||
fi
|
||||
|
||||
SKIP_CHOICE=$((8 + ${#PROVIDER_MENU_ENVS[@]} + 1))
|
||||
echo -e " ${CYAN}$SKIP_CHOICE)${NC} Skip for now"
|
||||
echo ""
|
||||
|
||||
@@ -1414,6 +1446,56 @@ case $choice in
|
||||
PROVIDER_NAME="OpenRouter"
|
||||
SIGNUP_URL="https://openrouter.ai/keys"
|
||||
;;
|
||||
14)
|
||||
# Local (Ollama) — no API key; pick model from ollama list
|
||||
if [ "$OLLAMA_DETECTED" != true ]; then
|
||||
echo ""
|
||||
echo -e "${YELLOW}Ollama depends on a local Ollama server, but 'ollama list' failed.${NC}"
|
||||
echo -e " Please install Ollama (https://ollama.com) and start the server,"
|
||||
echo -e " then run this quickstart again."
|
||||
echo ""
|
||||
exit 1
|
||||
fi
|
||||
SELECTED_PROVIDER_ID="ollama"
|
||||
SELECTED_ENV_VAR=""
|
||||
SELECTED_MAX_TOKENS=8192
|
||||
SELECTED_MAX_CONTEXT_TOKENS=16384
|
||||
OLLAMA_MODELS=()
|
||||
while IFS= read -r line; do
|
||||
[ -n "$line" ] && OLLAMA_MODELS+=("$line")
|
||||
done < <(ollama list 2>/dev/null | tail -n +2 | awk '{print $1}')
|
||||
if [ ${#OLLAMA_MODELS[@]} -gt 0 ]; then
|
||||
echo ""
|
||||
echo -e "${BOLD}Select an Ollama model:${NC}"
|
||||
echo ""
|
||||
for idx in "${!OLLAMA_MODELS[@]}"; do
|
||||
num=$((idx + 1))
|
||||
echo -e " ${CYAN}$num)${NC} ${OLLAMA_MODELS[$idx]}"
|
||||
done
|
||||
echo ""
|
||||
while true; do
|
||||
read -r -p "Enter choice (1-${#OLLAMA_MODELS[@]}): " model_choice
|
||||
if [[ "$model_choice" =~ ^[0-9]+$ ]] && [ "$model_choice" -ge 1 ] && [ "$model_choice" -le ${#OLLAMA_MODELS[@]} ]; then
|
||||
SELECTED_MODEL="${OLLAMA_MODELS[$((model_choice - 1))]}"
|
||||
SELECTED_API_BASE="http://localhost:11434"
|
||||
break
|
||||
fi
|
||||
echo -e "${RED}Invalid choice. Please enter 1-${#OLLAMA_MODELS[@]}${NC}"
|
||||
done
|
||||
echo ""
|
||||
echo -e "${GREEN}⬢${NC} Using Ollama with model ${DIM}$SELECTED_MODEL${NC}"
|
||||
echo -e "${YELLOW} ⚠ Note: The framework uses a ~9,500 token system prompt and requires strong tool use.${NC}"
|
||||
echo -e "${YELLOW} For best results, use models like qwen2.5:72b+ or mistral-large.${NC}"
|
||||
echo ""
|
||||
else
|
||||
echo ""
|
||||
echo -e "${RED}No Ollama models found.${NC}"
|
||||
echo -e " Please open another terminal, run ${CYAN}ollama pull llama3${NC} (or another model),"
|
||||
echo -e " and then run this quickstart again."
|
||||
echo ""
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
"$SKIP_CHOICE")
|
||||
echo ""
|
||||
echo -e "${YELLOW}Skipped.${NC} An LLM API key is required to test and use worker agents."
|
||||
@@ -1584,6 +1666,10 @@ if [ -n "$SELECTED_PROVIDER_ID" ]; then
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "$SELECTED_ENV_VAR" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" "" "$SELECTED_API_BASE" > /dev/null || SAVE_OK=false
|
||||
elif [ "$SELECTED_PROVIDER_ID" = "openrouter" ]; then
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "$SELECTED_ENV_VAR" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" "" "$SELECTED_API_BASE" > /dev/null || SAVE_OK=false
|
||||
elif [ "$SELECTED_PROVIDER_ID" = "ollama" ]; then
|
||||
# Pass api_base explicitly — LiteLLM requires this to route ollama/* models
|
||||
# to the local Ollama server instead of trying to reach a remote endpoint.
|
||||
save_configuration "ollama" "" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" "" "http://localhost:11434" > /dev/null || SAVE_OK=false
|
||||
else
|
||||
save_configuration "$SELECTED_PROVIDER_ID" "$SELECTED_ENV_VAR" "$SELECTED_MODEL" "$SELECTED_MAX_TOKENS" "$SELECTED_MAX_CONTEXT_TOKENS" > /dev/null || SAVE_OK=false
|
||||
fi
|
||||
@@ -1859,6 +1945,9 @@ if [ -n "$SELECTED_PROVIDER_ID" ]; then
|
||||
elif [ "$SELECTED_PROVIDER_ID" = "openrouter" ]; then
|
||||
echo -e " ${GREEN}⬢${NC} OpenRouter API Key → ${DIM}$SELECTED_MODEL${NC}"
|
||||
echo -e " ${DIM}API: openrouter.ai/api/v1 (OpenAI-compatible)${NC}"
|
||||
elif [ "$SELECTED_PROVIDER_ID" = "ollama" ]; then
|
||||
echo -e " ${GREEN}⬢${NC} Local (Ollama) → ${DIM}$SELECTED_MODEL${NC}"
|
||||
echo -e " ${DIM}No API key required (runs locally via http://localhost:11434)${NC}"
|
||||
else
|
||||
echo -e " ${CYAN}$SELECTED_PROVIDER_ID${NC} → ${DIM}$SELECTED_MODEL${NC}"
|
||||
fi
|
||||
|
||||
@@ -318,7 +318,22 @@ PROVIDERS = {
|
||||
key, "https://api.cerebras.ai/v1/models", "Cerebras"
|
||||
),
|
||||
"openrouter": lambda key, **kw: check_openrouter(key, **kw),
|
||||
"minimax": lambda key, **kw: check_minimax(key),
|
||||
"deepseek": lambda key, **_: check_openai_compatible(
|
||||
key, "https://api.deepseek.com/v1/models", "DeepSeek"
|
||||
),
|
||||
"together": lambda key, **_: check_openai_compatible(
|
||||
key, "https://api.together.xyz/v1/models", "Together AI"
|
||||
),
|
||||
"mistral": lambda key, **_: check_openai_compatible(
|
||||
key, "https://api.mistral.ai/v1/models", "Mistral"
|
||||
),
|
||||
"xai": lambda key, **_: check_openai_compatible(
|
||||
key, "https://api.x.ai/v1/models", "xAI"
|
||||
),
|
||||
"perplexity": lambda key, **_: check_openai_compatible(
|
||||
key, "https://api.perplexity.ai/v1/models", "Perplexity"
|
||||
),
|
||||
"minimax": lambda key, **_: check_minimax(key),
|
||||
# Kimi For Coding uses an Anthropic-compatible endpoint; check via /v1/messages
|
||||
# with empty messages (same as check_anthropic, triggers 400 not 401).
|
||||
"kimi": lambda key, **kw: check_anthropic_compatible(
|
||||
|
||||
@@ -95,6 +95,7 @@ from .kafka import KAFKA_CREDENTIALS
|
||||
from .langfuse import LANGFUSE_CREDENTIALS
|
||||
from .linear import LINEAR_CREDENTIALS
|
||||
from .lusha import LUSHA_CREDENTIALS
|
||||
from .mattermost import MATTERMOST_CREDENTIALS
|
||||
from .microsoft_graph import MICROSOFT_GRAPH_CREDENTIALS
|
||||
from .mongodb import MONGODB_CREDENTIALS
|
||||
from .n8n import N8N_CREDENTIALS
|
||||
@@ -179,6 +180,7 @@ CREDENTIAL_SPECS = {
|
||||
**LANGFUSE_CREDENTIALS,
|
||||
**LINEAR_CREDENTIALS,
|
||||
**LUSHA_CREDENTIALS,
|
||||
**MATTERMOST_CREDENTIALS,
|
||||
**MICROSOFT_GRAPH_CREDENTIALS,
|
||||
**MONGODB_CREDENTIALS,
|
||||
**N8N_CREDENTIALS,
|
||||
@@ -271,6 +273,7 @@ __all__ = [
|
||||
"LANGFUSE_CREDENTIALS",
|
||||
"LINEAR_CREDENTIALS",
|
||||
"LUSHA_CREDENTIALS",
|
||||
"MATTERMOST_CREDENTIALS",
|
||||
"MICROSOFT_GRAPH_CREDENTIALS",
|
||||
"MONGODB_CREDENTIALS",
|
||||
"N8N_CREDENTIALS",
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
Mattermost tool credentials.
|
||||
|
||||
Contains credentials for Mattermost server integration.
|
||||
"""
|
||||
|
||||
from .base import CredentialSpec
|
||||
|
||||
MATTERMOST_CREDENTIALS = {
|
||||
"mattermost": CredentialSpec(
|
||||
env_var="MATTERMOST_ACCESS_TOKEN",
|
||||
tools=[
|
||||
"mattermost_list_teams",
|
||||
"mattermost_list_channels",
|
||||
"mattermost_get_channel",
|
||||
"mattermost_send_message",
|
||||
"mattermost_get_posts",
|
||||
"mattermost_create_reaction",
|
||||
"mattermost_delete_post",
|
||||
],
|
||||
required=True,
|
||||
startup_required=False,
|
||||
help_url="https://developers.mattermost.com/integrate/reference/personal-access-token/",
|
||||
description="Mattermost Personal Access Token",
|
||||
aden_supported=False,
|
||||
direct_api_key_supported=True,
|
||||
api_key_instructions="""To get a Mattermost Personal Access Token:
|
||||
1. Log in to your Mattermost server
|
||||
2. Go to Profile > Security > Personal Access Tokens
|
||||
3. Click "Create Token"
|
||||
4. Give it a description and click "Save"
|
||||
5. Copy the token (it won't be shown again)
|
||||
|
||||
Note: Personal access tokens must be enabled by your System Admin.
|
||||
Also set MATTERMOST_URL to your server URL (e.g. https://mattermost.example.com)""",
|
||||
health_check_endpoint=None,
|
||||
health_check_method="GET",
|
||||
credential_id="mattermost",
|
||||
credential_key="access_token",
|
||||
),
|
||||
"mattermost_url": CredentialSpec(
|
||||
env_var="MATTERMOST_URL",
|
||||
tools=[
|
||||
"mattermost_list_teams",
|
||||
"mattermost_list_channels",
|
||||
"mattermost_get_channel",
|
||||
"mattermost_send_message",
|
||||
"mattermost_get_posts",
|
||||
"mattermost_create_reaction",
|
||||
"mattermost_delete_post",
|
||||
],
|
||||
required=True,
|
||||
startup_required=False,
|
||||
help_url="https://developers.mattermost.com/integrate/reference/personal-access-token/",
|
||||
description="Mattermost Server URL (e.g. https://mattermost.example.com)",
|
||||
aden_supported=False,
|
||||
direct_api_key_supported=True,
|
||||
api_key_instructions="""Set this to your Mattermost server URL, e.g. https://mattermost.example.com
|
||||
Do not include /api/v4 — it will be added automatically.""",
|
||||
health_check_endpoint=None,
|
||||
health_check_method="GET",
|
||||
credential_id="mattermost_url",
|
||||
credential_key="url",
|
||||
),
|
||||
}
|
||||
@@ -1,13 +1,15 @@
|
||||
"""
|
||||
Shell configuration utilities for persisting environment variables.
|
||||
|
||||
Supports both bash and zsh, detecting the user's default shell.
|
||||
Supports bash and zsh with platform-aware fallbacks for login-shell config
|
||||
files such as ``.bash_profile``, ``.zshenv``, and ``.profile``.
|
||||
Used primarily for persisting ADEN_API_KEY across sessions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
@@ -34,9 +36,9 @@ def detect_shell() -> ShellType:
|
||||
else:
|
||||
# Try to detect from config file existence
|
||||
home = Path.home()
|
||||
if (home / ".zshrc").exists():
|
||||
if (home / ".zshrc").exists() or (home / ".zshenv").exists():
|
||||
return "zsh"
|
||||
elif (home / ".bashrc").exists():
|
||||
elif (home / ".bashrc").exists() or (home / ".bash_profile").exists():
|
||||
return "bash"
|
||||
return "unknown"
|
||||
|
||||
@@ -55,14 +57,12 @@ def get_shell_config_path(shell_type: ShellType | None = None) -> Path:
|
||||
shell_type = detect_shell()
|
||||
|
||||
home = Path.home()
|
||||
candidates = _get_shell_config_candidates(home, shell_type)
|
||||
|
||||
if shell_type == "zsh":
|
||||
return home / ".zshrc"
|
||||
elif shell_type == "bash":
|
||||
return home / ".bashrc"
|
||||
else:
|
||||
# Default to .bashrc for unknown shells
|
||||
return home / ".bashrc"
|
||||
for candidate in candidates:
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
return candidates[0]
|
||||
|
||||
|
||||
def check_env_var_in_shell_config(
|
||||
@@ -79,29 +79,47 @@ def check_env_var_in_shell_config(
|
||||
Returns:
|
||||
Tuple of (exists, current_value or None)
|
||||
"""
|
||||
config_path = get_shell_config_path(shell_type)
|
||||
if shell_type is None:
|
||||
shell_type = detect_shell()
|
||||
|
||||
if not config_path.exists():
|
||||
return False, None
|
||||
for config_path in _get_shell_config_candidates(Path.home(), shell_type):
|
||||
if not config_path.exists():
|
||||
continue
|
||||
|
||||
content = config_path.read_text(encoding="utf-8")
|
||||
content = config_path.read_text(encoding="utf-8")
|
||||
|
||||
# Look for export ENV_VAR=value or export ENV_VAR="value"
|
||||
pattern = rf"^export\s+{re.escape(env_var)}=(.+)$"
|
||||
match = re.search(pattern, content, re.MULTILINE)
|
||||
# Look for export ENV_VAR=value or export ENV_VAR="value"
|
||||
pattern = rf"^export\s+{re.escape(env_var)}=(.+)$"
|
||||
match = re.search(pattern, content, re.MULTILINE)
|
||||
|
||||
if match:
|
||||
value = match.group(1).strip()
|
||||
# Remove surrounding quotes if present
|
||||
if (value.startswith('"') and value.endswith('"')) or (
|
||||
value.startswith("'") and value.endswith("'")
|
||||
):
|
||||
value = value[1:-1]
|
||||
return True, value
|
||||
if match:
|
||||
value = match.group(1).strip()
|
||||
# Remove surrounding quotes if present
|
||||
if (value.startswith('"') and value.endswith('"')) or (
|
||||
value.startswith("'") and value.endswith("'")
|
||||
):
|
||||
value = value[1:-1]
|
||||
return True, value
|
||||
|
||||
return False, None
|
||||
|
||||
|
||||
def _get_shell_config_candidates(home: Path, shell_type: ShellType) -> list[Path]:
|
||||
"""Return candidate config files in lookup order for the detected shell."""
|
||||
if shell_type == "zsh":
|
||||
return [home / ".zshrc", home / ".zshenv"]
|
||||
|
||||
if shell_type == "bash":
|
||||
# Git Bash commonly launches login shells on Windows, so prefer
|
||||
# ``.bash_profile`` there for writes, but keep ``.bashrc`` in the
|
||||
# lookup list so older setups continue to work.
|
||||
if platform.system() == "Windows":
|
||||
return [home / ".bash_profile", home / ".bashrc", home / ".profile"]
|
||||
return [home / ".bashrc", home / ".bash_profile", home / ".profile"]
|
||||
|
||||
return [home / ".profile", home / ".bashrc"]
|
||||
|
||||
|
||||
def add_env_var_to_shell_config(
|
||||
env_var: str,
|
||||
value: str,
|
||||
|
||||
@@ -88,6 +88,7 @@ from .kafka_tool import register_tools as register_kafka
|
||||
from .langfuse_tool import register_tools as register_langfuse
|
||||
from .linear_tool import register_tools as register_linear
|
||||
from .lusha_tool import register_tools as register_lusha
|
||||
from .mattermost_tool import register_tools as register_mattermost
|
||||
from .microsoft_graph_tool import register_tools as register_microsoft_graph
|
||||
from .mongodb_tool import register_tools as register_mongodb
|
||||
from .n8n_tool import register_tools as register_n8n
|
||||
@@ -266,6 +267,7 @@ def _register_unverified(
|
||||
register_langfuse(mcp, credentials=credentials)
|
||||
register_linear(mcp, credentials=credentials)
|
||||
register_lusha(mcp, credentials=credentials)
|
||||
register_mattermost(mcp, credentials=credentials)
|
||||
register_microsoft_graph(mcp, credentials=credentials)
|
||||
register_mongodb(mcp, credentials=credentials)
|
||||
register_n8n(mcp, credentials=credentials)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import csv
|
||||
import os
|
||||
import re
|
||||
|
||||
from fastmcp import FastMCP
|
||||
|
||||
@@ -330,39 +331,39 @@ def register_tools(mcp: FastMCP) -> None:
|
||||
if not query or not query.strip():
|
||||
return {"error": "query cannot be empty"}
|
||||
|
||||
# Security: only allow SELECT statements
|
||||
query_upper = query.strip().upper()
|
||||
if not query_upper.startswith("SELECT"):
|
||||
# Security: allow SELECT/WITH only
|
||||
query_upper = query.lstrip().upper()
|
||||
if not (query_upper.startswith("SELECT") or query_upper.startswith("WITH")):
|
||||
return {"error": "Only SELECT queries are allowed for security reasons"}
|
||||
|
||||
# Disallowed keywords for security
|
||||
disallowed = [
|
||||
"INSERT",
|
||||
"UPDATE",
|
||||
"DELETE",
|
||||
"DROP",
|
||||
"CREATE",
|
||||
"ALTER",
|
||||
"TRUNCATE",
|
||||
"EXEC",
|
||||
"EXECUTE",
|
||||
]
|
||||
for keyword in disallowed:
|
||||
if keyword in query_upper:
|
||||
return {"error": f"'{keyword}' is not allowed in queries"}
|
||||
# Disallowed keywords for security (word-boundary match to avoid
|
||||
# false positives on column names like created_at, updated_at, etc.)
|
||||
_WRITE_PATTERN = re.compile(
|
||||
r"\b(INSERT|UPDATE|DELETE|DROP|CREATE|ALTER|TRUNCATE|EXEC|EXECUTE)\b",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
match = _WRITE_PATTERN.search(query)
|
||||
if match:
|
||||
return {"error": f"'{match.group().upper()}' is not allowed in queries"}
|
||||
|
||||
# Block obvious multi-statement / injection attempts
|
||||
q_lower = query.lower()
|
||||
for token in [";", "--", "/*", "*/"]:
|
||||
if token in q_lower:
|
||||
return {"error": "Multiple statements or comments are not allowed"}
|
||||
|
||||
# Execute query using in-memory DuckDB
|
||||
con = duckdb.connect(":memory:")
|
||||
try:
|
||||
# Load CSV as 'data' table
|
||||
con.execute(f"CREATE TABLE data AS SELECT * FROM read_csv_auto('{secure_path}')")
|
||||
# SAFE: parameter binding (no string interpolation)
|
||||
con.execute(
|
||||
"CREATE TABLE data AS SELECT * FROM read_csv_auto(?)",
|
||||
[str(secure_path)],
|
||||
)
|
||||
|
||||
# Execute user query
|
||||
result = con.execute(query)
|
||||
columns = [desc[0] for desc in result.description]
|
||||
rows = result.fetchall()
|
||||
|
||||
# Convert to list of dicts
|
||||
rows_as_dicts = [dict(zip(columns, row, strict=False)) for row in rows]
|
||||
|
||||
return {
|
||||
@@ -374,12 +375,12 @@ def register_tools(mcp: FastMCP) -> None:
|
||||
"rows": rows_as_dicts,
|
||||
"row_count": len(rows_as_dicts),
|
||||
}
|
||||
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
# Make DuckDB errors more readable
|
||||
if "Catalog Error" in error_msg:
|
||||
return {"error": f"SQL error: {error_msg}. Remember the table is named 'data'."}
|
||||
return {"error": f"Query failed: {error_msg}"}
|
||||
|
||||
@@ -12,7 +12,6 @@ import os
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import httpx
|
||||
import resend
|
||||
from fastmcp import FastMCP
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -35,6 +34,15 @@ def register_tools(
|
||||
bcc: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""Send email using Resend API."""
|
||||
try:
|
||||
import resend
|
||||
except ImportError:
|
||||
return {
|
||||
"error": (
|
||||
"resend not installed. Install with: "
|
||||
"pip install resend or pip install tools[email]"
|
||||
)
|
||||
}
|
||||
resend.api_key = api_key
|
||||
try:
|
||||
payload: dict = {
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
"""Mattermost Tool - Send messages and interact with Mattermost servers."""
|
||||
|
||||
from .mattermost_tool import register_tools
|
||||
|
||||
__all__ = ["register_tools"]
|
||||
@@ -0,0 +1,447 @@
|
||||
"""
|
||||
Mattermost Tool - Send messages and interact with Mattermost servers via Mattermost API.
|
||||
|
||||
Supports:
|
||||
- Personal access tokens (MATTERMOST_ACCESS_TOKEN)
|
||||
- Self-hosted and cloud Mattermost instances (MATTERMOST_URL)
|
||||
|
||||
API Reference: https://api.mattermost.com/
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
from fastmcp import FastMCP
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
|
||||
MAX_MESSAGE_LENGTH = 16383 # Mattermost API limit
|
||||
MAX_RETRIES = 2 # 3 total attempts on 429
|
||||
MAX_RETRY_WAIT = 60 # cap wait at 60s
|
||||
|
||||
|
||||
class _MattermostClient:
|
||||
"""Internal client wrapping Mattermost API calls."""
|
||||
|
||||
def __init__(self, access_token: str, base_url: str):
|
||||
# Strip trailing slash and ensure /api/v4 suffix
|
||||
base_url = base_url.rstrip("/")
|
||||
if not base_url.endswith("/api/v4"):
|
||||
base_url = f"{base_url}/api/v4"
|
||||
self._base_url = base_url
|
||||
self._token = access_token
|
||||
|
||||
@property
|
||||
def _headers(self) -> dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"Bearer {self._token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def _request_with_retry(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""Make HTTP request with retry on 429 rate limit."""
|
||||
request_kwargs = {"headers": self._headers, "timeout": 30.0, **kwargs}
|
||||
for attempt in range(MAX_RETRIES + 1):
|
||||
response = httpx.request(method, url, **request_kwargs)
|
||||
if response.status_code == 429 and attempt < MAX_RETRIES:
|
||||
try:
|
||||
wait = min(float(response.headers.get("Retry-After", 1)), MAX_RETRY_WAIT)
|
||||
except (ValueError, TypeError):
|
||||
wait = min(2**attempt, MAX_RETRY_WAIT)
|
||||
time.sleep(wait)
|
||||
continue
|
||||
return self._handle_response(response)
|
||||
return self._handle_response(response)
|
||||
|
||||
def _handle_response(self, response: httpx.Response) -> dict[str, Any]:
|
||||
"""Handle Mattermost API response format."""
|
||||
if response.status_code == 204:
|
||||
return {"success": True}
|
||||
|
||||
if response.status_code == 429:
|
||||
try:
|
||||
retry_after = float(response.headers.get("Retry-After", 60))
|
||||
except (ValueError, TypeError):
|
||||
retry_after = 60
|
||||
return {
|
||||
"error": f"Mattermost rate limit exceeded. Retry after {retry_after}s",
|
||||
"retry_after": retry_after,
|
||||
}
|
||||
|
||||
if response.status_code not in (200, 201):
|
||||
try:
|
||||
data = response.json()
|
||||
message = data.get("message", response.text)
|
||||
except Exception:
|
||||
message = response.text
|
||||
return {"error": f"HTTP {response.status_code}: {message}"}
|
||||
|
||||
return response.json()
|
||||
|
||||
def get_me(self) -> dict[str, Any]:
|
||||
"""Get the authenticated user's info (health check)."""
|
||||
return self._request_with_retry("GET", f"{self._base_url}/users/me")
|
||||
|
||||
def list_teams(self) -> dict[str, Any]:
|
||||
"""List teams the authenticated user belongs to."""
|
||||
return self._request_with_retry("GET", f"{self._base_url}/users/me/teams")
|
||||
|
||||
def list_channels(self, team_id: str, per_page: int = 100) -> dict[str, Any]:
|
||||
"""List public channels for a team."""
|
||||
return self._request_with_retry(
|
||||
"GET",
|
||||
f"{self._base_url}/teams/{team_id}/channels",
|
||||
params={"per_page": min(per_page, 200)},
|
||||
)
|
||||
|
||||
def get_channel(self, channel_id: str) -> dict[str, Any]:
|
||||
"""Get detailed information about a channel."""
|
||||
return self._request_with_retry("GET", f"{self._base_url}/channels/{channel_id}")
|
||||
|
||||
def send_message(
|
||||
self,
|
||||
channel_id: str,
|
||||
message: str,
|
||||
*,
|
||||
root_id: str = "",
|
||||
) -> dict[str, Any]:
|
||||
"""Create a post in a channel."""
|
||||
body: dict[str, Any] = {
|
||||
"channel_id": channel_id,
|
||||
"message": message,
|
||||
}
|
||||
if root_id:
|
||||
body["root_id"] = root_id
|
||||
return self._request_with_retry(
|
||||
"POST",
|
||||
f"{self._base_url}/posts",
|
||||
json=body,
|
||||
)
|
||||
|
||||
def get_posts(
|
||||
self,
|
||||
channel_id: str,
|
||||
per_page: int = 60,
|
||||
page: int = 0,
|
||||
before: str = "",
|
||||
after: str = "",
|
||||
) -> dict[str, Any]:
|
||||
"""Get posts from a channel."""
|
||||
params: dict[str, Any] = {
|
||||
"per_page": min(per_page, 200),
|
||||
"page": page,
|
||||
}
|
||||
if before:
|
||||
params["before"] = before
|
||||
if after:
|
||||
params["after"] = after
|
||||
return self._request_with_retry(
|
||||
"GET",
|
||||
f"{self._base_url}/channels/{channel_id}/posts",
|
||||
params=params,
|
||||
)
|
||||
|
||||
def create_reaction(
|
||||
self,
|
||||
post_id: str,
|
||||
emoji_name: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Add a reaction to a post.
|
||||
|
||||
API ref: POST /reactions
|
||||
"""
|
||||
# Need user_id for the reaction; fetch from /users/me
|
||||
me = self.get_me()
|
||||
if isinstance(me, dict) and "error" in me:
|
||||
return me
|
||||
user_id = me.get("id", "")
|
||||
return self._request_with_retry(
|
||||
"POST",
|
||||
f"{self._base_url}/reactions",
|
||||
json={
|
||||
"user_id": user_id,
|
||||
"post_id": post_id,
|
||||
"emoji_name": emoji_name,
|
||||
},
|
||||
)
|
||||
|
||||
def delete_post(self, post_id: str) -> dict[str, Any]:
|
||||
"""Delete a post."""
|
||||
return self._request_with_retry("DELETE", f"{self._base_url}/posts/{post_id}")
|
||||
|
||||
|
||||
def register_tools(
|
||||
mcp: FastMCP,
|
||||
credentials: CredentialStoreAdapter | None = None,
|
||||
) -> None:
|
||||
"""Register Mattermost tools with the MCP server."""
|
||||
|
||||
def _get_token(account: str = "") -> str | None:
|
||||
"""Get Mattermost access token from credential manager or environment."""
|
||||
if credentials is not None:
|
||||
if account:
|
||||
return credentials.get_by_alias("mattermost", account)
|
||||
token = credentials.get("mattermost")
|
||||
if token is not None and not isinstance(token, str):
|
||||
raise TypeError(
|
||||
"Expected string from credentials.get('mattermost'), "
|
||||
f"got {type(token).__name__}"
|
||||
)
|
||||
return token
|
||||
return os.getenv("MATTERMOST_ACCESS_TOKEN")
|
||||
|
||||
def _get_url() -> str | None:
|
||||
"""Get Mattermost server URL from credential manager or environment."""
|
||||
if credentials is not None:
|
||||
url = credentials.get("mattermost_url")
|
||||
if url is not None and not isinstance(url, str):
|
||||
raise TypeError(
|
||||
"Expected string from credentials.get('mattermost_url'), "
|
||||
f"got {type(url).__name__}"
|
||||
)
|
||||
if url:
|
||||
return url
|
||||
return os.getenv("MATTERMOST_URL")
|
||||
|
||||
def _get_client(account: str = "") -> _MattermostClient | dict[str, str]:
|
||||
"""Get a Mattermost client, or return an error dict if no credentials."""
|
||||
token = _get_token(account)
|
||||
if not token:
|
||||
return {
|
||||
"error": "Mattermost credentials not configured",
|
||||
"help": (
|
||||
"Set MATTERMOST_ACCESS_TOKEN and MATTERMOST_URL environment variables "
|
||||
"or configure via credential store"
|
||||
),
|
||||
}
|
||||
url = _get_url()
|
||||
if not url:
|
||||
return {
|
||||
"error": "Mattermost server URL not configured",
|
||||
"help": (
|
||||
"Set MATTERMOST_URL environment variable (e.g. https://mattermost.example.com) "
|
||||
"or configure via credential store"
|
||||
),
|
||||
}
|
||||
return _MattermostClient(token, url)
|
||||
|
||||
@mcp.tool()
|
||||
def mattermost_list_teams(account: str = "") -> dict:
|
||||
"""
|
||||
List Mattermost teams the authenticated user belongs to.
|
||||
|
||||
Returns team IDs and names. Use team IDs with mattermost_list_channels.
|
||||
|
||||
Returns:
|
||||
Dict with list of teams or error
|
||||
"""
|
||||
client = _get_client(account)
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
result = client.list_teams()
|
||||
if isinstance(result, dict) and "error" in result:
|
||||
return result
|
||||
return {"teams": result, "success": True}
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": f"Network error: {e}"}
|
||||
|
||||
@mcp.tool()
|
||||
def mattermost_list_channels(team_id: str, per_page: int = 100, account: str = "") -> dict:
|
||||
"""
|
||||
List public channels for a Mattermost team.
|
||||
|
||||
Args:
|
||||
team_id: Team ID. Use mattermost_list_teams to find team IDs.
|
||||
per_page: Max channels to return (1-200, default 100).
|
||||
|
||||
Returns:
|
||||
Dict with list of channels or error
|
||||
"""
|
||||
client = _get_client(account)
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
result = client.list_channels(team_id, per_page=per_page)
|
||||
if isinstance(result, dict) and "error" in result:
|
||||
return result
|
||||
return {"channels": result, "success": True}
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": f"Network error: {e}"}
|
||||
|
||||
@mcp.tool()
|
||||
def mattermost_get_channel(channel_id: str, account: str = "") -> dict:
|
||||
"""
|
||||
Get detailed information about a Mattermost channel.
|
||||
|
||||
Returns channel metadata including name, display name, header, purpose,
|
||||
and type.
|
||||
|
||||
Args:
|
||||
channel_id: Channel ID
|
||||
|
||||
Returns:
|
||||
Dict with channel details or error
|
||||
"""
|
||||
client = _get_client(account)
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
result = client.get_channel(channel_id)
|
||||
if isinstance(result, dict) and "error" in result:
|
||||
return result
|
||||
return {"channel": result, "success": True}
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": f"Network error: {e}"}
|
||||
|
||||
@mcp.tool()
|
||||
def mattermost_send_message(
|
||||
channel_id: str,
|
||||
message: str,
|
||||
root_id: str = "",
|
||||
account: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Send a message (post) to a Mattermost channel.
|
||||
|
||||
Args:
|
||||
channel_id: Channel ID to post in
|
||||
message: Message text (max 16383 characters). Supports Markdown.
|
||||
root_id: Optional post ID to reply to (creates a thread)
|
||||
|
||||
Returns:
|
||||
Dict with post details or error
|
||||
"""
|
||||
if len(message) > MAX_MESSAGE_LENGTH:
|
||||
return {
|
||||
"error": f"Message exceeds {MAX_MESSAGE_LENGTH} character limit",
|
||||
"max_length": MAX_MESSAGE_LENGTH,
|
||||
"provided": len(message),
|
||||
}
|
||||
client = _get_client(account)
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
result = client.send_message(channel_id, message, root_id=root_id)
|
||||
if isinstance(result, dict) and "error" in result:
|
||||
return result
|
||||
return {"success": True, "post": result}
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": f"Network error: {e}"}
|
||||
|
||||
@mcp.tool()
|
||||
def mattermost_get_posts(
|
||||
channel_id: str,
|
||||
per_page: int = 60,
|
||||
page: int = 0,
|
||||
before: str = "",
|
||||
after: str = "",
|
||||
account: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Get posts from a Mattermost channel.
|
||||
|
||||
Args:
|
||||
channel_id: Channel ID
|
||||
per_page: Max posts to return (1-200, default 60)
|
||||
page: Page number for pagination (default 0)
|
||||
before: Post ID to get posts before (for pagination)
|
||||
after: Post ID to get posts after (for pagination)
|
||||
|
||||
Returns:
|
||||
Dict with posts or error
|
||||
"""
|
||||
client = _get_client(account)
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
result = client.get_posts(
|
||||
channel_id,
|
||||
per_page=per_page,
|
||||
page=page,
|
||||
before=before,
|
||||
after=after,
|
||||
)
|
||||
if isinstance(result, dict) and "error" in result:
|
||||
return result
|
||||
return {"posts": result, "success": True}
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": f"Network error: {e}"}
|
||||
|
||||
@mcp.tool()
|
||||
def mattermost_create_reaction(
|
||||
post_id: str,
|
||||
emoji_name: str,
|
||||
account: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Add a reaction to a Mattermost post.
|
||||
|
||||
Args:
|
||||
post_id: ID of the post to react to
|
||||
emoji_name: Emoji name without colons (e.g. "thumbsup", "heart")
|
||||
|
||||
Returns:
|
||||
Dict with success status or error
|
||||
"""
|
||||
client = _get_client(account)
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
result = client.create_reaction(post_id, emoji_name)
|
||||
if isinstance(result, dict) and "error" in result:
|
||||
return result
|
||||
return {"success": True}
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": f"Network error: {e}"}
|
||||
|
||||
@mcp.tool()
|
||||
def mattermost_delete_post(
|
||||
post_id: str,
|
||||
account: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Delete a post from Mattermost.
|
||||
|
||||
Requires appropriate permissions (post author or admin).
|
||||
|
||||
Args:
|
||||
post_id: ID of the post to delete
|
||||
|
||||
Returns:
|
||||
Dict with success status or error
|
||||
"""
|
||||
client = _get_client(account)
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
try:
|
||||
result = client.delete_post(post_id)
|
||||
if isinstance(result, dict) and "error" in result:
|
||||
return result
|
||||
return {"success": True, "deleted_post_id": post_id}
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": f"Network error: {e}"}
|
||||
@@ -8,25 +8,13 @@ from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
from fastmcp import FastMCP
|
||||
|
||||
|
||||
def _get_config() -> tuple[str, dict] | dict:
|
||||
"""Return (base_url, headers) or error dict."""
|
||||
base_url = os.getenv("SAP_BASE_URL", "").rstrip("/")
|
||||
username = os.getenv("SAP_USERNAME", "")
|
||||
password = os.getenv("SAP_PASSWORD", "")
|
||||
if not base_url or not username or not password:
|
||||
return {
|
||||
"error": "SAP_BASE_URL, SAP_USERNAME, and SAP_PASSWORD are required",
|
||||
"help": "Set SAP_BASE_URL, SAP_USERNAME, and SAP_PASSWORD environment variables",
|
||||
}
|
||||
creds = base64.b64encode(f"{username}:{password}".encode()).decode()
|
||||
headers = {"Authorization": f"Basic {creds}", "Accept": "application/json"}
|
||||
return base_url, headers
|
||||
if TYPE_CHECKING:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
|
||||
|
||||
def _get(url: str, headers: dict, params: dict | None = None) -> dict:
|
||||
@@ -45,9 +33,43 @@ def _odata_list(data: dict) -> tuple[list, int | None]:
|
||||
return results, count
|
||||
|
||||
|
||||
def register_tools(mcp: FastMCP, credentials: Any = None) -> None:
|
||||
def register_tools(
|
||||
mcp: FastMCP,
|
||||
credentials: CredentialStoreAdapter | None = None,
|
||||
) -> None:
|
||||
"""Register SAP S/4HANA tools."""
|
||||
|
||||
def _get_config() -> tuple[str, dict] | dict[str, str]:
|
||||
"""Return (base_url, headers) or error dict."""
|
||||
base_url = username = password = None
|
||||
if credentials is not None:
|
||||
try:
|
||||
base_url = credentials.get("sap_base_url")
|
||||
username = credentials.get("sap_username")
|
||||
password = credentials.get("sap_password")
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
base_url = base_url or os.getenv("SAP_BASE_URL")
|
||||
username = username or os.getenv("SAP_USERNAME")
|
||||
password = password or os.getenv("SAP_PASSWORD")
|
||||
|
||||
if not base_url or not username or not password:
|
||||
return {
|
||||
"error": "SAP credentials not configured",
|
||||
"help": (
|
||||
"Set SAP_BASE_URL, SAP_USERNAME, and SAP_PASSWORD "
|
||||
"environment variables or configure via credential store"
|
||||
),
|
||||
}
|
||||
base_url = base_url.rstrip("/")
|
||||
encoded = base64.b64encode(f"{username}:{password}".encode()).decode()
|
||||
headers = {
|
||||
"Authorization": f"Basic {encoded}",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
return base_url, headers
|
||||
|
||||
@mcp.tool()
|
||||
def sap_list_purchase_orders(
|
||||
top: int = 50,
|
||||
|
||||
@@ -4,10 +4,13 @@ Web Scrape Tool - Extract content from web pages.
|
||||
Uses Playwright with stealth for headless browser scraping,
|
||||
enabling JavaScript-rendered content and bot detection evasion.
|
||||
Uses BeautifulSoup for HTML parsing and content extraction.
|
||||
Validates URLs against internal network ranges to prevent SSRF attacks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import socket
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin, urlparse
|
||||
from urllib.robotparser import RobotFileParser
|
||||
@@ -29,6 +32,49 @@ BROWSER_USER_AGENT = (
|
||||
)
|
||||
|
||||
|
||||
def _is_internal_address(raw_ip: str) -> bool:
|
||||
"""Check whether an IP address targets non-public infrastructure."""
|
||||
ip_str = raw_ip.split("%")[0] if "%" in raw_ip else raw_ip
|
||||
try:
|
||||
addr = ipaddress.ip_address(ip_str)
|
||||
except ValueError:
|
||||
return True # Unparseable — fail closed
|
||||
return not addr.is_global or addr.is_multicast
|
||||
|
||||
|
||||
def _check_url_target(url: str) -> str | None:
|
||||
"""Resolve a URL's hostname and reject it if any address is non-public.
|
||||
|
||||
Returns an error message if blocked, None if safe.
|
||||
"""
|
||||
hostname = urlparse(url).hostname
|
||||
if not hostname:
|
||||
return "Invalid URL: missing hostname"
|
||||
|
||||
# Fast-path for raw IP literals
|
||||
try:
|
||||
ipaddress.ip_address(hostname)
|
||||
if _is_internal_address(hostname):
|
||||
return f"Blocked: direct request to internal address ({hostname})"
|
||||
except ValueError:
|
||||
pass # Not an IP literal, resolve below
|
||||
|
||||
try:
|
||||
results = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
|
||||
except socket.gaierror:
|
||||
return f"DNS resolution failed for host: {hostname}"
|
||||
|
||||
if not results:
|
||||
return f"No DNS records found for host: {hostname}"
|
||||
|
||||
for entry in results:
|
||||
resolved_ip = str(entry[4][0])
|
||||
if _is_internal_address(resolved_ip):
|
||||
return f"Blocked: {hostname} resolves to internal address"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def register_tools(mcp: FastMCP) -> None:
|
||||
"""Register web scrape tools with the MCP server."""
|
||||
|
||||
@@ -65,6 +111,12 @@ def register_tools(mcp: FastMCP) -> None:
|
||||
# Validate max_length
|
||||
max_length = max(1000, min(max_length, 500000))
|
||||
|
||||
# SSRF check: validate URL before making any request (must run
|
||||
# before robots.txt fetch, which also makes a network request)
|
||||
block_reason = _check_url_target(url)
|
||||
if block_reason is not None:
|
||||
return {"error": block_reason, "blocked_by_ssrf_protection": True, "url": url}
|
||||
|
||||
# Check robots.txt before launching browser
|
||||
if respect_robots_txt:
|
||||
try:
|
||||
@@ -102,12 +154,44 @@ def register_tools(mcp: FastMCP) -> None:
|
||||
page = await context.new_page()
|
||||
await Stealth().apply_stealth_async(page)
|
||||
|
||||
# Intercept navigation requests to block SSRF via redirects.
|
||||
# Only check "document" requests (navigations), not
|
||||
# sub-resources (CSS/JS/images) to avoid false positives
|
||||
# and unnecessary DNS lookups.
|
||||
ssrf_blocked: dict[str, Any] | None = None
|
||||
|
||||
async def _ssrf_route_handler(route):
|
||||
nonlocal ssrf_blocked
|
||||
req_url = route.request.url
|
||||
|
||||
# Skip non-network schemes (data:, blob:, etc.)
|
||||
if urlparse(req_url).scheme not in {"http", "https"}:
|
||||
await route.continue_()
|
||||
return
|
||||
|
||||
block = _check_url_target(req_url)
|
||||
if block is not None:
|
||||
ssrf_blocked = {
|
||||
"error": block,
|
||||
"blocked_by_ssrf_protection": True,
|
||||
"url": req_url,
|
||||
}
|
||||
await route.abort("blockedbyclient")
|
||||
else:
|
||||
await route.continue_()
|
||||
|
||||
await page.route("**/*", _ssrf_route_handler)
|
||||
|
||||
response = await page.goto(
|
||||
url,
|
||||
wait_until="domcontentloaded",
|
||||
timeout=60000,
|
||||
)
|
||||
|
||||
# Check if a redirect was blocked by SSRF protection
|
||||
if ssrf_blocked is not None:
|
||||
return ssrf_blocked
|
||||
|
||||
# Validate response before waiting for JS render
|
||||
if response is None:
|
||||
return {"error": "Navigation failed: no response received"}
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
"""Tests for shell config path selection and env-var lookups."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from aden_tools.credentials import shell_config
|
||||
|
||||
|
||||
def _mock_home(monkeypatch, tmp_path: Path) -> None:
|
||||
monkeypatch.setattr(shell_config.Path, "home", staticmethod(lambda: tmp_path))
|
||||
|
||||
|
||||
def test_get_shell_config_path_prefers_existing_bash_profile(monkeypatch, tmp_path):
|
||||
_mock_home(monkeypatch, tmp_path)
|
||||
monkeypatch.setenv("SHELL", "/usr/bin/bash")
|
||||
monkeypatch.setattr(shell_config.platform, "system", lambda: "Windows")
|
||||
|
||||
(tmp_path / ".bashrc").write_text("# bashrc\n", encoding="utf-8")
|
||||
(tmp_path / ".bash_profile").write_text("# bash profile\n", encoding="utf-8")
|
||||
|
||||
assert shell_config.get_shell_config_path() == tmp_path / ".bash_profile"
|
||||
|
||||
|
||||
def test_get_shell_config_path_prefers_bashrc_for_non_windows_bash(monkeypatch, tmp_path):
|
||||
_mock_home(monkeypatch, tmp_path)
|
||||
monkeypatch.setenv("SHELL", "/usr/bin/bash")
|
||||
monkeypatch.setattr(shell_config.platform, "system", lambda: "Linux")
|
||||
|
||||
(tmp_path / ".bashrc").write_text("# bashrc\n", encoding="utf-8")
|
||||
(tmp_path / ".bash_profile").write_text("# bash profile\n", encoding="utf-8")
|
||||
|
||||
assert shell_config.get_shell_config_path() == tmp_path / ".bashrc"
|
||||
|
||||
|
||||
def test_check_env_var_in_shell_config_reads_bash_profile(monkeypatch, tmp_path):
|
||||
_mock_home(monkeypatch, tmp_path)
|
||||
monkeypatch.setenv("SHELL", "/usr/bin/bash")
|
||||
monkeypatch.setattr(shell_config.platform, "system", lambda: "Windows")
|
||||
|
||||
(tmp_path / ".bash_profile").write_text(
|
||||
'export HIVE_API_KEY="hive-key-123"\n',
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
assert shell_config.check_env_var_in_shell_config("HIVE_API_KEY") == (
|
||||
True,
|
||||
"hive-key-123",
|
||||
)
|
||||
|
||||
|
||||
def test_check_env_var_in_shell_config_falls_back_to_bashrc_on_windows(monkeypatch, tmp_path):
|
||||
_mock_home(monkeypatch, tmp_path)
|
||||
monkeypatch.setenv("SHELL", "/usr/bin/bash")
|
||||
monkeypatch.setattr(shell_config.platform, "system", lambda: "Windows")
|
||||
|
||||
(tmp_path / ".bash_profile").write_text("# no key here\n", encoding="utf-8")
|
||||
(tmp_path / ".bashrc").write_text(
|
||||
'export HIVE_API_KEY="hive-key-from-bashrc"\n',
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
assert shell_config.check_env_var_in_shell_config("HIVE_API_KEY") == (
|
||||
True,
|
||||
"hive-key-from-bashrc",
|
||||
)
|
||||
|
||||
|
||||
def test_check_env_var_in_shell_config_reads_zshenv_when_zshrc_missing(monkeypatch, tmp_path):
|
||||
_mock_home(monkeypatch, tmp_path)
|
||||
monkeypatch.setenv("SHELL", "/bin/zsh")
|
||||
monkeypatch.setattr(shell_config.platform, "system", lambda: "Darwin")
|
||||
|
||||
(tmp_path / ".zshenv").write_text(
|
||||
"export OPENROUTER_API_KEY='or-key-123'\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
assert shell_config.check_env_var_in_shell_config("OPENROUTER_API_KEY") == (
|
||||
True,
|
||||
"or-key-123",
|
||||
)
|
||||
|
||||
|
||||
def test_get_shell_config_path_falls_back_to_profile_for_unknown_shell(monkeypatch, tmp_path):
|
||||
_mock_home(monkeypatch, tmp_path)
|
||||
monkeypatch.setenv("SHELL", "/usr/bin/fish")
|
||||
monkeypatch.setattr(shell_config.platform, "system", lambda: "Linux")
|
||||
|
||||
(tmp_path / ".profile").write_text("# profile\n", encoding="utf-8")
|
||||
|
||||
assert shell_config.get_shell_config_path() == tmp_path / ".profile"
|
||||
@@ -732,6 +732,100 @@ class TestCsvSql:
|
||||
assert "id" in result["columns"]
|
||||
assert "name" in result["columns"]
|
||||
|
||||
def test_path_with_single_quote(self, csv_tools, session_dir, tmp_path):
|
||||
"""Regression: CSV paths containing single quotes should work (parameter binding)."""
|
||||
csv_file = session_dir / "O'Reilly.csv"
|
||||
csv_file.write_text("name,age\nAlice,21\nBob,22\n", encoding="utf-8")
|
||||
|
||||
with patch("aden_tools.tools.file_system_toolkits.security.WORKSPACES_DIR", str(tmp_path)):
|
||||
result = csv_tools["csv_sql"](
|
||||
path="O'Reilly.csv",
|
||||
workspace_id=TEST_WORKSPACE_ID,
|
||||
agent_id=TEST_AGENT_ID,
|
||||
session_id=TEST_SESSION_ID,
|
||||
query="SELECT * FROM data",
|
||||
)
|
||||
|
||||
assert "error" not in result, result
|
||||
assert result["success"] is True
|
||||
assert result["row_count"] == 2
|
||||
names = [row["name"] for row in result["rows"]]
|
||||
assert "Alice" in names
|
||||
assert "Bob" in names
|
||||
|
||||
# --- NEW: security regression tests required by Issue #1256 ---
|
||||
|
||||
def test_reject_non_select(self, csv_tools, products_csv, tmp_path):
|
||||
"""Reject any non-SELECT / non-WITH query."""
|
||||
with patch("aden_tools.tools.file_system_toolkits.security.WORKSPACES_DIR", str(tmp_path)):
|
||||
result = csv_tools["csv_sql"](
|
||||
path=products_csv.name,
|
||||
workspace_id=TEST_WORKSPACE_ID,
|
||||
agent_id=TEST_AGENT_ID,
|
||||
session_id=TEST_SESSION_ID,
|
||||
query="DROP TABLE data",
|
||||
)
|
||||
assert "error" in result
|
||||
|
||||
def test_reject_multi_statement(self, csv_tools, products_csv, tmp_path):
|
||||
"""Reject multi-statement queries with semicolons."""
|
||||
with patch("aden_tools.tools.file_system_toolkits.security.WORKSPACES_DIR", str(tmp_path)):
|
||||
result = csv_tools["csv_sql"](
|
||||
path=products_csv.name,
|
||||
workspace_id=TEST_WORKSPACE_ID,
|
||||
agent_id=TEST_AGENT_ID,
|
||||
session_id=TEST_SESSION_ID,
|
||||
query="SELECT * FROM data; DROP TABLE data",
|
||||
)
|
||||
assert "error" in result
|
||||
|
||||
def test_reject_sql_comment_dash(self, csv_tools, products_csv, tmp_path):
|
||||
"""Reject queries with SQL line comments."""
|
||||
with patch("aden_tools.tools.file_system_toolkits.security.WORKSPACES_DIR", str(tmp_path)):
|
||||
result = csv_tools["csv_sql"](
|
||||
path=products_csv.name,
|
||||
workspace_id=TEST_WORKSPACE_ID,
|
||||
agent_id=TEST_AGENT_ID,
|
||||
session_id=TEST_SESSION_ID,
|
||||
query="SELECT * FROM data -- WHERE id = 1",
|
||||
)
|
||||
assert "error" in result
|
||||
|
||||
def test_with_cte_allowed(self, csv_tools, products_csv, tmp_path):
|
||||
"""Allow valid WITH (CTE) queries."""
|
||||
with patch("aden_tools.tools.file_system_toolkits.security.WORKSPACES_DIR", str(tmp_path)):
|
||||
result = csv_tools["csv_sql"](
|
||||
path=products_csv.name,
|
||||
workspace_id=TEST_WORKSPACE_ID,
|
||||
agent_id=TEST_AGENT_ID,
|
||||
session_id=TEST_SESSION_ID,
|
||||
query=(
|
||||
"WITH electronics AS (SELECT * FROM data"
|
||||
" WHERE category = 'Electronics')"
|
||||
" SELECT * FROM electronics"
|
||||
),
|
||||
)
|
||||
assert result["success"] is True
|
||||
|
||||
def test_keyword_in_column_name_allowed(self, csv_tools, session_dir, tmp_path):
|
||||
"""Column names like created_at should not trigger keyword blocking."""
|
||||
csv_file = session_dir / "timestamps.csv"
|
||||
csv_file.write_text(
|
||||
"created_at,updated_at,value\n2024-01-01,2024-01-02,100\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
with patch("aden_tools.tools.file_system_toolkits.security.WORKSPACES_DIR", str(tmp_path)):
|
||||
result = csv_tools["csv_sql"](
|
||||
path="timestamps.csv",
|
||||
workspace_id=TEST_WORKSPACE_ID,
|
||||
agent_id=TEST_AGENT_ID,
|
||||
session_id=TEST_SESSION_ID,
|
||||
query="SELECT created_at, updated_at FROM data",
|
||||
)
|
||||
assert "error" not in result, result
|
||||
assert result["success"] is True
|
||||
|
||||
def test_where_clause(self, csv_tools, products_csv, tmp_path):
|
||||
"""Filter with WHERE clause."""
|
||||
with patch("aden_tools.tools.file_system_toolkits.security.WORKSPACES_DIR", str(tmp_path)):
|
||||
|
||||
@@ -0,0 +1,612 @@
|
||||
"""
|
||||
Tests for Mattermost tool.
|
||||
|
||||
Covers:
|
||||
- _MattermostClient methods (list_teams, list_channels, send_message, get_posts, etc.)
|
||||
- Error handling (401, 403, 404, 429, timeout)
|
||||
- Credential retrieval (CredentialStoreAdapter vs env var)
|
||||
- All MCP tool functions
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from aden_tools.tools.mattermost_tool.mattermost_tool import (
|
||||
MAX_MESSAGE_LENGTH,
|
||||
MAX_RETRIES,
|
||||
_MattermostClient,
|
||||
register_tools,
|
||||
)
|
||||
|
||||
# --- _MattermostClient tests ---
|
||||
|
||||
|
||||
class TestMattermostClient:
|
||||
def setup_method(self):
|
||||
self.client = _MattermostClient("test-access-token", "https://mattermost.example.com")
|
||||
|
||||
def test_headers(self):
|
||||
headers = self.client._headers
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
assert headers["Authorization"] == "Bearer test-access-token"
|
||||
|
||||
def test_base_url_strips_trailing_slash(self):
|
||||
client = _MattermostClient("tok", "https://mm.example.com/")
|
||||
assert client._base_url == "https://mm.example.com/api/v4"
|
||||
|
||||
def test_base_url_preserves_api_v4(self):
|
||||
client = _MattermostClient("tok", "https://mm.example.com/api/v4")
|
||||
assert client._base_url == "https://mm.example.com/api/v4"
|
||||
|
||||
def test_base_url_appends_api_v4(self):
|
||||
client = _MattermostClient("tok", "https://mm.example.com")
|
||||
assert client._base_url == "https://mm.example.com/api/v4"
|
||||
|
||||
def test_handle_response_success(self):
|
||||
response = MagicMock()
|
||||
response.status_code = 200
|
||||
response.json.return_value = {"id": "abc123", "username": "testbot"}
|
||||
assert self.client._handle_response(response) == {
|
||||
"id": "abc123",
|
||||
"username": "testbot",
|
||||
}
|
||||
|
||||
def test_handle_response_201(self):
|
||||
response = MagicMock()
|
||||
response.status_code = 201
|
||||
response.json.return_value = {"id": "post123", "message": "hello"}
|
||||
result = self.client._handle_response(response)
|
||||
assert result == {"id": "post123", "message": "hello"}
|
||||
|
||||
def test_handle_response_204(self):
|
||||
response = MagicMock()
|
||||
response.status_code = 204
|
||||
result = self.client._handle_response(response)
|
||||
assert result == {"success": True}
|
||||
|
||||
def test_handle_response_rate_limit_429(self):
|
||||
response = MagicMock()
|
||||
response.status_code = 429
|
||||
response.headers = {"Retry-After": "2.5"}
|
||||
result = self.client._handle_response(response)
|
||||
assert "error" in result
|
||||
assert "rate limit" in result["error"].lower()
|
||||
assert result["retry_after"] == 2.5
|
||||
|
||||
@pytest.mark.parametrize("status_code", [401, 403, 404, 500])
|
||||
def test_handle_response_errors(self, status_code):
|
||||
response = MagicMock()
|
||||
response.status_code = status_code
|
||||
response.json.return_value = {"message": "Test error"}
|
||||
response.text = "Test error"
|
||||
result = self.client._handle_response(response)
|
||||
assert "error" in result
|
||||
assert str(status_code) in result["error"]
|
||||
|
||||
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
|
||||
def test_list_teams(self, mock_request):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(
|
||||
return_value=[
|
||||
{"id": "t1", "name": "test-team", "display_name": "Test Team"},
|
||||
{"id": "t2", "name": "dev-team", "display_name": "Dev Team"},
|
||||
]
|
||||
),
|
||||
)
|
||||
result = self.client.list_teams()
|
||||
mock_request.assert_called_once()
|
||||
assert mock_request.call_args[0][0] == "GET"
|
||||
assert "users/me/teams" in mock_request.call_args[0][1]
|
||||
assert len(result) == 2
|
||||
assert result[0]["display_name"] == "Test Team"
|
||||
|
||||
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
|
||||
def test_list_channels(self, mock_request):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(
|
||||
return_value=[
|
||||
{"id": "c1", "name": "town-square", "type": "O"},
|
||||
{"id": "c2", "name": "off-topic", "type": "O"},
|
||||
]
|
||||
),
|
||||
)
|
||||
result = self.client.list_channels("t1")
|
||||
mock_request.assert_called_once()
|
||||
assert "teams/t1/channels" in mock_request.call_args[0][1]
|
||||
assert len(result) == 2
|
||||
assert result[0]["name"] == "town-square"
|
||||
|
||||
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
|
||||
def test_send_message(self, mock_request):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=201,
|
||||
json=MagicMock(
|
||||
return_value={
|
||||
"id": "p123",
|
||||
"channel_id": "c1",
|
||||
"message": "Hello world",
|
||||
}
|
||||
),
|
||||
)
|
||||
result = self.client.send_message("c1", "Hello world")
|
||||
mock_request.assert_called_once()
|
||||
assert mock_request.call_args[0][0] == "POST"
|
||||
assert "posts" in mock_request.call_args[0][1]
|
||||
assert result["message"] == "Hello world"
|
||||
|
||||
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
|
||||
def test_send_message_with_thread(self, mock_request):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=201,
|
||||
json=MagicMock(
|
||||
return_value={
|
||||
"id": "p124",
|
||||
"channel_id": "c1",
|
||||
"message": "Reply",
|
||||
"root_id": "p123",
|
||||
}
|
||||
),
|
||||
)
|
||||
result = self.client.send_message("c1", "Reply", root_id="p123")
|
||||
call_kwargs = mock_request.call_args[1]
|
||||
assert call_kwargs["json"]["root_id"] == "p123"
|
||||
assert result["root_id"] == "p123"
|
||||
|
||||
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
|
||||
def test_get_posts(self, mock_request):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(
|
||||
return_value={
|
||||
"order": ["p1", "p2"],
|
||||
"posts": {
|
||||
"p1": {"id": "p1", "message": "First"},
|
||||
"p2": {"id": "p2", "message": "Second"},
|
||||
},
|
||||
}
|
||||
),
|
||||
)
|
||||
result = self.client.get_posts("c1", per_page=10)
|
||||
mock_request.assert_called_once()
|
||||
assert mock_request.call_args[1]["params"]["per_page"] == 10
|
||||
assert "order" in result
|
||||
assert len(result["order"]) == 2
|
||||
|
||||
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
|
||||
def test_get_channel(self, mock_request):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(
|
||||
return_value={
|
||||
"id": "c1",
|
||||
"name": "town-square",
|
||||
"display_name": "Town Square",
|
||||
"type": "O",
|
||||
}
|
||||
),
|
||||
)
|
||||
result = self.client.get_channel("c1")
|
||||
assert result["name"] == "town-square"
|
||||
assert result["type"] == "O"
|
||||
|
||||
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
|
||||
def test_delete_post(self, mock_request):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=200, json=MagicMock(return_value={"status": "ok"})
|
||||
)
|
||||
self.client.delete_post("p123")
|
||||
assert mock_request.call_args[0][0] == "DELETE"
|
||||
assert "posts/p123" in mock_request.call_args[0][1]
|
||||
|
||||
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
|
||||
def test_create_reaction(self, mock_request):
|
||||
# First call returns user info, second creates the reaction
|
||||
mock_request.side_effect = [
|
||||
MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(return_value={"id": "user123", "username": "testbot"}),
|
||||
),
|
||||
MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(
|
||||
return_value={
|
||||
"user_id": "user123",
|
||||
"post_id": "p123",
|
||||
"emoji_name": "thumbsup",
|
||||
}
|
||||
),
|
||||
),
|
||||
]
|
||||
result = self.client.create_reaction("p123", "thumbsup")
|
||||
assert result["emoji_name"] == "thumbsup"
|
||||
# Second call should be the reaction POST
|
||||
assert mock_request.call_args_list[1][1]["json"]["emoji_name"] == "thumbsup"
|
||||
|
||||
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.time.sleep")
|
||||
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
|
||||
def test_retry_on_429_then_success(self, mock_request, mock_sleep):
|
||||
mock_request.side_effect = [
|
||||
MagicMock(
|
||||
status_code=429,
|
||||
headers={"Retry-After": "0.01"},
|
||||
text="{}",
|
||||
),
|
||||
MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(return_value=[{"id": "t1", "name": "team"}]),
|
||||
),
|
||||
]
|
||||
result = self.client.list_teams()
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "team"
|
||||
assert mock_request.call_count == 2
|
||||
mock_sleep.assert_called_once_with(0.01)
|
||||
|
||||
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.time.sleep")
|
||||
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
|
||||
def test_retry_exhausted_returns_error(self, mock_request, mock_sleep):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=429,
|
||||
headers={"Retry-After": "0.01"},
|
||||
text="{}",
|
||||
)
|
||||
result = self.client.list_teams()
|
||||
assert "error" in result
|
||||
assert "rate limit" in result["error"].lower()
|
||||
assert mock_request.call_count == MAX_RETRIES + 1
|
||||
|
||||
|
||||
# --- Tool registration tests ---
|
||||
|
||||
|
||||
class TestMattermostListTeamsTool:
|
||||
def setup_method(self):
|
||||
self.mcp = MagicMock()
|
||||
self.fns = []
|
||||
self.mcp.tool.return_value = lambda fn: self.fns.append(fn) or fn
|
||||
cred = MagicMock()
|
||||
cred.get.side_effect = lambda key: {
|
||||
"mattermost": "test-token",
|
||||
"mattermost_url": "https://mattermost.example.com",
|
||||
}.get(key)
|
||||
register_tools(self.mcp, credentials=cred)
|
||||
|
||||
def _fn(self, name):
|
||||
return next(f for f in self.fns if f.__name__ == name)
|
||||
|
||||
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
|
||||
def test_list_teams_success(self, mock_request):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(return_value=[{"id": "t1", "name": "test-team"}]),
|
||||
)
|
||||
result = self._fn("mattermost_list_teams")()
|
||||
assert result["success"] is True
|
||||
assert len(result["teams"]) == 1
|
||||
assert result["teams"][0]["name"] == "test-team"
|
||||
|
||||
def test_list_teams_no_credentials(self):
|
||||
mcp = MagicMock()
|
||||
fns = []
|
||||
mcp.tool.return_value = lambda fn: fns.append(fn) or fn
|
||||
register_tools(mcp, credentials=None)
|
||||
with patch.dict("os.environ", {"MATTERMOST_ACCESS_TOKEN": ""}, clear=False):
|
||||
result = next(f for f in fns if f.__name__ == "mattermost_list_teams")()
|
||||
assert "error" in result
|
||||
assert "not configured" in result["error"]
|
||||
|
||||
|
||||
class TestMattermostListChannelsTool:
|
||||
def setup_method(self):
|
||||
self.mcp = MagicMock()
|
||||
self.fns = []
|
||||
self.mcp.tool.return_value = lambda fn: self.fns.append(fn) or fn
|
||||
cred = MagicMock()
|
||||
cred.get.side_effect = lambda key: {
|
||||
"mattermost": "test-token",
|
||||
"mattermost_url": "https://mattermost.example.com",
|
||||
}.get(key)
|
||||
register_tools(self.mcp, credentials=cred)
|
||||
|
||||
def _fn(self, name):
|
||||
return next(f for f in self.fns if f.__name__ == name)
|
||||
|
||||
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
|
||||
def test_list_channels_success(self, mock_request):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(
|
||||
return_value=[
|
||||
{"id": "c1", "name": "town-square", "type": "O"},
|
||||
]
|
||||
),
|
||||
)
|
||||
result = self._fn("mattermost_list_channels")("team-123")
|
||||
assert result["success"] is True
|
||||
assert len(result["channels"]) == 1
|
||||
assert result["channels"][0]["name"] == "town-square"
|
||||
|
||||
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
|
||||
def test_list_channels_error(self, mock_request):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=404,
|
||||
json=MagicMock(return_value={"message": "Unknown Team"}),
|
||||
text="Unknown Team",
|
||||
)
|
||||
result = self._fn("mattermost_list_channels")("bad-team")
|
||||
assert "error" in result
|
||||
assert "404" in result["error"]
|
||||
|
||||
|
||||
class TestMattermostSendMessageTool:
|
||||
def setup_method(self):
|
||||
self.mcp = MagicMock()
|
||||
self.fns = []
|
||||
self.mcp.tool.return_value = lambda fn: self.fns.append(fn) or fn
|
||||
cred = MagicMock()
|
||||
cred.get.side_effect = lambda key: {
|
||||
"mattermost": "test-token",
|
||||
"mattermost_url": "https://mattermost.example.com",
|
||||
}.get(key)
|
||||
register_tools(self.mcp, credentials=cred)
|
||||
|
||||
def _fn(self, name):
|
||||
return next(f for f in self.fns if f.__name__ == name)
|
||||
|
||||
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
|
||||
def test_send_message_success(self, mock_request):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=201,
|
||||
json=MagicMock(
|
||||
return_value={
|
||||
"id": "p123",
|
||||
"channel_id": "c1",
|
||||
"message": "Incident resolved",
|
||||
}
|
||||
),
|
||||
)
|
||||
result = self._fn("mattermost_send_message")("c1", "Incident resolved")
|
||||
assert result["success"] is True
|
||||
assert result["post"]["message"] == "Incident resolved"
|
||||
|
||||
def test_send_message_length_validation(self):
|
||||
long_content = "x" * (MAX_MESSAGE_LENGTH + 1)
|
||||
result = self._fn("mattermost_send_message")("c1", long_content)
|
||||
assert "error" in result
|
||||
assert str(MAX_MESSAGE_LENGTH) in result["error"]
|
||||
assert result["max_length"] == MAX_MESSAGE_LENGTH
|
||||
assert result["provided"] == MAX_MESSAGE_LENGTH + 1
|
||||
|
||||
def test_send_message_exactly_at_limit(self):
|
||||
content = "x" * MAX_MESSAGE_LENGTH
|
||||
with patch(
|
||||
"aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request"
|
||||
) as mock_request:
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=201,
|
||||
json=MagicMock(return_value={"id": "p1", "channel_id": "c1", "message": content}),
|
||||
)
|
||||
result = self._fn("mattermost_send_message")("c1", content)
|
||||
assert result["success"] is True
|
||||
|
||||
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
|
||||
def test_send_message_rate_limit_429_exhausted(self, mock_request):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=429,
|
||||
headers={"Retry-After": "5"},
|
||||
text="{}",
|
||||
)
|
||||
result = self._fn("mattermost_send_message")("c1", "Hello")
|
||||
assert "error" in result
|
||||
assert "rate limit" in result["error"].lower()
|
||||
assert mock_request.call_count == MAX_RETRIES + 1
|
||||
|
||||
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
|
||||
def test_send_message_rate_limit_then_success(self, mock_request):
|
||||
mock_request.side_effect = [
|
||||
MagicMock(
|
||||
status_code=429,
|
||||
headers={"Retry-After": "0.01"},
|
||||
text="{}",
|
||||
),
|
||||
MagicMock(
|
||||
status_code=201,
|
||||
json=MagicMock(return_value={"id": "p1", "channel_id": "c1", "message": "Hi"}),
|
||||
),
|
||||
]
|
||||
result = self._fn("mattermost_send_message")("c1", "Hi")
|
||||
assert result["success"] is True
|
||||
assert result["post"]["message"] == "Hi"
|
||||
assert mock_request.call_count == 2
|
||||
|
||||
|
||||
class TestMattermostGetPostsTool:
|
||||
def setup_method(self):
|
||||
self.mcp = MagicMock()
|
||||
self.fns = []
|
||||
self.mcp.tool.return_value = lambda fn: self.fns.append(fn) or fn
|
||||
cred = MagicMock()
|
||||
cred.get.side_effect = lambda key: {
|
||||
"mattermost": "test-token",
|
||||
"mattermost_url": "https://mattermost.example.com",
|
||||
}.get(key)
|
||||
register_tools(self.mcp, credentials=cred)
|
||||
|
||||
def _fn(self, name):
|
||||
return next(f for f in self.fns if f.__name__ == name)
|
||||
|
||||
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
|
||||
def test_get_posts_success(self, mock_request):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(
|
||||
return_value={
|
||||
"order": ["p1"],
|
||||
"posts": {"p1": {"id": "p1", "message": "First message"}},
|
||||
}
|
||||
),
|
||||
)
|
||||
result = self._fn("mattermost_get_posts")("c1", per_page=10)
|
||||
assert result["success"] is True
|
||||
assert "posts" in result
|
||||
|
||||
|
||||
class TestMattermostGetChannelTool:
|
||||
def setup_method(self):
|
||||
self.mcp = MagicMock()
|
||||
self.fns = []
|
||||
self.mcp.tool.return_value = lambda fn: self.fns.append(fn) or fn
|
||||
cred = MagicMock()
|
||||
cred.get.side_effect = lambda key: {
|
||||
"mattermost": "test-token",
|
||||
"mattermost_url": "https://mattermost.example.com",
|
||||
}.get(key)
|
||||
register_tools(self.mcp, credentials=cred)
|
||||
|
||||
def _fn(self, name):
|
||||
return next(f for f in self.fns if f.__name__ == name)
|
||||
|
||||
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
|
||||
def test_get_channel_success(self, mock_request):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(
|
||||
return_value={
|
||||
"id": "c1",
|
||||
"name": "town-square",
|
||||
"display_name": "Town Square",
|
||||
"type": "O",
|
||||
}
|
||||
),
|
||||
)
|
||||
result = self._fn("mattermost_get_channel")("c1")
|
||||
assert result["success"] is True
|
||||
assert result["channel"]["name"] == "town-square"
|
||||
|
||||
|
||||
class TestMattermostDeletePostTool:
|
||||
def setup_method(self):
|
||||
self.mcp = MagicMock()
|
||||
self.fns = []
|
||||
self.mcp.tool.return_value = lambda fn: self.fns.append(fn) or fn
|
||||
cred = MagicMock()
|
||||
cred.get.side_effect = lambda key: {
|
||||
"mattermost": "test-token",
|
||||
"mattermost_url": "https://mattermost.example.com",
|
||||
}.get(key)
|
||||
register_tools(self.mcp, credentials=cred)
|
||||
|
||||
def _fn(self, name):
|
||||
return next(f for f in self.fns if f.__name__ == name)
|
||||
|
||||
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
|
||||
def test_delete_post_success(self, mock_request):
|
||||
mock_request.return_value = MagicMock(
|
||||
status_code=200, json=MagicMock(return_value={"status": "ok"})
|
||||
)
|
||||
result = self._fn("mattermost_delete_post")("p123")
|
||||
assert result["success"] is True
|
||||
assert result["deleted_post_id"] == "p123"
|
||||
|
||||
|
||||
class TestMattermostCreateReactionTool:
|
||||
def setup_method(self):
|
||||
self.mcp = MagicMock()
|
||||
self.fns = []
|
||||
self.mcp.tool.return_value = lambda fn: self.fns.append(fn) or fn
|
||||
cred = MagicMock()
|
||||
cred.get.side_effect = lambda key: {
|
||||
"mattermost": "test-token",
|
||||
"mattermost_url": "https://mattermost.example.com",
|
||||
}.get(key)
|
||||
register_tools(self.mcp, credentials=cred)
|
||||
|
||||
def _fn(self, name):
|
||||
return next(f for f in self.fns if f.__name__ == name)
|
||||
|
||||
@patch("aden_tools.tools.mattermost_tool.mattermost_tool.httpx.request")
|
||||
def test_create_reaction_success(self, mock_request):
|
||||
mock_request.side_effect = [
|
||||
MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(return_value={"id": "user123"}),
|
||||
),
|
||||
MagicMock(
|
||||
status_code=200,
|
||||
json=MagicMock(
|
||||
return_value={
|
||||
"user_id": "user123",
|
||||
"post_id": "p123",
|
||||
"emoji_name": "thumbsup",
|
||||
}
|
||||
),
|
||||
),
|
||||
]
|
||||
result = self._fn("mattermost_create_reaction")("p123", "thumbsup")
|
||||
assert result["success"] is True
|
||||
|
||||
|
||||
class TestMattermostNoUrl:
|
||||
"""Test that missing URL returns a helpful error."""
|
||||
|
||||
def test_missing_url_returns_error(self):
|
||||
mcp = MagicMock()
|
||||
fns = []
|
||||
mcp.tool.return_value = lambda fn: fns.append(fn) or fn
|
||||
cred = MagicMock()
|
||||
# Token is set but URL is not
|
||||
cred.get.side_effect = lambda key: {
|
||||
"mattermost": "test-token",
|
||||
"mattermost_url": None,
|
||||
}.get(key)
|
||||
register_tools(mcp, credentials=cred)
|
||||
with patch.dict("os.environ", {"MATTERMOST_URL": ""}, clear=False):
|
||||
fn = next(f for f in fns if f.__name__ == "mattermost_list_teams")
|
||||
result = fn()
|
||||
assert "error" in result
|
||||
assert "URL" in result["error"]
|
||||
|
||||
|
||||
# --- Credential spec tests ---
|
||||
|
||||
|
||||
class TestCredentialSpec:
|
||||
def test_mattermost_credential_spec_exists(self):
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
|
||||
assert "mattermost" in CREDENTIAL_SPECS
|
||||
|
||||
def test_mattermost_spec_env_var(self):
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
|
||||
spec = CREDENTIAL_SPECS["mattermost"]
|
||||
assert spec.env_var == "MATTERMOST_ACCESS_TOKEN"
|
||||
|
||||
def test_mattermost_spec_tools(self):
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
|
||||
spec = CREDENTIAL_SPECS["mattermost"]
|
||||
assert "mattermost_list_teams" in spec.tools
|
||||
assert "mattermost_list_channels" in spec.tools
|
||||
assert "mattermost_get_channel" in spec.tools
|
||||
assert "mattermost_send_message" in spec.tools
|
||||
assert "mattermost_get_posts" in spec.tools
|
||||
assert "mattermost_create_reaction" in spec.tools
|
||||
assert "mattermost_delete_post" in spec.tools
|
||||
assert len(spec.tools) == 7
|
||||
|
||||
def test_mattermost_url_credential_spec_exists(self):
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
|
||||
assert "mattermost_url" in CREDENTIAL_SPECS
|
||||
|
||||
def test_mattermost_url_spec_env_var(self):
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
|
||||
spec = CREDENTIAL_SPECS["mattermost_url"]
|
||||
assert spec.env_var == "MATTERMOST_URL"
|
||||
@@ -22,13 +22,30 @@ def _mock_resp(data, status_code=200):
|
||||
return resp
|
||||
|
||||
|
||||
def _mock_credentials() -> MagicMock:
|
||||
creds = MagicMock()
|
||||
creds.get.side_effect = lambda key: {
|
||||
"sap_base_url": "https://cred-store.s4hana.ondemand.com",
|
||||
"sap_username": "CRED_USER",
|
||||
"sap_password": "cred-password",
|
||||
}.get(key)
|
||||
return creds
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_fns(mcp: FastMCP):
|
||||
def tool_fns(mcp: FastMCP) -> dict:
|
||||
register_tools(mcp, credentials=None)
|
||||
tools = mcp._tool_manager._tools
|
||||
return {name: tools[name].fn for name in tools}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_fns_with_creds(mcp: FastMCP) -> dict:
|
||||
register_tools(mcp, credentials=_mock_credentials())
|
||||
tools = mcp._tool_manager._tools
|
||||
return {name: tools[name].fn for name in tools}
|
||||
|
||||
|
||||
class TestSAPListPurchaseOrders:
|
||||
def test_missing_credentials(self, tool_fns):
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
@@ -176,3 +193,66 @@ class TestSAPListSalesOrders:
|
||||
assert result["count"] == 1
|
||||
assert result["sales_orders"][0]["sales_order"] == "1"
|
||||
assert result["sales_orders"][0]["net_amount"] == "25000.00"
|
||||
|
||||
|
||||
class TestCredentialStoreAdapter:
|
||||
"""Verify credentials are resolved via CredentialStoreAdapter."""
|
||||
|
||||
def test_credential_store_used(self, tool_fns_with_creds):
|
||||
data = {
|
||||
"d": {
|
||||
"__count": "1",
|
||||
"results": [
|
||||
{
|
||||
"PurchaseOrder": "4500000001",
|
||||
"PurchaseOrderType": "NB",
|
||||
"CompanyCode": "1010",
|
||||
"Supplier": "17300001",
|
||||
"CreationDate": "/Date(1672531200000)/",
|
||||
"PurchaseOrderNetAmount": "15000.00",
|
||||
"DocumentCurrency": "USD",
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
with patch(
|
||||
"aden_tools.tools.sap_tool.sap_tool.httpx.get",
|
||||
return_value=_mock_resp(data),
|
||||
) as mock_get:
|
||||
result = tool_fns_with_creds["sap_list_purchase_orders"]()
|
||||
|
||||
assert result["count"] == 1
|
||||
call_url = mock_get.call_args.args[0]
|
||||
assert "cred-store.s4hana.ondemand.com" in call_url
|
||||
|
||||
def test_credential_store_missing_values(self):
|
||||
creds = MagicMock()
|
||||
creds.get.return_value = None
|
||||
|
||||
mcp = FastMCP("test")
|
||||
register_tools(mcp, credentials=creds)
|
||||
tools = mcp._tool_manager._tools
|
||||
fn = tools["sap_list_purchase_orders"].fn
|
||||
|
||||
result = fn()
|
||||
assert "error" in result
|
||||
|
||||
def test_env_fallback_when_no_adapter(self, tool_fns):
|
||||
data = {
|
||||
"d": {
|
||||
"__count": "0",
|
||||
"results": [],
|
||||
}
|
||||
}
|
||||
with (
|
||||
patch.dict("os.environ", ENV),
|
||||
patch(
|
||||
"aden_tools.tools.sap_tool.sap_tool.httpx.get",
|
||||
return_value=_mock_resp(data),
|
||||
) as mock_get,
|
||||
):
|
||||
result = tool_fns["sap_list_purchase_orders"]()
|
||||
|
||||
assert result["count"] == 0
|
||||
call_url = mock_get.call_args.args[0]
|
||||
assert "my-tenant-api.s4hana.ondemand.com" in call_url
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
"""Tests for web_scrape tool (FastMCP)."""
|
||||
|
||||
import socket
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import FastMCP
|
||||
|
||||
from aden_tools.tools.web_scrape_tool import register_tools
|
||||
from aden_tools.tools.web_scrape_tool.web_scrape_tool import (
|
||||
_check_url_target,
|
||||
_is_internal_address,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -430,3 +435,100 @@ class TestWebScrapeToolRobotsTxt:
|
||||
result = await web_scrape_fn(url="https://example.com", respect_robots_txt=False)
|
||||
assert "error" not in result
|
||||
mock_rp_cls.assert_not_called()
|
||||
|
||||
|
||||
_MOD = "aden_tools.tools.web_scrape_tool.web_scrape_tool"
|
||||
|
||||
|
||||
class TestIsInternalAddress:
|
||||
"""Tests for _is_internal_address."""
|
||||
|
||||
def test_loopback_ipv4(self):
|
||||
assert _is_internal_address("127.0.0.1") is True
|
||||
|
||||
def test_private_10_range(self):
|
||||
assert _is_internal_address("10.0.0.1") is True
|
||||
|
||||
def test_private_192_168(self):
|
||||
assert _is_internal_address("192.168.1.1") is True
|
||||
|
||||
def test_link_local_aws_metadata(self):
|
||||
assert _is_internal_address("169.254.169.254") is True
|
||||
|
||||
def test_public_ipv4(self):
|
||||
assert _is_internal_address("8.8.8.8") is False
|
||||
|
||||
def test_public_ipv6(self):
|
||||
assert _is_internal_address("2607:f8b0:4004:800::200e") is False
|
||||
|
||||
def test_invalid_string_blocked(self):
|
||||
assert _is_internal_address("not-an-ip") is True
|
||||
|
||||
|
||||
def _fake_addrinfo(ip: str, port: int = 443) -> list[tuple]:
|
||||
return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", (ip, port))]
|
||||
|
||||
|
||||
class TestCheckUrlTarget:
|
||||
"""Tests for _check_url_target."""
|
||||
|
||||
@patch(f"{_MOD}.socket.getaddrinfo")
|
||||
def test_public_hostname_allowed(self, mock_dns):
|
||||
mock_dns.return_value = _fake_addrinfo("93.184.216.34")
|
||||
assert _check_url_target("https://example.com/page") is None
|
||||
|
||||
@patch(f"{_MOD}.socket.getaddrinfo")
|
||||
def test_private_hostname_blocked(self, mock_dns):
|
||||
mock_dns.return_value = _fake_addrinfo("10.0.0.1")
|
||||
result = _check_url_target("https://evil.com/steal")
|
||||
assert result is not None
|
||||
assert "internal" in result.lower()
|
||||
|
||||
def test_raw_private_ip_blocked(self):
|
||||
result = _check_url_target("http://127.0.0.1/admin")
|
||||
assert result is not None
|
||||
|
||||
@patch(
|
||||
f"{_MOD}.socket.getaddrinfo",
|
||||
side_effect=socket.gaierror("NXDOMAIN"),
|
||||
)
|
||||
def test_dns_failure_returns_error(self, _mock_dns):
|
||||
result = _check_url_target("https://nonexistent.invalid/")
|
||||
assert result is not None
|
||||
assert "DNS" in result
|
||||
|
||||
|
||||
class TestWebScrapeSSRF:
|
||||
"""SSRF protection through the web_scrape tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocks_private_ip(self, web_scrape_fn):
|
||||
result = await web_scrape_fn(url="http://192.168.1.1/admin")
|
||||
assert "error" in result
|
||||
assert result.get("blocked_by_ssrf_protection") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocks_localhost(self, web_scrape_fn):
|
||||
result = await web_scrape_fn(url="http://127.0.0.1/secret")
|
||||
assert "error" in result
|
||||
assert result.get("blocked_by_ssrf_protection") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocks_metadata_endpoint(self, web_scrape_fn):
|
||||
result = await web_scrape_fn(url="http://169.254.169.254/latest/meta-data/")
|
||||
assert "error" in result
|
||||
assert result.get("blocked_by_ssrf_protection") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(_STEALTH_PATH)
|
||||
@patch(_PW_PATH)
|
||||
@patch(f"{_MOD}._check_url_target", return_value=None)
|
||||
async def test_allows_public_url(self, _mock_check, mock_pw, mock_stealth, web_scrape_fn):
|
||||
html = "<html><body><p>Hello world</p></body></html>"
|
||||
mock_cm, _, _ = _make_playwright_mocks(html)
|
||||
mock_pw.return_value = mock_cm
|
||||
mock_stealth.return_value.apply_stealth_async = AsyncMock()
|
||||
|
||||
result = await web_scrape_fn(url="https://example.com/")
|
||||
assert "error" not in result
|
||||
assert "Hello world" in result["content"]
|
||||
|
||||
Reference in New Issue
Block a user