Compare commits
36 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8c4085f5e8 | |||
| 53240eb888 | |||
| de8d6f0946 | |||
| ea707438f2 | |||
| 445c9600ab | |||
| 2ab5e6d784 | |||
| e7f9b7d791 | |||
| 3cb0c69a96 | |||
| 22d75bfb05 | |||
| 357df1bbcb | |||
| 386bbd5780 | |||
| 235022b35d | |||
| 4d8f312c3e | |||
| 4651a6a85a | |||
| ea9c163438 | |||
| 77cc169606 | |||
| 8c6428f445 | |||
| 44cb0c0f4c | |||
| 2621fb88b1 | |||
| a70f92edbe | |||
| b2efa179ea | |||
| 8c6e76d052 | |||
| c7f1fbf19f | |||
| 7047ecbf46 | |||
| b96ee5aaab | |||
| 6744bea01a | |||
| 390038225b | |||
| b55c8fdf86 | |||
| e9aea0bbc4 | |||
| 0ba1fa8262 | |||
| 0fd96d410e | |||
| c658a7c50b | |||
| 56c3659bda | |||
| 14f927996c | |||
| 8a0ec070b8 | |||
| 80cd77ac30 |
@@ -85,7 +85,12 @@ from framework.agent_loop.internals.types import (
|
||||
JudgeVerdict,
|
||||
TriggerEvent,
|
||||
)
|
||||
from framework.agent_loop.internals.vision_fallback import (
|
||||
caption_tool_image,
|
||||
extract_intent_for_tool,
|
||||
)
|
||||
from framework.agent_loop.types import AgentContext, AgentProtocol, AgentResult
|
||||
from framework.config import get_vision_fallback_model
|
||||
from framework.host.event_bus import EventBus
|
||||
from framework.llm.capabilities import filter_tools_for_model, supports_image_tool_results
|
||||
from framework.llm.provider import Tool, ToolResult, ToolUse
|
||||
@@ -219,6 +224,46 @@ async def _describe_images_as_text(image_content: list[dict[str, Any]]) -> str |
|
||||
return None
|
||||
|
||||
|
||||
def _vision_fallback_active(model: str | None) -> bool:
|
||||
"""Return True if tool-result images for *model* should be routed
|
||||
through the vision-fallback chain rather than sent to the model.
|
||||
|
||||
Trigger: the model's catalog entry has ``supports_vision: false``
|
||||
(resolved via :func:`capabilities.supports_image_tool_results`,
|
||||
which reads ``model_catalog.json``). Unknown models default to
|
||||
vision-capable, so the fallback only fires when the catalog
|
||||
explicitly says the model is text-only.
|
||||
|
||||
The ``vision_fallback`` config block is the *substitution* model —
|
||||
it doesn't widen the trigger. To force fallback for a model that
|
||||
isn't catalogued yet, add an entry to ``model_catalog.json`` with
|
||||
``supports_vision: false`` rather than relying on a runtime config.
|
||||
"""
|
||||
if not model:
|
||||
return False
|
||||
return not supports_image_tool_results(model)
|
||||
|
||||
|
||||
async def _captioning_chain(
|
||||
intent: str,
|
||||
image_content: list[dict[str, Any]],
|
||||
) -> str | None:
|
||||
"""Two-stage caption chain used by the agent-loop tool-result hook.
|
||||
|
||||
Stage 1: configured ``vision_fallback`` model with intent + images.
|
||||
Stage 2: generic-caption rotation (gpt-4o-mini → claude-3-haiku
|
||||
→ gemini-flash) when stage 1 is unconfigured or fails.
|
||||
|
||||
Returns the caption text or None if both stages fail. Caller is
|
||||
responsible for the placeholder-on-None and the splice into the
|
||||
persisted tool-result content.
|
||||
"""
|
||||
caption = await caption_tool_image(intent, image_content)
|
||||
if not caption:
|
||||
caption = await _describe_images_as_text(image_content)
|
||||
return caption
|
||||
|
||||
|
||||
# Pattern for detecting context-window-exceeded errors across LLM providers.
|
||||
_CONTEXT_TOO_LARGE_RE = re.compile(
|
||||
r"context.{0,20}(length|window|limit|size)|"
|
||||
@@ -575,6 +620,7 @@ class AgentLoop(AgentProtocol):
|
||||
store=self._conversation_store,
|
||||
run_id=ctx.effective_run_id,
|
||||
compaction_buffer_tokens=self._config.compaction_buffer_tokens,
|
||||
compaction_buffer_ratio=self._config.compaction_buffer_ratio,
|
||||
compaction_warning_buffer_tokens=(self._config.compaction_warning_buffer_tokens),
|
||||
)
|
||||
accumulator = OutputAccumulator(
|
||||
@@ -587,7 +633,12 @@ class AgentLoop(AgentProtocol):
|
||||
|
||||
initial_message = self._build_initial_message(ctx)
|
||||
if initial_message:
|
||||
await conversation.add_user_message(initial_message)
|
||||
# Stamp with arrival time so the conversation has a
|
||||
# temporal anchor for the first turn, matching the
|
||||
# stamping done by drain_injection_queue for every
|
||||
# subsequent event.
|
||||
_stamp = datetime.now().astimezone().strftime("%Y-%m-%d %H:%M %Z")
|
||||
await conversation.add_user_message(f"[{_stamp}] {initial_message}")
|
||||
|
||||
await self._run_hooks("session_start", conversation, trigger=initial_message)
|
||||
|
||||
@@ -599,7 +650,8 @@ class AgentLoop(AgentProtocol):
|
||||
initial_message = self._build_initial_message(ctx)
|
||||
if not initial_message:
|
||||
initial_message = "Hello"
|
||||
await conversation.add_user_message(initial_message)
|
||||
_stamp = datetime.now().astimezone().strftime("%Y-%m-%d %H:%M %Z")
|
||||
await conversation.add_user_message(f"[{_stamp}] {initial_message}")
|
||||
|
||||
# 2b. Restore spill counter from existing files (resume safety)
|
||||
self._restore_spill_counter()
|
||||
@@ -619,8 +671,23 @@ class AgentLoop(AgentProtocol):
|
||||
# Hide image-producing tools from text-only models so they never try
|
||||
# to call them. Avoids wasted turns + "screenshot failed" lessons
|
||||
# getting saved to memory. See framework.llm.capabilities.
|
||||
# EXCEPTION: when the model IS on the text-only deny list AND
|
||||
# a vision_fallback subagent is configured, leave image tools
|
||||
# visible. The post-execution hook in the inner tool loop
|
||||
# will route each image_content through the fallback VLM and
|
||||
# replace it with a text caption before the main agent sees
|
||||
# the result — so the main agent gets captions instead of
|
||||
# raw images, rather than losing the tool entirely. We DON'T
|
||||
# bypass the filter for vision-capable models (that would be
|
||||
# a no-op anyway — the filter doesn't fire for them) and we
|
||||
# DON'T bypass it without a configured fallback (the agent
|
||||
# would just see raw stripped tool results with no caption).
|
||||
_llm_model = ctx.llm.model if ctx.llm else ""
|
||||
tools, _hidden_image_tools = filter_tools_for_model(tools, _llm_model)
|
||||
_text_only_main = _llm_model and not supports_image_tool_results(_llm_model)
|
||||
if _text_only_main and get_vision_fallback_model() is not None:
|
||||
_hidden_image_tools: list[str] = []
|
||||
else:
|
||||
tools, _hidden_image_tools = filter_tools_for_model(tools, _llm_model)
|
||||
|
||||
logger.info(
|
||||
"[%s] Tools available (%d): %s | direct_user_io=%s | judge=%s | hidden_image_tools=%s",
|
||||
@@ -793,14 +860,56 @@ class AgentLoop(AgentProtocol):
|
||||
tools.extend(synthetic)
|
||||
|
||||
# 6b3. Dynamic prompt refresh (phase switching / memory refresh)
|
||||
if ctx.dynamic_prompt_provider is not None or ctx.dynamic_memory_provider is not None:
|
||||
if (
|
||||
ctx.dynamic_prompt_provider is not None
|
||||
or ctx.dynamic_memory_provider is not None
|
||||
or ctx.dynamic_skills_catalog_provider is not None
|
||||
):
|
||||
if ctx.dynamic_prompt_provider is not None:
|
||||
_new_prompt = stamp_prompt_datetime(ctx.dynamic_prompt_provider())
|
||||
_new_prompt = ctx.dynamic_prompt_provider()
|
||||
# When a suffix provider is also wired (Queen's
|
||||
# static/dynamic split), keep the two pieces separate
|
||||
# so the LLM wrapper can emit them as two system
|
||||
# content blocks with a cache breakpoint between them.
|
||||
# The timestamp used to be stamped here via
|
||||
# stamp_prompt_datetime on every iteration — it now
|
||||
# lives inside the frozen dynamic suffix and is only
|
||||
# refreshed at user-turn boundaries, so per-iteration
|
||||
# stamping would both double-stamp and bust the cache.
|
||||
_new_suffix: str | None = None
|
||||
if ctx.dynamic_prompt_suffix_provider is not None:
|
||||
try:
|
||||
_new_suffix = ctx.dynamic_prompt_suffix_provider() or ""
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"[%s] dynamic_prompt_suffix_provider raised — falling back to legacy stamp",
|
||||
node_id,
|
||||
exc_info=True,
|
||||
)
|
||||
_new_suffix = None
|
||||
if _new_suffix is None:
|
||||
# Legacy / fallback path: no split in use (or the
|
||||
# suffix provider raised). Stamp the timestamp at
|
||||
# the end of the single-string prompt so the model
|
||||
# still sees a current "now".
|
||||
_new_prompt = stamp_prompt_datetime(_new_prompt)
|
||||
else:
|
||||
# build_system_prompt_for_context reads dynamic_skills_catalog_provider
|
||||
# directly; no separate branch needed.
|
||||
_new_prompt = build_system_prompt_for_context(ctx)
|
||||
if _new_prompt != conversation.system_prompt:
|
||||
conversation.update_system_prompt(_new_prompt)
|
||||
logger.info("[%s] Dynamic prompt updated", node_id)
|
||||
_new_suffix = None
|
||||
if _new_suffix is not None:
|
||||
_combined_for_compare = f"{_new_prompt}\n\n{_new_suffix}" if _new_suffix else _new_prompt
|
||||
if (
|
||||
_combined_for_compare != conversation.system_prompt
|
||||
or _new_suffix != conversation.system_prompt_dynamic_suffix
|
||||
):
|
||||
conversation.update_system_prompt(_new_prompt, dynamic_suffix=_new_suffix)
|
||||
logger.info("[%s] Dynamic prompt updated (split)", node_id)
|
||||
else:
|
||||
if _new_prompt != conversation.system_prompt:
|
||||
conversation.update_system_prompt(_new_prompt)
|
||||
logger.info("[%s] Dynamic prompt updated", node_id)
|
||||
|
||||
# 6c. Publish iteration event (with per-iteration metadata when available)
|
||||
_iter_meta = None
|
||||
@@ -890,6 +999,8 @@ class AgentLoop(AgentProtocol):
|
||||
input_tokens=turn_tokens.get("input", 0),
|
||||
output_tokens=turn_tokens.get("output", 0),
|
||||
cached_tokens=turn_tokens.get("cached", 0),
|
||||
cache_creation_tokens=turn_tokens.get("cache_creation", 0),
|
||||
cost_usd=float(turn_tokens.get("cost", 0.0) or 0.0),
|
||||
execution_id=execution_id,
|
||||
iteration=iteration,
|
||||
)
|
||||
@@ -2290,7 +2401,9 @@ class AgentLoop(AgentProtocol):
|
||||
stream_id = ctx.stream_id or ctx.agent_id
|
||||
node_id = ctx.agent_id
|
||||
execution_id = ctx.execution_id or ""
|
||||
token_counts: dict[str, int] = {"input": 0, "output": 0, "cached": 0}
|
||||
# Mixed-type dict: int token counts + str stop_reason/model + float cost.
|
||||
# Typed loosely to avoid churn in the many call sites that read from it.
|
||||
token_counts: dict[str, Any] = {"input": 0, "output": 0, "cached": 0, "cache_creation": 0, "cost": 0.0}
|
||||
tool_call_count = 0
|
||||
final_text = ""
|
||||
final_system_prompt = conversation.system_prompt
|
||||
@@ -2431,9 +2544,16 @@ class AgentLoop(AgentProtocol):
|
||||
nonlocal _first_event_at
|
||||
_clean_snapshot = "" # visible-only text for the frontend
|
||||
|
||||
# Split-prompt path: pass STATIC and DYNAMIC tail separately
|
||||
# so the LLM wrapper can emit them as two Anthropic system
|
||||
# content blocks with a cache breakpoint between them. When
|
||||
# no split is in use, ``system_prompt_static`` equals the
|
||||
# full prompt and the suffix is empty — identical to the
|
||||
# legacy single-block request.
|
||||
async for event in ctx.llm.stream(
|
||||
messages=_msgs,
|
||||
system=conversation.system_prompt,
|
||||
system=conversation.system_prompt_static,
|
||||
system_dynamic_suffix=(conversation.system_prompt_dynamic_suffix or None),
|
||||
tools=tools if tools else None,
|
||||
max_tokens=ctx.max_tokens,
|
||||
):
|
||||
@@ -2514,6 +2634,8 @@ class AgentLoop(AgentProtocol):
|
||||
token_counts["input"] += event.input_tokens
|
||||
token_counts["output"] += event.output_tokens
|
||||
token_counts["cached"] += event.cached_tokens
|
||||
token_counts["cache_creation"] += event.cache_creation_tokens
|
||||
token_counts["cost"] = token_counts.get("cost", 0.0) + event.cost_usd
|
||||
token_counts["stop_reason"] = event.stop_reason
|
||||
token_counts["model"] = event.model
|
||||
|
||||
@@ -3306,6 +3428,30 @@ class AgentLoop(AgentProtocol):
|
||||
|
||||
# Phase 3: record results into conversation in original order,
|
||||
# build logged/real lists, and publish completed events.
|
||||
#
|
||||
# Vision-fallback prefetch: a single turn may fire several
|
||||
# image-producing tools in parallel (e.g. one screenshot
|
||||
# per tab). Captioning each one takes a vision LLM round
|
||||
# trip (1–30 s). Doing them sequentially in this loop
|
||||
# would serialise that latency per image. Instead, kick
|
||||
# off all caption tasks concurrently NOW, and await each
|
||||
# one just-in-time inside the per-tc body. If only a
|
||||
# single image needs captioning, this collapses to a
|
||||
# single await with no overhead.
|
||||
_model_text_only = ctx.llm and _vision_fallback_active(ctx.llm.model)
|
||||
caption_tasks: dict[str, asyncio.Task[str | None]] = {}
|
||||
if _model_text_only:
|
||||
for tc in tool_calls[:executed_in_batch]:
|
||||
res = results_by_id.get(tc.tool_use_id)
|
||||
if not res or not res.image_content:
|
||||
continue
|
||||
intent = extract_intent_for_tool(
|
||||
conversation,
|
||||
tc.tool_name,
|
||||
tc.tool_input or {},
|
||||
)
|
||||
caption_tasks[tc.tool_use_id] = asyncio.create_task(_captioning_chain(intent, res.image_content))
|
||||
|
||||
for tc in tool_calls[:executed_in_batch]:
|
||||
result = results_by_id.get(tc.tool_use_id)
|
||||
if result is None:
|
||||
@@ -3328,11 +3474,30 @@ class AgentLoop(AgentProtocol):
|
||||
logged_tool_calls.append(tool_entry)
|
||||
|
||||
image_content = result.image_content
|
||||
if image_content and ctx.llm and not supports_image_tool_results(ctx.llm.model):
|
||||
logger.info(
|
||||
"Stripping image_content from tool result; model '%s' does not support images in tool results",
|
||||
ctx.llm.model,
|
||||
)
|
||||
# Vision-fallback marker spliced into the persisted text
|
||||
# below. None when no captioning ran (vision-capable
|
||||
# main model, no images, or no fallback chain reached
|
||||
# this tool).
|
||||
vision_fallback_marker: str | None = None
|
||||
if image_content and tc.tool_use_id in caption_tasks:
|
||||
caption = await caption_tasks.pop(tc.tool_use_id)
|
||||
if caption:
|
||||
vision_fallback_marker = f"[vision-fallback caption]\n{caption}"
|
||||
logger.info(
|
||||
"vision_fallback: captioned %d image(s) for tool '%s' (model '%s' routed through fallback)",
|
||||
len(image_content),
|
||||
tc.tool_name,
|
||||
ctx.llm.model if ctx.llm else "?",
|
||||
)
|
||||
else:
|
||||
vision_fallback_marker = "[image stripped — vision fallback exhausted]"
|
||||
logger.info(
|
||||
"vision_fallback: exhausted; stripping %d image(s) from "
|
||||
"tool '%s' result without caption (model '%s')",
|
||||
len(image_content),
|
||||
tc.tool_name,
|
||||
ctx.llm.model if ctx.llm else "?",
|
||||
)
|
||||
image_content = None
|
||||
|
||||
# Apply replay-detector steer prefix if this call matched a
|
||||
@@ -3344,6 +3509,11 @@ class AgentLoop(AgentProtocol):
|
||||
if _prefix:
|
||||
stored_content = f"{_prefix}{stored_content or ''}"
|
||||
|
||||
# Splice the vision-fallback caption / placeholder into
|
||||
# the persisted text after any prefix has been applied.
|
||||
if vision_fallback_marker:
|
||||
stored_content = f"{stored_content or ''}\n\n{vision_fallback_marker}"
|
||||
|
||||
await conversation.add_tool_result(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=stored_content,
|
||||
@@ -4095,6 +4265,8 @@ class AgentLoop(AgentProtocol):
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cached_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
cost_usd: float = 0.0,
|
||||
execution_id: str = "",
|
||||
iteration: int | None = None,
|
||||
) -> None:
|
||||
@@ -4107,6 +4279,8 @@ class AgentLoop(AgentProtocol):
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cached_tokens=cached_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
cost_usd=cost_usd,
|
||||
execution_id=execution_id,
|
||||
iteration=iteration,
|
||||
)
|
||||
|
||||
@@ -427,9 +427,20 @@ class NodeConversation:
|
||||
store: ConversationStore | None = None,
|
||||
run_id: str | None = None,
|
||||
compaction_buffer_tokens: int | None = None,
|
||||
compaction_buffer_ratio: float | None = None,
|
||||
compaction_warning_buffer_tokens: int | None = None,
|
||||
) -> None:
|
||||
self._system_prompt = system_prompt
|
||||
# Optional split: when a caller updates the prompt with a
|
||||
# ``dynamic_suffix`` argument, we remember the static prefix and
|
||||
# suffix separately so the LLM wrapper can emit them as two
|
||||
# Anthropic system content blocks with a cache breakpoint between
|
||||
# them. ``_system_prompt`` stays as the concatenated form used for
|
||||
# persistence and for the legacy single-block LLM path.
|
||||
# On restore, these default to the concat/empty pair — the next
|
||||
# AgentLoop iteration's dynamic-prompt refresh step repopulates.
|
||||
self._system_prompt_static: str = system_prompt
|
||||
self._system_prompt_dynamic_suffix: str = ""
|
||||
self._max_context_tokens = max_context_tokens
|
||||
self._compaction_threshold = compaction_threshold
|
||||
# Buffer-based compaction trigger (Gap 7). When set, takes
|
||||
@@ -439,6 +450,11 @@ class NodeConversation:
|
||||
# limit. If left as None the legacy threshold-based rule is
|
||||
# used, keeping old call sites behaving identically.
|
||||
self._compaction_buffer_tokens = compaction_buffer_tokens
|
||||
# Ratio component of the hybrid buffer. Combines additively with
|
||||
# _compaction_buffer_tokens so callers can express "reserve N tokens
|
||||
# plus M% of the window" — the absolute floor matters on tiny
|
||||
# windows, the ratio matters on large ones.
|
||||
self._compaction_buffer_ratio = compaction_buffer_ratio
|
||||
self._compaction_warning_buffer_tokens = compaction_warning_buffer_tokens
|
||||
self._output_keys = output_keys
|
||||
self._store = store
|
||||
@@ -453,15 +469,56 @@ class NodeConversation:
|
||||
|
||||
@property
|
||||
def system_prompt(self) -> str:
|
||||
"""Full concatenated system prompt (static + dynamic suffix, if any).
|
||||
|
||||
This is the canonical form used for persistence and for the legacy
|
||||
single-block LLM path. Split-prompt callers should read
|
||||
``system_prompt_static`` and ``system_prompt_dynamic_suffix`` instead.
|
||||
"""
|
||||
return self._system_prompt
|
||||
|
||||
def update_system_prompt(self, new_prompt: str) -> None:
|
||||
@property
|
||||
def system_prompt_static(self) -> str:
|
||||
"""Static prefix of the system prompt (cache-stable).
|
||||
|
||||
Equals ``system_prompt`` when no split is in use. When the AgentLoop
|
||||
calls ``update_system_prompt(static, dynamic_suffix=...)``, this is
|
||||
the piece sent as the cache-controlled first block.
|
||||
"""
|
||||
return self._system_prompt_static
|
||||
|
||||
@property
|
||||
def system_prompt_dynamic_suffix(self) -> str:
|
||||
"""Dynamic tail of the system prompt (not cached).
|
||||
|
||||
Empty unless the consumer splits its prompt. The LLM wrapper uses a
|
||||
non-empty suffix to emit a two-block system content list with a
|
||||
cache breakpoint between the static prefix and this tail.
|
||||
"""
|
||||
return self._system_prompt_dynamic_suffix
|
||||
|
||||
def update_system_prompt(self, new_prompt: str, dynamic_suffix: str | None = None) -> None:
|
||||
"""Update the system prompt.
|
||||
|
||||
Used in continuous conversation mode at phase transitions to swap
|
||||
Layer 3 (focus) while preserving the conversation history.
|
||||
|
||||
When ``dynamic_suffix`` is provided, ``new_prompt`` is interpreted as
|
||||
the STATIC prefix and ``dynamic_suffix`` as the per-turn tail; they
|
||||
travel to the LLM as two separate cache-controlled blocks but are
|
||||
persisted as a single concatenated string for backward-compat
|
||||
restore. ``new_prompt`` alone (suffix left None) keeps the legacy
|
||||
single-string behavior.
|
||||
"""
|
||||
self._system_prompt = new_prompt
|
||||
if dynamic_suffix is None:
|
||||
# Legacy single-string path — static == full, no suffix split.
|
||||
self._system_prompt = new_prompt
|
||||
self._system_prompt_static = new_prompt
|
||||
self._system_prompt_dynamic_suffix = ""
|
||||
else:
|
||||
self._system_prompt_static = new_prompt
|
||||
self._system_prompt_dynamic_suffix = dynamic_suffix
|
||||
self._system_prompt = f"{new_prompt}\n\n{dynamic_suffix}" if dynamic_suffix else new_prompt
|
||||
self._meta_persisted = False # re-persist with new prompt
|
||||
|
||||
def set_current_phase(self, phase_id: str) -> None:
|
||||
@@ -847,19 +904,30 @@ class NodeConversation:
|
||||
"""True when the conversation should be compacted before the
|
||||
next LLM call.
|
||||
|
||||
Buffer-based rule (Gap 7): trigger when the current estimate
|
||||
plus the configured buffer would exceed the hard context limit.
|
||||
Prevents compaction from firing only AFTER we're already over
|
||||
the wire and forced into a reactive binary-split pass.
|
||||
Hybrid buffer rule: the headroom reserved before compaction fires
|
||||
is the SUM of an absolute fixed component and a ratio of the hard
|
||||
context limit:
|
||||
|
||||
When no buffer is configured, falls back to the multiplicative
|
||||
threshold the old callers were built around.
|
||||
effective_buffer = compaction_buffer_tokens
|
||||
+ compaction_buffer_ratio * max_context_tokens
|
||||
|
||||
The fixed component gives a floor on tiny windows; the ratio
|
||||
keeps the trigger meaningful on large windows where any constant
|
||||
buffer becomes a rounding error (an 8k buffer is 75% on a 32k
|
||||
window but 96% on a 200k window). Compaction fires when the
|
||||
current estimate would consume more than (limit - effective_buffer).
|
||||
|
||||
When neither component is configured, falls back to the legacy
|
||||
multiplicative threshold so old callers keep behaving identically.
|
||||
"""
|
||||
if self._max_context_tokens <= 0:
|
||||
return False
|
||||
if self._compaction_buffer_tokens is not None:
|
||||
budget = self._max_context_tokens - self._compaction_buffer_tokens
|
||||
return self.estimate_tokens() >= max(0, budget)
|
||||
fixed = self._compaction_buffer_tokens
|
||||
ratio = self._compaction_buffer_ratio
|
||||
if fixed is not None or ratio is not None:
|
||||
effective_buffer = (fixed or 0) + (ratio or 0.0) * self._max_context_tokens
|
||||
budget = self._max_context_tokens - effective_buffer
|
||||
return self.estimate_tokens() >= max(0.0, budget)
|
||||
return self.estimate_tokens() >= self._max_context_tokens * self._compaction_threshold
|
||||
|
||||
def compaction_warning(self) -> bool:
|
||||
@@ -1516,6 +1584,7 @@ class NodeConversation:
|
||||
"max_context_tokens": self._max_context_tokens,
|
||||
"compaction_threshold": self._compaction_threshold,
|
||||
"compaction_buffer_tokens": self._compaction_buffer_tokens,
|
||||
"compaction_buffer_ratio": self._compaction_buffer_ratio,
|
||||
"compaction_warning_buffer_tokens": (self._compaction_warning_buffer_tokens),
|
||||
"output_keys": self._output_keys,
|
||||
}
|
||||
@@ -1565,6 +1634,7 @@ class NodeConversation:
|
||||
store=store,
|
||||
run_id=run_id,
|
||||
compaction_buffer_tokens=meta.get("compaction_buffer_tokens"),
|
||||
compaction_buffer_ratio=meta.get("compaction_buffer_ratio"),
|
||||
compaction_warning_buffer_tokens=meta.get("compaction_warning_buffer_tokens"),
|
||||
)
|
||||
conv._meta_persisted = True
|
||||
|
||||
@@ -12,6 +12,7 @@ import json
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from framework.agent_loop.conversation import ConversationStore, NodeConversation
|
||||
@@ -191,15 +192,21 @@ async def drain_injection_queue(
|
||||
else:
|
||||
logger.info("[drain] no vision fallback available; images dropped")
|
||||
image_content = None
|
||||
# Real user input is stored as-is; external events get a prefix
|
||||
# Stamp every injected event with its arrival time so the model
|
||||
# has a consistent temporal log to reason over (and so the
|
||||
# stamp lives inside byte-stable conversation history instead
|
||||
# of a per-turn system-prompt tail). Minute precision is what
|
||||
# the queen needs for conversational / scheduling context.
|
||||
stamp = datetime.now().astimezone().strftime("%Y-%m-%d %H:%M %Z")
|
||||
if is_client_input:
|
||||
stamped = f"[{stamp}] {content}" if content else f"[{stamp}]"
|
||||
await conversation.add_user_message(
|
||||
content,
|
||||
stamped,
|
||||
is_client_input=True,
|
||||
image_content=image_content,
|
||||
)
|
||||
else:
|
||||
await conversation.add_user_message(f"[External event]: {content}")
|
||||
await conversation.add_user_message(f"[{stamp}] [External event] {content}")
|
||||
count += 1
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
@@ -232,7 +239,8 @@ async def drain_trigger_queue(
|
||||
payload_str = json.dumps(t.payload, default=str)
|
||||
parts.append(f"[TRIGGER: {t.trigger_type}/{t.source_id}]{task_line}\n{payload_str}")
|
||||
|
||||
combined = "\n\n".join(parts)
|
||||
stamp = datetime.now().astimezone().strftime("%Y-%m-%d %H:%M %Z")
|
||||
combined = f"[{stamp}]\n" + "\n\n".join(parts)
|
||||
logger.info("[drain] %d trigger(s): %s", len(triggers), combined[:200])
|
||||
# Tag the message so the UI can render a banner instead of the raw
|
||||
# `[TRIGGER: ...]` text. The LLM still sees `combined` verbatim.
|
||||
|
||||
@@ -108,6 +108,8 @@ async def publish_llm_turn_complete(
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cached_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
cost_usd: float = 0.0,
|
||||
execution_id: str = "",
|
||||
iteration: int | None = None,
|
||||
) -> None:
|
||||
@@ -120,6 +122,8 @@ async def publish_llm_turn_complete(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cached_tokens=cached_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
cost_usd=cost_usd,
|
||||
execution_id=execution_id,
|
||||
iteration=iteration,
|
||||
)
|
||||
|
||||
@@ -69,6 +69,20 @@ class LoopConfig:
|
||||
# and less tight than Anthropic's own counting. Override via
|
||||
# LoopConfig for larger windows.
|
||||
compaction_buffer_tokens: int = 8_000
|
||||
# Ratio-based component of the hybrid compaction buffer. Effective
|
||||
# headroom reserved before compaction fires is
|
||||
# compaction_buffer_tokens + compaction_buffer_ratio * max_context_tokens
|
||||
# The ratio scales with the model's window where the absolute fixed
|
||||
# component does not (an 8k absolute buffer is 75% trigger on a 32k
|
||||
# window but 96% on a 200k window). Combining them gives an absolute
|
||||
# floor sized for the worst-case single tool result (one un-spilled
|
||||
# max_tool_result_chars payload ≈ 30k chars ≈ 7.5k tokens, rounded to
|
||||
# 8k) plus a fractional headroom that keeps the trigger meaningful on
|
||||
# large windows, so the inner tool loop always has room to grow
|
||||
# without tripping the mid-turn pre-send guard. Defaults: 8k + 15%.
|
||||
# On 32k that's a 12.8k buffer (~60% trigger); on 200k it's 38k
|
||||
# (~81% trigger); on 1M it's 158k (~84% trigger).
|
||||
compaction_buffer_ratio: float = 0.15
|
||||
# Warning is emitted one buffer earlier so the user/telemetry gets
|
||||
# a "we're close" signal without triggering a compaction pass.
|
||||
compaction_warning_buffer_tokens: int = 12_000
|
||||
|
||||
@@ -0,0 +1,247 @@
|
||||
"""Vision-fallback subagent for tool-result images on text-only LLMs.
|
||||
|
||||
When a tool returns image content but the main agent's model can't
|
||||
accept image blocks (i.e. its catalog entry has ``supports_vision: false``),
|
||||
the framework strips the images before they ever reach the LLM. Without
|
||||
this module, the agent then sees only the tool's text envelope (URL,
|
||||
dimensions, size) and is blind to whatever the image actually shows.
|
||||
|
||||
This module provides:
|
||||
|
||||
* ``caption_tool_image()`` — direct LiteLLM call to a configured
|
||||
vision model (``vision_fallback`` block in ``~/.hive/configuration.json``)
|
||||
that takes the agent's intent + the image(s) and returns a textual
|
||||
description tailored to that intent.
|
||||
* ``extract_intent_for_tool()`` — pull the most recent assistant text
|
||||
+ the tool call descriptor and concatenate them into a ≤2KB intent
|
||||
string the vision subagent can reason against.
|
||||
|
||||
Both helpers degrade silently — return ``None`` / a placeholder rather
|
||||
than raise — so a vision-fallback failure can never kill the main
|
||||
agent's run. The agent-loop call site is responsible for chaining
|
||||
through to the existing generic-caption rotation
|
||||
(``_describe_images_as_text``) on a None return.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from framework.config import (
|
||||
get_vision_fallback_api_base,
|
||||
get_vision_fallback_api_key,
|
||||
get_vision_fallback_model,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..conversation import NodeConversation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Hard cap on the intent string handed to the vision subagent. The
|
||||
# subagent only needs the agent's recent reasoning + the tool descriptor;
|
||||
# anything longer is wasted tokens (and risks pushing past the vision
|
||||
# model's context with the image attached).
|
||||
_INTENT_MAX_CHARS = 4096
|
||||
|
||||
# Cap on the tool args JSON snippet inside the intent. Some tool inputs
|
||||
# (large strings, file contents) would dominate the intent if uncapped.
|
||||
_TOOL_ARGS_MAX_CHARS = 4096
|
||||
|
||||
# Subagent system prompt — kept short so it fits within any provider's
|
||||
# system-prompt budget alongside the user message + image. Tells the
|
||||
# subagent its role and constrains output format.
|
||||
#
|
||||
# Coordinate labeling: the main agent's browser tools
|
||||
# (browser_click_coordinate / browser_hover_coordinate / browser_press_at)
|
||||
# accept VIEWPORT FRACTIONS (x, y) in [0..1] where (0,0) is the top-left
|
||||
# and (1,1) is the bottom-right of the screenshot. Without coordinates
|
||||
# the text-only agent has no way to act on what we describe — it can
|
||||
# read the caption but cannot point. So for every interactive element
|
||||
# we name (button, link, input, icon, tab, menu item, dialog control),
|
||||
# include its approximate viewport-fraction centre as ``(fx, fy)``
|
||||
# right after the element's name, e.g. ``"Submit" button (0.83, 0.92)``.
|
||||
# Three rules: (1) coordinates only for things plausibly clickable /
|
||||
# hoverable / typeable — don't tag pure body text or decorative
|
||||
# graphics. (2) Eyeball to two decimal places; precision beyond that
|
||||
# is false confidence. (3) Never invent — if an element is partly
|
||||
# off-screen or you can't locate it, omit the coordinate rather than
|
||||
# guessing.
|
||||
_VISION_SUBAGENT_SYSTEM = (
|
||||
"You are a vision subagent for a text-only main agent. The main "
|
||||
"agent invoked a tool that returned the image(s) attached. Their "
|
||||
"intent (their reasoning + the tool call) is below. Describe what "
|
||||
"the image shows in service of their intent — concrete, factual, "
|
||||
"no speculation. If their intent asks a yes/no question, answer it "
|
||||
"directly first.\n\n"
|
||||
"Coordinate labeling: the main agent uses fractional viewport "
|
||||
"coordinates (x, y) in [0..1] — (0, 0) is the top-left of the "
|
||||
"image, (1, 1) is the bottom-right — to drive its click / hover / "
|
||||
"key-press tools. For every interactive element you mention "
|
||||
"(button, link, input, checkbox, radio, dropdown, tab, menu item, "
|
||||
"dialog control, icon), append its approximate centre as "
|
||||
"``(fx, fy)`` immediately after the element's name or label, e.g. "
|
||||
'``"Submit" button (0.83, 0.92)`` or ``profile avatar icon '
|
||||
"(0.05, 0.07)``. Use two decimal places — more is false precision. "
|
||||
"Skip coordinates for pure body text and decorative elements that "
|
||||
"aren't clickable. If an element is partially off-screen or you "
|
||||
"cannot reliably locate its centre, omit the coordinate rather "
|
||||
"than guessing.\n\n"
|
||||
"Output plain text, no markdown, ≤ 600 words."
|
||||
)
|
||||
|
||||
|
||||
def extract_intent_for_tool(
|
||||
conversation: NodeConversation,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any] | None,
|
||||
) -> str:
|
||||
"""Build the intent string passed to the vision subagent.
|
||||
|
||||
Combines the most recent assistant text (the LLM's reasoning right
|
||||
before invoking the tool) with a structured tool-call descriptor.
|
||||
Truncates to ``_INTENT_MAX_CHARS`` total, favouring the head of the
|
||||
assistant text where goal-stating sentences usually live.
|
||||
|
||||
If no preceding assistant text exists (rare — first turn), falls
|
||||
back to ``"<no preceding reasoning>"`` so the subagent still gets
|
||||
the tool descriptor.
|
||||
"""
|
||||
args_json: str
|
||||
try:
|
||||
args_json = json.dumps(tool_args or {}, default=str)
|
||||
except Exception:
|
||||
args_json = repr(tool_args)
|
||||
if len(args_json) > _TOOL_ARGS_MAX_CHARS:
|
||||
args_json = args_json[:_TOOL_ARGS_MAX_CHARS] + "…"
|
||||
|
||||
tool_line = f"Called: {tool_name}({args_json})"
|
||||
|
||||
# Walk newest → oldest, take the first assistant message with text.
|
||||
assistant_text = ""
|
||||
try:
|
||||
messages = getattr(conversation, "_messages", []) or []
|
||||
for msg in reversed(messages):
|
||||
if getattr(msg, "role", None) != "assistant":
|
||||
continue
|
||||
content = getattr(msg, "content", "") or ""
|
||||
if isinstance(content, str) and content.strip():
|
||||
assistant_text = content.strip()
|
||||
break
|
||||
except Exception:
|
||||
# Defensive — the agent loop must keep running even if the
|
||||
# conversation structure changes shape.
|
||||
assistant_text = ""
|
||||
|
||||
if not assistant_text:
|
||||
assistant_text = "<no preceding reasoning>"
|
||||
|
||||
# Intent = tool descriptor (always intact) + reasoning (truncated).
|
||||
head = f"{tool_line}\n\nReasoning before call:\n"
|
||||
budget = _INTENT_MAX_CHARS - len(head)
|
||||
if budget < 100:
|
||||
# Tool descriptor is huge somehow — truncate it.
|
||||
return head[:_INTENT_MAX_CHARS]
|
||||
if len(assistant_text) > budget:
|
||||
assistant_text = assistant_text[: budget - 1] + "…"
|
||||
return head + assistant_text
|
||||
|
||||
|
||||
async def caption_tool_image(
|
||||
intent: str,
|
||||
image_content: list[dict[str, Any]],
|
||||
*,
|
||||
timeout_s: float = 30.0,
|
||||
) -> str | None:
|
||||
"""Caption the given images using the configured ``vision_fallback`` model.
|
||||
|
||||
Returns the model's text response on success, or ``None`` on any
|
||||
failure (no config, no API key, timeout, exception, empty
|
||||
response). Callers chain to the next stage of the fallback on None.
|
||||
|
||||
Logs each call to ``~/.hive/llm_logs`` via ``log_llm_turn`` so the
|
||||
cost / latency / quality are auditable post-hoc, tagged with
|
||||
``execution_id="vision_fallback_subagent"``.
|
||||
"""
|
||||
model = get_vision_fallback_model()
|
||||
if not model:
|
||||
return None
|
||||
|
||||
api_key = get_vision_fallback_api_key()
|
||||
api_base = get_vision_fallback_api_base()
|
||||
if not api_key:
|
||||
logger.debug("vision_fallback configured but no API key resolved; skipping")
|
||||
return None
|
||||
|
||||
try:
|
||||
import litellm
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
user_blocks: list[dict[str, Any]] = [{"type": "text", "text": intent}]
|
||||
user_blocks.extend(image_content)
|
||||
messages = [
|
||||
{"role": "system", "content": _VISION_SUBAGENT_SYSTEM},
|
||||
{"role": "user", "content": user_blocks},
|
||||
]
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": 1024,
|
||||
"timeout": timeout_s,
|
||||
"api_key": api_key,
|
||||
}
|
||||
if api_base:
|
||||
kwargs["api_base"] = api_base
|
||||
|
||||
started = datetime.now()
|
||||
caption: str | None = None
|
||||
error_text: str | None = None
|
||||
try:
|
||||
response = await litellm.acompletion(**kwargs)
|
||||
text = (response.choices[0].message.content or "").strip()
|
||||
if text:
|
||||
caption = text
|
||||
except Exception as exc:
|
||||
error_text = f"{type(exc).__name__}: {exc}"
|
||||
logger.debug("vision_fallback model '%s' failed: %s", model, exc)
|
||||
|
||||
# Best-effort audit log so users can grep ~/.hive/llm_logs/ for
|
||||
# vision-fallback subagent calls. Failures here must not bubble.
|
||||
try:
|
||||
from framework.tracker.llm_debug_logger import log_llm_turn
|
||||
|
||||
# Don't dump the base64 image data into the log file — that
|
||||
# would balloon the jsonl with mostly-binary noise.
|
||||
elided_blocks: list[dict[str, Any]] = [{"type": "text", "text": intent}]
|
||||
elided_blocks.extend({"type": "image_url", "image_url": {"url": "<elided>"}} for _ in range(len(image_content)))
|
||||
log_llm_turn(
|
||||
node_id="vision_fallback_subagent",
|
||||
stream_id="vision_fallback",
|
||||
execution_id="vision_fallback_subagent",
|
||||
iteration=0,
|
||||
system_prompt=_VISION_SUBAGENT_SYSTEM,
|
||||
messages=[{"role": "user", "content": elided_blocks}],
|
||||
assistant_text=caption or "",
|
||||
tool_calls=[],
|
||||
tool_results=[],
|
||||
token_counts={
|
||||
"model": model,
|
||||
"elapsed_s": (datetime.now() - started).total_seconds(),
|
||||
"error": error_text,
|
||||
"num_images": len(image_content),
|
||||
"intent_chars": len(intent),
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return caption
|
||||
|
||||
|
||||
__all__ = ["caption_tool_image", "extract_intent_for_tool"]
|
||||
@@ -53,7 +53,14 @@ def build_prompt_spec(
|
||||
# trigger tools are present in this agent's tool list (e.g. browser_*
|
||||
# pulls in hive.browser-automation). Keeps non-browser agents lean.
|
||||
tool_names = [getattr(t, "name", "") for t in (getattr(ctx, "available_tools", None) or [])]
|
||||
skills_catalog_prompt = augment_catalog_for_tools(ctx.skills_catalog_prompt or "", tool_names)
|
||||
raw_catalog = ctx.skills_catalog_prompt or ""
|
||||
dynamic_catalog = getattr(ctx, "dynamic_skills_catalog_provider", None)
|
||||
if dynamic_catalog is not None:
|
||||
try:
|
||||
raw_catalog = dynamic_catalog() or ""
|
||||
except Exception:
|
||||
raw_catalog = ctx.skills_catalog_prompt or ""
|
||||
skills_catalog_prompt = augment_catalog_for_tools(raw_catalog, tool_names)
|
||||
|
||||
return PromptSpec(
|
||||
identity_prompt=ctx.identity_prompt or "",
|
||||
|
||||
@@ -182,7 +182,24 @@ class AgentContext:
|
||||
|
||||
dynamic_tools_provider: Any = None
|
||||
dynamic_prompt_provider: Any = None
|
||||
# Optional Callable[[], str]: when set alongside ``dynamic_prompt_provider``,
|
||||
# the AgentLoop sends the system prompt as two pieces — the result of
|
||||
# ``dynamic_prompt_provider`` is the STATIC block (cached), and this
|
||||
# provider returns the DYNAMIC suffix (not cached). The LLM wrapper
|
||||
# emits them as two Anthropic system content blocks with a cache
|
||||
# breakpoint between them for providers that honor ``cache_control``.
|
||||
# For providers that don't, the two strings are concatenated. Used by
|
||||
# the Queen to keep her persona/role/tools block warm across iterations
|
||||
# while the recall + timestamp tail refreshes per user turn.
|
||||
dynamic_prompt_suffix_provider: Any = None
|
||||
dynamic_memory_provider: Any = None
|
||||
# Optional Callable[[], str]: when set, the current skills-catalog
|
||||
# prompt is sourced from this provider each iteration. Lets workers
|
||||
# pick up UI toggles without restarting the run. Queen agents already
|
||||
# rebuild the whole prompt via dynamic_prompt_provider — this field
|
||||
# is a surgical alternative used by colony workers where the rest of
|
||||
# the prompt stays constant and we don't want to thrash the cache.
|
||||
dynamic_skills_catalog_provider: Any = None
|
||||
|
||||
skills_catalog_prompt: str = ""
|
||||
protocols_prompt: str = ""
|
||||
|
||||
@@ -0,0 +1,217 @@
|
||||
"""Per-queen tool configuration sidecar (``tools.json``).
|
||||
|
||||
Lives at ``~/.hive/agents/queens/{queen_id}/tools.json`` alongside
|
||||
``profile.yaml``. Kept separate so identity (name, title, core traits)
|
||||
stays human-authored and lean, while the machine-managed tool allowlist
|
||||
can grow (per-tool overrides, audit timestamps, future per-phase rules)
|
||||
without bloating the profile.
|
||||
|
||||
Schema::
|
||||
|
||||
{
|
||||
"enabled_mcp_tools": ["read_file", ...] | null,
|
||||
"updated_at": "2026-04-21T12:34:56+00:00"
|
||||
}
|
||||
|
||||
- ``null`` / missing file → default "allow every MCP tool".
|
||||
- ``[]`` → explicitly disable every MCP tool.
|
||||
- ``["foo", "bar"]`` → only those MCP tool names pass the filter.
|
||||
|
||||
Atomic writes via ``os.replace`` follow the same pattern as
|
||||
``framework.host.colony_metadata.update_colony_metadata``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
from framework.config import QUEENS_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def tools_config_path(queen_id: str) -> Path:
|
||||
"""Return the on-disk path to a queen's ``tools.json``."""
|
||||
return QUEENS_DIR / queen_id / "tools.json"
|
||||
|
||||
|
||||
def _atomic_write_json(path: Path, data: dict[str, Any]) -> None:
|
||||
"""Write ``data`` to ``path`` atomically via tempfile + replace."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd, tmp = tempfile.mkstemp(
|
||||
prefix=".tools.",
|
||||
suffix=".json.tmp",
|
||||
dir=str(path.parent),
|
||||
)
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as fh:
|
||||
json.dump(data, fh, indent=2)
|
||||
fh.flush()
|
||||
os.fsync(fh.fileno())
|
||||
os.replace(tmp, path)
|
||||
except BaseException:
|
||||
try:
|
||||
os.unlink(tmp)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
|
||||
def _migrate_from_profile_if_needed(queen_id: str) -> list[str] | None:
|
||||
"""Hoist a legacy ``enabled_mcp_tools`` field out of ``profile.yaml``.
|
||||
|
||||
Returns the migrated value (or ``None`` if nothing to migrate). After
|
||||
migration the sidecar exists on disk and the profile YAML no longer
|
||||
contains ``enabled_mcp_tools``. Safe to call repeatedly.
|
||||
"""
|
||||
profile_path = QUEENS_DIR / queen_id / "profile.yaml"
|
||||
if not profile_path.exists():
|
||||
return None
|
||||
try:
|
||||
data = yaml.safe_load(profile_path.read_text(encoding="utf-8"))
|
||||
except (yaml.YAMLError, OSError):
|
||||
logger.warning("Could not read profile.yaml during tools migration: %s", queen_id)
|
||||
return None
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
if "enabled_mcp_tools" not in data:
|
||||
return None
|
||||
|
||||
raw = data.pop("enabled_mcp_tools")
|
||||
enabled: list[str] | None
|
||||
if raw is None:
|
||||
enabled = None
|
||||
elif isinstance(raw, list) and all(isinstance(x, str) for x in raw):
|
||||
enabled = raw
|
||||
else:
|
||||
logger.warning(
|
||||
"Legacy enabled_mcp_tools on queen %s had unexpected shape %r; dropping",
|
||||
queen_id,
|
||||
raw,
|
||||
)
|
||||
enabled = None
|
||||
|
||||
# Write sidecar first, then rewrite profile — if the second step
|
||||
# fails we still have the config available and won't re-migrate.
|
||||
_atomic_write_json(
|
||||
tools_config_path(queen_id),
|
||||
{
|
||||
"enabled_mcp_tools": enabled,
|
||||
"updated_at": datetime.now(UTC).isoformat(),
|
||||
},
|
||||
)
|
||||
profile_path.write_text(
|
||||
yaml.safe_dump(data, sort_keys=False, allow_unicode=True),
|
||||
encoding="utf-8",
|
||||
)
|
||||
logger.info(
|
||||
"Migrated enabled_mcp_tools for queen %s from profile.yaml to tools.json",
|
||||
queen_id,
|
||||
)
|
||||
return enabled
|
||||
|
||||
|
||||
def tools_config_exists(queen_id: str) -> bool:
|
||||
"""Return True when the queen has a persisted ``tools.json`` sidecar.
|
||||
|
||||
Used by callers that need to tell an explicit user save apart from a
|
||||
fallthrough to the role-based default (both can return the same
|
||||
value from ``load_queen_tools_config``).
|
||||
"""
|
||||
return tools_config_path(queen_id).exists()
|
||||
|
||||
|
||||
def delete_queen_tools_config(queen_id: str) -> bool:
|
||||
"""Delete the queen's ``tools.json`` sidecar if present.
|
||||
|
||||
Returns ``True`` if a file was removed, ``False`` if none existed.
|
||||
The next ``load_queen_tools_config`` call falls through to the
|
||||
role-based default (or allow-all for unknown queens).
|
||||
"""
|
||||
path = tools_config_path(queen_id)
|
||||
if not path.exists():
|
||||
return False
|
||||
try:
|
||||
path.unlink()
|
||||
return True
|
||||
except OSError:
|
||||
logger.warning("Failed to delete %s", path, exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
def load_queen_tools_config(
|
||||
queen_id: str,
|
||||
mcp_catalog: dict[str, list[dict]] | None = None,
|
||||
) -> list[str] | None:
|
||||
"""Return the queen's MCP tool allowlist, or ``None`` for default-allow.
|
||||
|
||||
Order of resolution:
|
||||
1. ``tools.json`` sidecar (authoritative; user has saved).
|
||||
2. Legacy ``profile.yaml`` field (migrated and deleted on first read).
|
||||
3. Role-based default from ``queen_tools_defaults`` when the queen
|
||||
is in the known persona table. ``mcp_catalog`` lets the helper
|
||||
expand ``@server:NAME`` shorthands; without it, shorthand entries
|
||||
are dropped.
|
||||
4. ``None`` — default "allow every MCP tool".
|
||||
"""
|
||||
path = tools_config_path(queen_id)
|
||||
if path.exists():
|
||||
try:
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError):
|
||||
logger.warning("Invalid %s; treating as default-allow", path)
|
||||
return None
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
raw = data.get("enabled_mcp_tools")
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, list) and all(isinstance(x, str) for x in raw):
|
||||
return raw
|
||||
logger.warning("Unexpected enabled_mcp_tools shape in %s; ignoring", path)
|
||||
return None
|
||||
|
||||
migrated = _migrate_from_profile_if_needed(queen_id)
|
||||
if migrated is not None:
|
||||
return migrated
|
||||
# If migration just hoisted an explicit ``null`` out of profile.yaml,
|
||||
# a sidecar with allow-all semantics now exists on disk. Honor that
|
||||
# over the role default so an explicit user choice wins.
|
||||
if tools_config_path(queen_id).exists():
|
||||
return None
|
||||
|
||||
# No sidecar, nothing to migrate — fall back to role-based default.
|
||||
from framework.agents.queen.queen_tools_defaults import resolve_queen_default_tools
|
||||
|
||||
return resolve_queen_default_tools(queen_id, mcp_catalog)
|
||||
|
||||
|
||||
def update_queen_tools_config(
|
||||
queen_id: str,
|
||||
enabled_mcp_tools: list[str] | None,
|
||||
) -> list[str] | None:
|
||||
"""Persist the queen's MCP allowlist to ``tools.json``.
|
||||
|
||||
Raises ``FileNotFoundError`` if the queen's directory is missing —
|
||||
we refuse to silently create a sidecar for a queen that doesn't
|
||||
exist.
|
||||
"""
|
||||
queen_dir = QUEENS_DIR / queen_id
|
||||
if not queen_dir.exists():
|
||||
raise FileNotFoundError(f"Queen directory not found: {queen_id}")
|
||||
_atomic_write_json(
|
||||
tools_config_path(queen_id),
|
||||
{
|
||||
"enabled_mcp_tools": enabled_mcp_tools,
|
||||
"updated_at": datetime.now(UTC).isoformat(),
|
||||
},
|
||||
)
|
||||
return enabled_mcp_tools
|
||||
@@ -0,0 +1,272 @@
|
||||
"""Role-based default tool allowlists for queens.
|
||||
|
||||
Every queen inherits the same MCP surface (all servers loaded for the
|
||||
queen agent), but exposing 94+ tools to every persona clutters the LLM
|
||||
tool catalog and wastes prompt tokens. This module defines a sensible
|
||||
default allowlist per queen persona so, e.g., Head of Legal doesn't
|
||||
see port scanners and Head of Finance doesn't see ``apply_patch``.
|
||||
|
||||
Defaults apply only when the queen has no ``tools.json`` sidecar — the
|
||||
moment the user saves an allowlist through the Tool Library, the
|
||||
sidecar becomes authoritative. A DELETE on the tools endpoint removes
|
||||
the sidecar and brings the queen back to her role default.
|
||||
|
||||
Category entries support a ``@server:NAME`` shorthand that expands to
|
||||
every tool name registered against that MCP server in the current
|
||||
catalog. This keeps the category table short and drift-free when new
|
||||
tools are added (e.g. browser_* auto-joins the ``browser`` category).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Categories — reusable bundles of MCP tool names.
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Each category is a flat list of either concrete tool names or the
|
||||
# ``@server:NAME`` shorthand. The shorthand expands to every tool the
|
||||
# given MCP server currently exposes (requires a live catalog; when one
|
||||
# is not available the shorthand is silently dropped so we fall back to
|
||||
# the named entries only).
|
||||
|
||||
_TOOL_CATEGORIES: dict[str, list[str]] = {
|
||||
# Read-only file operations — safe baseline for every knowledge queen.
|
||||
"file_read": [
|
||||
"read_file",
|
||||
"list_directory",
|
||||
"list_dir",
|
||||
"list_files",
|
||||
"search_files",
|
||||
"grep_search",
|
||||
"pdf_read",
|
||||
],
|
||||
# File mutation — only personas that author or edit artifacts.
|
||||
"file_write": [
|
||||
"write_file",
|
||||
"edit_file",
|
||||
"apply_diff",
|
||||
"apply_patch",
|
||||
"replace_file_content",
|
||||
"hashline_edit",
|
||||
"undo_changes",
|
||||
],
|
||||
# Shell + process control — engineering personas only.
|
||||
"shell": [
|
||||
"run_command",
|
||||
"execute_command_tool",
|
||||
"bash_kill",
|
||||
"bash_output",
|
||||
],
|
||||
# Tabular data. CSV/Excel read/write + DuckDB SQL.
|
||||
"data": [
|
||||
"csv_read",
|
||||
"csv_info",
|
||||
"csv_write",
|
||||
"csv_append",
|
||||
"csv_sql",
|
||||
"excel_read",
|
||||
"excel_info",
|
||||
"excel_write",
|
||||
"excel_append",
|
||||
"excel_search",
|
||||
"excel_sheet_list",
|
||||
"excel_sql",
|
||||
],
|
||||
# Browser automation — every tool from the gcu-tools MCP server.
|
||||
"browser": ["@server:gcu-tools"],
|
||||
# External research / information-gathering.
|
||||
"research": [
|
||||
"search_papers",
|
||||
"download_paper",
|
||||
"search_wikipedia",
|
||||
"web_scrape",
|
||||
],
|
||||
# Security scanners — pentest-ish, only for engineering/security roles.
|
||||
"security": [
|
||||
"dns_security_scan",
|
||||
"http_headers_scan",
|
||||
"port_scan",
|
||||
"ssl_tls_scan",
|
||||
"subdomain_enumerate",
|
||||
"tech_stack_detect",
|
||||
"risk_score",
|
||||
],
|
||||
# Lightweight context helpers — good default for every queen.
|
||||
"time_context": [
|
||||
"get_current_time",
|
||||
"get_account_info",
|
||||
],
|
||||
# Runtime log inspection — debug/observability for builder personas.
|
||||
"runtime_inspection": [
|
||||
"query_runtime_logs",
|
||||
"query_runtime_log_details",
|
||||
"query_runtime_log_raw",
|
||||
],
|
||||
# Agent-management tools — building/validating/checking agents.
|
||||
"agent_mgmt": [
|
||||
"list_agents",
|
||||
"list_agent_tools",
|
||||
"list_agent_sessions",
|
||||
"get_agent_checkpoint",
|
||||
"list_agent_checkpoints",
|
||||
"run_agent_tests",
|
||||
"save_agent_draft",
|
||||
"confirm_and_build",
|
||||
"validate_agent_package",
|
||||
"validate_agent_tools",
|
||||
"enqueue_task",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-queen mapping.
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Built from the queen personas in ``queen_profiles.DEFAULT_QUEENS``. The
|
||||
# goal is "just enough" — a queen should see tools she'd plausibly call
|
||||
# for her stated role, nothing more. Users curate further via the Tool
|
||||
# Library if they want.
|
||||
#
|
||||
# A queen whose ID is NOT in this map falls through to "allow every MCP
|
||||
# tool" (the original behavior), which keeps the system compatible with
|
||||
# user-added custom queen IDs that we don't know about.
|
||||
|
||||
QUEEN_DEFAULT_CATEGORIES: dict[str, list[str]] = {
|
||||
# Head of Technology — builds and operates systems; full toolkit.
|
||||
"queen_technology": [
|
||||
"file_read",
|
||||
"file_write",
|
||||
"shell",
|
||||
"data",
|
||||
"browser",
|
||||
"research",
|
||||
"security",
|
||||
"time_context",
|
||||
"runtime_inspection",
|
||||
"agent_mgmt",
|
||||
],
|
||||
# Head of Growth — data, experiments, competitor research; no shell/security.
|
||||
"queen_growth": [
|
||||
"file_read",
|
||||
"file_write",
|
||||
"data",
|
||||
"browser",
|
||||
"research",
|
||||
"time_context",
|
||||
],
|
||||
# Head of Product Strategy — user research + roadmaps; no shell/security.
|
||||
"queen_product_strategy": [
|
||||
"file_read",
|
||||
"file_write",
|
||||
"data",
|
||||
"browser",
|
||||
"research",
|
||||
"time_context",
|
||||
],
|
||||
# Head of Finance — financial models (CSV/Excel heavy), market research.
|
||||
"queen_finance_fundraising": [
|
||||
"file_read",
|
||||
"file_write",
|
||||
"data",
|
||||
"browser",
|
||||
"research",
|
||||
"time_context",
|
||||
],
|
||||
# Head of Legal — reads contracts/PDFs, researches; no shell/data/security.
|
||||
"queen_legal": [
|
||||
"file_read",
|
||||
"file_write",
|
||||
"browser",
|
||||
"research",
|
||||
"time_context",
|
||||
],
|
||||
# Head of Brand & Design — visual refs, style guides; no shell/data/security.
|
||||
"queen_brand_design": [
|
||||
"file_read",
|
||||
"file_write",
|
||||
"browser",
|
||||
"research",
|
||||
"time_context",
|
||||
],
|
||||
# Head of Talent — candidate pipelines, resumes; data + browser heavy.
|
||||
"queen_talent": [
|
||||
"file_read",
|
||||
"file_write",
|
||||
"data",
|
||||
"browser",
|
||||
"research",
|
||||
"time_context",
|
||||
],
|
||||
# Head of Operations — processes, automation, observability.
|
||||
"queen_operations": [
|
||||
"file_read",
|
||||
"file_write",
|
||||
"data",
|
||||
"browser",
|
||||
"research",
|
||||
"time_context",
|
||||
"runtime_inspection",
|
||||
"agent_mgmt",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def has_role_default(queen_id: str) -> bool:
|
||||
"""Return True when ``queen_id`` is known to the category table."""
|
||||
return queen_id in QUEEN_DEFAULT_CATEGORIES
|
||||
|
||||
|
||||
def resolve_queen_default_tools(
|
||||
queen_id: str,
|
||||
mcp_catalog: dict[str, list[dict[str, Any]]] | None = None,
|
||||
) -> list[str] | None:
|
||||
"""Return the role-based default allowlist for ``queen_id``.
|
||||
|
||||
Arguments:
|
||||
queen_id: Profile ID (e.g. ``"queen_technology"``).
|
||||
mcp_catalog: Optional mapping of ``{server_name: [{"name": ...}, ...]}``
|
||||
used to expand ``@server:NAME`` shorthands in categories.
|
||||
When absent, shorthand entries are dropped and the result
|
||||
contains only the explicitly-named tools.
|
||||
|
||||
Returns:
|
||||
A deduplicated list of tool names, or ``None`` if the queen has
|
||||
no role entry (caller should treat as "allow every MCP tool").
|
||||
"""
|
||||
categories = QUEEN_DEFAULT_CATEGORIES.get(queen_id)
|
||||
if not categories:
|
||||
return None
|
||||
|
||||
names: list[str] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
def _add(name: str) -> None:
|
||||
if name and name not in seen:
|
||||
seen.add(name)
|
||||
names.append(name)
|
||||
|
||||
for cat in categories:
|
||||
for entry in _TOOL_CATEGORIES.get(cat, []):
|
||||
if entry.startswith("@server:"):
|
||||
server_name = entry[len("@server:") :]
|
||||
if mcp_catalog is None:
|
||||
logger.debug(
|
||||
"resolve_queen_default_tools: catalog missing; cannot expand %s",
|
||||
entry,
|
||||
)
|
||||
continue
|
||||
for tool in mcp_catalog.get(server_name, []) or []:
|
||||
tname = tool.get("name") if isinstance(tool, dict) else None
|
||||
if tname:
|
||||
_add(tname)
|
||||
else:
|
||||
_add(entry)
|
||||
|
||||
return names
|
||||
@@ -155,6 +155,57 @@ def get_preferred_worker_model() -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def get_vision_fallback_model() -> str | None:
|
||||
"""Return the configured vision-fallback model, or None if not configured.
|
||||
|
||||
Reads from the ``vision_fallback`` section of ~/.hive/configuration.json.
|
||||
Used by the agent-loop hook that captions tool-result images when the
|
||||
main agent's model cannot accept image content (text-only LLMs).
|
||||
|
||||
When this returns None the fallback chain skips the configured-subagent
|
||||
stage and proceeds straight to the generic caption rotation
|
||||
(``_describe_images_as_text``).
|
||||
"""
|
||||
vision = get_hive_config().get("vision_fallback", {})
|
||||
if vision.get("provider") and vision.get("model"):
|
||||
provider = str(vision["provider"])
|
||||
model = str(vision["model"]).strip()
|
||||
if provider.lower() == "openrouter" and model.lower().startswith("openrouter/"):
|
||||
model = model[len("openrouter/") :]
|
||||
if model:
|
||||
return f"{provider}/{model}"
|
||||
return None
|
||||
|
||||
|
||||
def get_vision_fallback_api_key() -> str | None:
|
||||
"""Return the API key for the vision-fallback model.
|
||||
|
||||
Resolution order: ``vision_fallback.api_key_env_var`` from the env,
|
||||
then the default ``get_api_key()``. No subscription-token branches —
|
||||
vision fallback is intended for hosted vision models (Anthropic,
|
||||
OpenAI, Google), not for the subscription-bearer providers.
|
||||
"""
|
||||
vision = get_hive_config().get("vision_fallback", {})
|
||||
if not vision:
|
||||
return get_api_key()
|
||||
api_key_env_var = vision.get("api_key_env_var")
|
||||
if api_key_env_var:
|
||||
return os.environ.get(api_key_env_var)
|
||||
return get_api_key()
|
||||
|
||||
|
||||
def get_vision_fallback_api_base() -> str | None:
|
||||
"""Return the api_base for the vision-fallback model, or None."""
|
||||
vision = get_hive_config().get("vision_fallback", {})
|
||||
if not vision:
|
||||
return None
|
||||
if vision.get("api_base"):
|
||||
return vision["api_base"]
|
||||
if str(vision.get("provider", "")).lower() == "openrouter":
|
||||
return OPENROUTER_API_BASE
|
||||
return None
|
||||
|
||||
|
||||
def get_worker_api_key() -> str | None:
|
||||
"""Return the API key for the worker LLM, falling back to the default key."""
|
||||
worker_llm = get_hive_config().get("worker_llm", {})
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
"""Read/write helpers for per-colony metadata.json.
|
||||
|
||||
A colony's metadata.json lives at ``{COLONIES_DIR}/{colony_name}/metadata.json``
|
||||
and holds immutable provenance: the queen that created it, the forked
|
||||
session id, creation/update timestamps, and the list of workers.
|
||||
|
||||
Mutable user-editable tool configuration lives in a sibling
|
||||
``tools.json`` sidecar — see :mod:`framework.host.colony_tools_config`
|
||||
— so identity and tool gating evolve independently.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from framework.config import COLONIES_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def colony_metadata_path(colony_name: str) -> Path:
|
||||
"""Return the on-disk path to a colony's metadata.json."""
|
||||
return COLONIES_DIR / colony_name / "metadata.json"
|
||||
|
||||
|
||||
def load_colony_metadata(colony_name: str) -> dict[str, Any]:
|
||||
"""Load metadata.json for ``colony_name``.
|
||||
|
||||
Returns an empty dict if the file is missing or malformed — callers
|
||||
are expected to treat missing fields as defaults.
|
||||
"""
|
||||
path = colony_metadata_path(colony_name)
|
||||
if not path.exists():
|
||||
return {}
|
||||
try:
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError):
|
||||
logger.warning("Failed to read colony metadata at %s", path)
|
||||
return {}
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
|
||||
def update_colony_metadata(colony_name: str, updates: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Shallow-merge ``updates`` into metadata.json and persist.
|
||||
|
||||
Returns the full updated dict. Raises ``FileNotFoundError`` if the
|
||||
colony does not exist. Writes atomically via ``os.replace`` to
|
||||
minimize the window where a reader could see a half-written file.
|
||||
"""
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
path = colony_metadata_path(colony_name)
|
||||
if not path.parent.exists():
|
||||
raise FileNotFoundError(f"Colony '{colony_name}' not found")
|
||||
|
||||
data = load_colony_metadata(colony_name) if path.exists() else {}
|
||||
for key, value in updates.items():
|
||||
data[key] = value
|
||||
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd, tmp_path = tempfile.mkstemp(
|
||||
prefix=".metadata.",
|
||||
suffix=".json.tmp",
|
||||
dir=str(path.parent),
|
||||
)
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as fh:
|
||||
json.dump(data, fh, indent=2)
|
||||
fh.flush()
|
||||
os.fsync(fh.fileno())
|
||||
os.replace(tmp_path, path)
|
||||
except BaseException:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
return data
|
||||
|
||||
|
||||
def list_colony_names() -> list[str]:
|
||||
"""Return the names of every colony that has a metadata.json on disk."""
|
||||
if not COLONIES_DIR.is_dir():
|
||||
return []
|
||||
names: list[str] = []
|
||||
for entry in sorted(COLONIES_DIR.iterdir()):
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
if (entry / "metadata.json").exists():
|
||||
names.append(entry.name)
|
||||
return names
|
||||
@@ -185,6 +185,8 @@ class ColonyRuntime:
|
||||
protocols_prompt: str = "",
|
||||
skill_dirs: list[str] | None = None,
|
||||
pipeline_stages: list | None = None,
|
||||
queen_id: str | None = None,
|
||||
colony_name: str | None = None,
|
||||
):
|
||||
from framework.pipeline.runner import PipelineRunner
|
||||
from framework.skills.manager import SkillsManager
|
||||
@@ -193,14 +195,27 @@ class ColonyRuntime:
|
||||
self._goal = goal
|
||||
self._config = config or ColonyConfig()
|
||||
self._runtime_log_store = runtime_log_store
|
||||
self._queen_id: str | None = queen_id
|
||||
# ``colony_id`` is the event-bus scope (session.id in DM sessions);
|
||||
# ``colony_name`` is the on-disk identity under ~/.hive/colonies/.
|
||||
# They coincide for forked colonies but diverge for queen DM
|
||||
# sessions, so separate them explicitly.
|
||||
self._colony_name: str | None = colony_name
|
||||
|
||||
if pipeline_stages:
|
||||
self._pipeline = PipelineRunner(pipeline_stages)
|
||||
else:
|
||||
self._pipeline = self._load_pipeline_from_config()
|
||||
|
||||
if skills_manager_config is not None:
|
||||
self._skills_manager = SkillsManager(skills_manager_config)
|
||||
# Resolve per-colony override paths so UI toggles can reach this
|
||||
# runtime. Callers that build their own SkillsManagerConfig stay
|
||||
# in charge; bare construction auto-wires the standard paths.
|
||||
_effective_cfg = skills_manager_config
|
||||
if _effective_cfg is None and not (skills_catalog_prompt or protocols_prompt):
|
||||
_effective_cfg = self._build_default_skills_config(colony_name, queen_id)
|
||||
|
||||
if _effective_cfg is not None:
|
||||
self._skills_manager = SkillsManager(_effective_cfg)
|
||||
self._skills_manager.load()
|
||||
elif skills_catalog_prompt or protocols_prompt:
|
||||
import warnings
|
||||
@@ -242,6 +257,19 @@ class ColonyRuntime:
|
||||
self._tools = tools or []
|
||||
self._tool_executor = tool_executor
|
||||
|
||||
# Per-colony MCP tool allowlist — applied when spawning workers. A
|
||||
# value of ``None`` means "allow every MCP tool" (default), an empty
|
||||
# list disables every MCP tool, and a list of names only enables
|
||||
# those. Lifecycle / synthetic tools always pass through the filter
|
||||
# because their names are absent from ``_mcp_tool_names_all``. The
|
||||
# allowlist is re-read on every ``spawn`` so a PATCH that mutates
|
||||
# this attribute via ``set_tool_allowlist`` takes effect on the
|
||||
# NEXT worker spawn without a runtime restart. In-flight workers
|
||||
# keep the tool list they booted with — workers have no dynamic
|
||||
# tools provider today.
|
||||
self._enabled_mcp_tools: list[str] | None = None
|
||||
self._mcp_tool_names_all: set[str] = set()
|
||||
|
||||
# Worker management
|
||||
self._workers: dict[str, Worker] = {}
|
||||
# The persistent client-facing overseer (optional). Set by
|
||||
@@ -384,6 +412,128 @@ class ColonyRuntime:
|
||||
return PipelineRunner([])
|
||||
return build_pipeline_from_config(stages_config)
|
||||
|
||||
@staticmethod
|
||||
def _build_default_skills_config(
|
||||
colony_name: str | None,
|
||||
queen_id: str | None,
|
||||
) -> SkillsManagerConfig:
|
||||
"""Assemble a ``SkillsManagerConfig`` that wires in the per-colony /
|
||||
per-queen override files and the ``queen_ui`` / ``colony_ui`` scope
|
||||
dirs based on the standard ``~/.hive`` layout.
|
||||
|
||||
``colony_name`` must be an actual on-disk colony name
|
||||
(``~/.hive/colonies/{name}/``). DM sessions where the ``colony_id``
|
||||
is a session UUID should pass ``None`` so we don't create a stray
|
||||
override file under a session identifier.
|
||||
"""
|
||||
from framework.config import COLONIES_DIR, QUEENS_DIR
|
||||
from framework.skills.discovery import ExtraScope
|
||||
from framework.skills.manager import SkillsManagerConfig
|
||||
|
||||
extras: list[ExtraScope] = []
|
||||
queen_overrides_path: Path | None = None
|
||||
if queen_id:
|
||||
queen_home = QUEENS_DIR / queen_id
|
||||
queen_overrides_path = queen_home / "skills_overrides.json"
|
||||
extras.append(ExtraScope(directory=queen_home / "skills", label="queen_ui", priority=2))
|
||||
|
||||
colony_overrides_path: Path | None = None
|
||||
if colony_name:
|
||||
colony_home = COLONIES_DIR / colony_name
|
||||
colony_overrides_path = colony_home / "skills_overrides.json"
|
||||
# Colony-scope SKILL.md dir is the project-scope from discovery's
|
||||
# point of view (colony_dir is the project_root). Add it also as
|
||||
# a tagged ``colony_ui`` scope so UI-created entries resolve with
|
||||
# correct provenance.
|
||||
extras.append(
|
||||
ExtraScope(
|
||||
directory=colony_home / ".hive" / "skills",
|
||||
label="colony_ui",
|
||||
priority=3,
|
||||
)
|
||||
)
|
||||
|
||||
return SkillsManagerConfig(
|
||||
queen_id=queen_id,
|
||||
queen_overrides_path=queen_overrides_path,
|
||||
colony_name=colony_name,
|
||||
colony_overrides_path=colony_overrides_path,
|
||||
extra_scope_dirs=extras,
|
||||
interactive=False, # HTTP-driven runtimes never prompt for consent
|
||||
)
|
||||
|
||||
@property
|
||||
def queen_id(self) -> str | None:
|
||||
"""The queen that owns this runtime, if known."""
|
||||
return self._queen_id
|
||||
|
||||
@property
|
||||
def colony_name(self) -> str | None:
|
||||
"""The on-disk colony name (distinct from event-bus scope ``colony_id``)."""
|
||||
return self._colony_name
|
||||
|
||||
@property
|
||||
def skills_manager(self):
|
||||
"""Access the live :class:`SkillsManager` (for HTTP handlers)."""
|
||||
return self._skills_manager
|
||||
|
||||
async def reload_skills(self) -> dict[str, Any]:
|
||||
"""Rebuild the catalog after an override change; in-flight workers
|
||||
pick up the new catalog on their next iteration via
|
||||
``dynamic_skills_catalog_provider``.
|
||||
|
||||
Returns a small stats dict that HTTP handlers can echo back to
|
||||
the UI ("applied — N skills now in catalog").
|
||||
"""
|
||||
async with self._skills_manager.mutation_lock:
|
||||
self._skills_manager.reload()
|
||||
self.skill_dirs = self._skills_manager.allowlisted_dirs
|
||||
self.batch_init_nudge = self._skills_manager.batch_init_nudge
|
||||
self.context_warn_ratio = self._skills_manager.context_warn_ratio
|
||||
catalog_prompt = self._skills_manager.skills_catalog_prompt
|
||||
return {
|
||||
"catalog_chars": len(catalog_prompt),
|
||||
"skill_dirs": list(self.skill_dirs),
|
||||
}
|
||||
|
||||
# ── Per-colony tool allowlist ───────────────────────────────
|
||||
|
||||
def set_tool_allowlist(
|
||||
self,
|
||||
enabled_mcp_tools: list[str] | None,
|
||||
mcp_tool_names_all: set[str] | None = None,
|
||||
) -> None:
|
||||
"""Configure the per-colony MCP tool allowlist.
|
||||
|
||||
Called at construction time (from SessionManager) and again from
|
||||
the ``/api/colony/{name}/tools`` PATCH handler when a user edits
|
||||
the allowlist. The change applies to the NEXT worker spawn — we
|
||||
never mutate the tool list of a worker that is already running
|
||||
(workers have no dynamic tools provider, so hot-reloading their
|
||||
tool set would diverge from the list the LLM was already using).
|
||||
"""
|
||||
self._enabled_mcp_tools = list(enabled_mcp_tools) if enabled_mcp_tools is not None else None
|
||||
if mcp_tool_names_all is not None:
|
||||
self._mcp_tool_names_all = set(mcp_tool_names_all)
|
||||
|
||||
def _apply_tool_allowlist(self, tools: list) -> list:
|
||||
"""Filter ``tools`` against the colony's MCP allowlist.
|
||||
|
||||
Lifecycle / synthetic tools (those whose names are NOT in
|
||||
``_mcp_tool_names_all``) are never gated. MCP tools are kept only
|
||||
when ``_enabled_mcp_tools`` is None (default allow) or contains
|
||||
their name. Input list order is preserved so downstream cache
|
||||
keys and logs stay stable.
|
||||
"""
|
||||
if self._enabled_mcp_tools is None:
|
||||
return tools
|
||||
allowed = set(self._enabled_mcp_tools)
|
||||
return [
|
||||
t
|
||||
for t in tools
|
||||
if getattr(t, "name", None) not in self._mcp_tool_names_all or getattr(t, "name", None) in allowed
|
||||
]
|
||||
|
||||
# ── Lifecycle ───────────────────────────────────────────────
|
||||
|
||||
async def start(self) -> None:
|
||||
@@ -658,6 +808,14 @@ class ColonyRuntime:
|
||||
spawn_tools = tools if tools is not None else self._tools
|
||||
spawn_executor = tool_executor or self._tool_executor
|
||||
|
||||
# Apply the per-colony MCP tool allowlist (if any). Done HERE —
|
||||
# after spawn_tools is resolved but before it's frozen into the
|
||||
# worker's AgentContext — so the next spawn reflects any PATCH
|
||||
# that happened since the last spawn. A value of ``None`` on
|
||||
# ``_enabled_mcp_tools`` is a no-op so the default path is
|
||||
# unchanged.
|
||||
spawn_tools = self._apply_tool_allowlist(spawn_tools)
|
||||
|
||||
# Colony progress tracker: when the caller supplied a db_path
|
||||
# in input_data, this worker is part of a SQLite task queue
|
||||
# and must see the hive.colony-progress-tracker skill body in
|
||||
@@ -740,6 +898,17 @@ class ColonyRuntime:
|
||||
conversation_store=worker_conv_store,
|
||||
)
|
||||
|
||||
# Workers pick up UI-driven override changes via this provider,
|
||||
# which reads the live catalog on each iteration. The db_path
|
||||
# pre-activated catalog stays static because its contents are
|
||||
# built for *this* worker's task (a tombstone toggle from the
|
||||
# UI should not yank it mid-run).
|
||||
_db_path_pre_activated = bool(isinstance(input_data, dict) and input_data.get("db_path"))
|
||||
# Default-bind the manager into the closure so each loop iteration
|
||||
# captures the same manager instance — pyflakes B023 would flag a
|
||||
# free-variable capture here.
|
||||
_provider = None if _db_path_pre_activated else (lambda mgr=self._skills_manager: mgr.skills_catalog_prompt)
|
||||
|
||||
agent_context = AgentContext(
|
||||
runtime=self._make_runtime_adapter(worker_id),
|
||||
agent_id=worker_id,
|
||||
@@ -753,6 +922,7 @@ class ColonyRuntime:
|
||||
skills_catalog_prompt=_spawn_catalog,
|
||||
protocols_prompt=self.protocols_prompt,
|
||||
skill_dirs=_spawn_skill_dirs,
|
||||
dynamic_skills_catalog_provider=_provider,
|
||||
execution_id=worker_id,
|
||||
stream_id=explicit_stream_id or f"worker:{worker_id}",
|
||||
)
|
||||
@@ -997,6 +1167,7 @@ class ColonyRuntime:
|
||||
conversation_store=overseer_conv_store,
|
||||
)
|
||||
|
||||
_overseer_skills_mgr = self._skills_manager
|
||||
overseer_ctx = AgentContext(
|
||||
runtime=self._make_runtime_adapter(overseer_id),
|
||||
agent_id=overseer_id,
|
||||
@@ -1010,6 +1181,7 @@ class ColonyRuntime:
|
||||
skills_catalog_prompt=self.skills_catalog_prompt,
|
||||
protocols_prompt=self.protocols_prompt,
|
||||
skill_dirs=self.skill_dirs,
|
||||
dynamic_skills_catalog_provider=lambda: _overseer_skills_mgr.skills_catalog_prompt,
|
||||
execution_id=overseer_id,
|
||||
stream_id="overseer",
|
||||
)
|
||||
|
||||
@@ -0,0 +1,162 @@
|
||||
"""Per-colony tool configuration sidecar (``tools.json``).
|
||||
|
||||
Lives at ``~/.hive/colonies/{colony_name}/tools.json`` alongside
|
||||
``metadata.json``. Kept separate so provenance (queen_name,
|
||||
created_at, workers) stays in metadata while the user-editable tool
|
||||
allowlist gets its own file.
|
||||
|
||||
Schema::
|
||||
|
||||
{
|
||||
"enabled_mcp_tools": ["read_file", ...] | null,
|
||||
"updated_at": "2026-04-21T12:34:56+00:00"
|
||||
}
|
||||
|
||||
- ``null`` / missing file → default "allow every MCP tool".
|
||||
- ``[]`` → explicitly disable every MCP tool.
|
||||
- ``["foo", "bar"]`` → only those MCP tool names pass the filter.
|
||||
|
||||
Atomic writes via ``os.replace`` mirror
|
||||
``framework.host.colony_metadata.update_colony_metadata``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from framework.config import COLONIES_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def tools_config_path(colony_name: str) -> Path:
|
||||
"""Return the on-disk path to a colony's ``tools.json``."""
|
||||
return COLONIES_DIR / colony_name / "tools.json"
|
||||
|
||||
|
||||
def _metadata_path(colony_name: str) -> Path:
|
||||
return COLONIES_DIR / colony_name / "metadata.json"
|
||||
|
||||
|
||||
def _atomic_write_json(path: Path, data: dict[str, Any]) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd, tmp = tempfile.mkstemp(
|
||||
prefix=".tools.",
|
||||
suffix=".json.tmp",
|
||||
dir=str(path.parent),
|
||||
)
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as fh:
|
||||
json.dump(data, fh, indent=2)
|
||||
fh.flush()
|
||||
os.fsync(fh.fileno())
|
||||
os.replace(tmp, path)
|
||||
except BaseException:
|
||||
try:
|
||||
os.unlink(tmp)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
|
||||
def _migrate_from_metadata_if_needed(colony_name: str) -> list[str] | None:
|
||||
"""Hoist a legacy ``enabled_mcp_tools`` field out of ``metadata.json``.
|
||||
|
||||
Returns the migrated value (or ``None`` if nothing to migrate). After
|
||||
migration the sidecar exists and ``metadata.json`` no longer contains
|
||||
``enabled_mcp_tools``. Safe to call repeatedly.
|
||||
"""
|
||||
meta_path = _metadata_path(colony_name)
|
||||
if not meta_path.exists():
|
||||
return None
|
||||
try:
|
||||
data = json.loads(meta_path.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError):
|
||||
logger.warning("Could not read metadata.json during tools migration: %s", colony_name)
|
||||
return None
|
||||
if not isinstance(data, dict) or "enabled_mcp_tools" not in data:
|
||||
return None
|
||||
|
||||
raw = data.pop("enabled_mcp_tools")
|
||||
enabled: list[str] | None
|
||||
if raw is None:
|
||||
enabled = None
|
||||
elif isinstance(raw, list) and all(isinstance(x, str) for x in raw):
|
||||
enabled = raw
|
||||
else:
|
||||
logger.warning(
|
||||
"Legacy enabled_mcp_tools on colony %s had unexpected shape %r; dropping",
|
||||
colony_name,
|
||||
raw,
|
||||
)
|
||||
enabled = None
|
||||
|
||||
# Sidecar first so a partial failure leaves the config recoverable.
|
||||
_atomic_write_json(
|
||||
tools_config_path(colony_name),
|
||||
{
|
||||
"enabled_mcp_tools": enabled,
|
||||
"updated_at": datetime.now(UTC).isoformat(),
|
||||
},
|
||||
)
|
||||
_atomic_write_json(meta_path, data)
|
||||
logger.info(
|
||||
"Migrated enabled_mcp_tools for colony %s from metadata.json to tools.json",
|
||||
colony_name,
|
||||
)
|
||||
return enabled
|
||||
|
||||
|
||||
def load_colony_tools_config(colony_name: str) -> list[str] | None:
|
||||
"""Return the colony's MCP tool allowlist, or ``None`` for default-allow.
|
||||
|
||||
Order of resolution:
|
||||
1. ``tools.json`` sidecar (authoritative).
|
||||
2. Legacy ``metadata.json`` field (migrated and deleted on first read).
|
||||
3. ``None`` — default "allow every MCP tool".
|
||||
"""
|
||||
path = tools_config_path(colony_name)
|
||||
if path.exists():
|
||||
try:
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError):
|
||||
logger.warning("Invalid %s; treating as default-allow", path)
|
||||
return None
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
raw = data.get("enabled_mcp_tools")
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, list) and all(isinstance(x, str) for x in raw):
|
||||
return raw
|
||||
logger.warning("Unexpected enabled_mcp_tools shape in %s; ignoring", path)
|
||||
return None
|
||||
|
||||
return _migrate_from_metadata_if_needed(colony_name)
|
||||
|
||||
|
||||
def update_colony_tools_config(
|
||||
colony_name: str,
|
||||
enabled_mcp_tools: list[str] | None,
|
||||
) -> list[str] | None:
|
||||
"""Persist a colony's MCP allowlist to ``tools.json``.
|
||||
|
||||
Raises ``FileNotFoundError`` if the colony's directory is missing.
|
||||
"""
|
||||
colony_dir = COLONIES_DIR / colony_name
|
||||
if not colony_dir.exists():
|
||||
raise FileNotFoundError(f"Colony directory not found: {colony_name}")
|
||||
_atomic_write_json(
|
||||
tools_config_path(colony_name),
|
||||
{
|
||||
"enabled_mcp_tools": enabled_mcp_tools,
|
||||
"updated_at": datetime.now(UTC).isoformat(),
|
||||
},
|
||||
)
|
||||
return enabled_mcp_tools
|
||||
@@ -809,16 +809,28 @@ class EventBus:
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cached_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
cost_usd: float = 0.0,
|
||||
execution_id: str | None = None,
|
||||
iteration: int | None = None,
|
||||
) -> None:
|
||||
"""Emit LLM turn completion with stop reason and model metadata."""
|
||||
"""Emit LLM turn completion with stop reason and model metadata.
|
||||
|
||||
``cached_tokens`` and ``cache_creation_tokens`` are subsets of
|
||||
``input_tokens`` (already inside provider ``prompt_tokens``).
|
||||
Subscribers should display them, not add them to a total.
|
||||
|
||||
``cost_usd`` is the USD cost for this turn when known (Anthropic,
|
||||
OpenAI, OpenRouter). 0.0 means unreported (not free).
|
||||
"""
|
||||
data: dict = {
|
||||
"stop_reason": stop_reason,
|
||||
"model": model,
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"cached_tokens": cached_tokens,
|
||||
"cache_creation_tokens": cache_creation_tokens,
|
||||
"cost_usd": cost_usd,
|
||||
}
|
||||
if iteration is not None:
|
||||
data["iteration"] = iteration
|
||||
|
||||
@@ -653,10 +653,17 @@ class AntigravityProvider(LLMProvider):
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
system_dynamic_suffix: str | None = None,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
import asyncio # noqa: PLC0415
|
||||
import concurrent.futures # noqa: PLC0415
|
||||
|
||||
# Antigravity (Google's proprietary endpoint) doesn't expose a
|
||||
# cache_control hook. Concatenate the dynamic suffix so its shape
|
||||
# matches the legacy single-string call site.
|
||||
if system_dynamic_suffix:
|
||||
system = f"{system}\n\n{system_dynamic_suffix}" if system else system_dynamic_suffix
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
queue: asyncio.Queue[StreamEvent | None] = asyncio.Queue()
|
||||
|
||||
|
||||
@@ -1,114 +1,32 @@
|
||||
"""Model capability checks for LLM providers.
|
||||
|
||||
Vision support rules are derived from official vendor documentation:
|
||||
- ZAI (z.ai): docs.z.ai/guides/vlm — GLM-4.6V variants are vision; GLM-5/4.6/4.7 are text-only
|
||||
- MiniMax: platform.minimax.io/docs — minimax-vl-01 is vision; M2.x are text-only
|
||||
- DeepSeek: api-docs.deepseek.com — deepseek-vl2 is vision; chat/reasoner are text-only
|
||||
- Cerebras: inference-docs.cerebras.ai — no vision models at all
|
||||
- Groq: console.groq.com/docs/vision — vision capable; treat as supported by default
|
||||
- Ollama/LM Studio/vLLM/llama.cpp: local runners denied by default; model names
|
||||
don't reliably indicate vision support, so users must configure explicitly
|
||||
Vision support is sourced from the curated ``model_catalog.json``. Each model
|
||||
entry carries an optional ``supports_vision`` boolean; unknown models default
|
||||
to vision-capable so hosted frontier models work out of the box. To toggle
|
||||
support for a model, edit its catalog entry rather than this file.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from framework.llm.model_catalog import model_supports_vision
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.llm.provider import Tool
|
||||
|
||||
|
||||
def _model_name(model: str) -> str:
|
||||
"""Return the bare model name after stripping any 'provider/' prefix."""
|
||||
if "/" in model:
|
||||
return model.split("/", 1)[1]
|
||||
return model
|
||||
|
||||
|
||||
# Step 1: explicit vision allow-list — these always support images regardless
|
||||
# of what the provider-level rules say. Checked first so that e.g. glm-4.6v
|
||||
# is allowed even though glm-4.6 is denied.
|
||||
_VISION_ALLOW_BARE_PREFIXES: tuple[str, ...] = (
|
||||
# ZAI/GLM vision models (docs.z.ai/guides/vlm)
|
||||
"glm-4v", # GLM-4V series (legacy)
|
||||
"glm-4.6v", # GLM-4.6V, GLM-4.6V-flash, GLM-4.6V-flashx
|
||||
# DeepSeek vision models
|
||||
"deepseek-vl", # deepseek-vl2, deepseek-vl2-small, deepseek-vl2-tiny
|
||||
# MiniMax vision model
|
||||
"minimax-vl", # minimax-vl-01
|
||||
)
|
||||
|
||||
# Step 2: provider-level deny — every model from this provider is text-only.
|
||||
_TEXT_ONLY_PROVIDER_PREFIXES: tuple[str, ...] = (
|
||||
# Cerebras: inference-docs.cerebras.ai lists only text models
|
||||
"cerebras/",
|
||||
# Local runners: model names don't reliably indicate vision support
|
||||
"ollama/",
|
||||
"ollama_chat/",
|
||||
"lm_studio/",
|
||||
"vllm/",
|
||||
"llamacpp/",
|
||||
)
|
||||
|
||||
# Step 3: per-model deny — text-only models within otherwise mixed providers.
|
||||
# Matched against the bare model name (provider prefix stripped, lower-cased).
|
||||
# The vision allow-list above is checked first, so vision variants of the same
|
||||
# family are already handled before these deny patterns are reached.
|
||||
_TEXT_ONLY_MODEL_BARE_PREFIXES: tuple[str, ...] = (
|
||||
# --- ZAI / GLM family ---
|
||||
# text-only: glm-5, glm-4.6, glm-4.7, glm-4.5, zai-glm-*
|
||||
# vision: glm-4v, glm-4.6v (caught by allow-list above)
|
||||
"glm-5",
|
||||
"glm-4.6", # bare glm-4.6 is text-only; glm-4.6v is caught by allow-list
|
||||
"glm-4.7",
|
||||
"glm-4.5",
|
||||
"zai-glm",
|
||||
# --- DeepSeek ---
|
||||
# text-only: deepseek-chat, deepseek-coder, deepseek-reasoner
|
||||
# vision: deepseek-vl2 (caught by allow-list above)
|
||||
# Note: LiteLLM's deepseek handler may flatten content lists for some models;
|
||||
# VL models are allowed through and rely on LiteLLM's native VL support.
|
||||
"deepseek-chat",
|
||||
"deepseek-coder",
|
||||
"deepseek-reasoner",
|
||||
# --- MiniMax ---
|
||||
# text-only: minimax-m2.*, minimax-text-*, abab* (legacy)
|
||||
# vision: minimax-vl-01 (caught by allow-list above)
|
||||
"minimax-m2",
|
||||
"minimax-text",
|
||||
"abab",
|
||||
)
|
||||
|
||||
|
||||
def supports_image_tool_results(model: str) -> bool:
|
||||
"""Return whether *model* can receive image content in messages.
|
||||
|
||||
Used to gate both user-message images and tool-result image blocks.
|
||||
|
||||
Logic (checked in order):
|
||||
1. Vision allow-list → True (known vision model, skip all denies)
|
||||
2. Provider deny → False (entire provider is text-only)
|
||||
3. Model deny → False (specific text-only model within a mixed provider)
|
||||
4. Default → True (assume capable; unknown providers and models)
|
||||
Thin wrapper over :func:`model_supports_vision` so existing call sites
|
||||
keep working. Used to gate both user-message images and tool-result
|
||||
image blocks. Empty model strings are treated as capable so the default
|
||||
code path doesn't strip images before a provider is selected.
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
bare = _model_name(model_lower)
|
||||
|
||||
# 1. Explicit vision allow — takes priority over all denies
|
||||
if any(bare.startswith(p) for p in _VISION_ALLOW_BARE_PREFIXES):
|
||||
if not model:
|
||||
return True
|
||||
|
||||
# 2. Provider-level deny (all models from this provider are text-only)
|
||||
if any(model_lower.startswith(p) for p in _TEXT_ONLY_PROVIDER_PREFIXES):
|
||||
return False
|
||||
|
||||
# 3. Per-model deny (text-only variants within mixed-capability families)
|
||||
if any(bare.startswith(p) for p in _TEXT_ONLY_MODEL_BARE_PREFIXES):
|
||||
return False
|
||||
|
||||
# 5. Default: assume vision capable
|
||||
# Covers: OpenAI, Anthropic, Google, Mistral, Kimi, and other hosted providers
|
||||
return True
|
||||
return model_supports_vision(model)
|
||||
|
||||
|
||||
def filter_tools_for_model(tools: list[Tool], model: str) -> tuple[list[Tool], list[str]]:
|
||||
|
||||
+350
-28
@@ -33,6 +33,7 @@ except ImportError:
|
||||
RateLimitError = Exception # type: ignore[assignment, misc]
|
||||
|
||||
from framework.config import HIVE_LLM_ENDPOINT as HIVE_API_BASE
|
||||
from framework.llm.model_catalog import get_model_pricing
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
from framework.llm.stream_events import StreamEvent
|
||||
|
||||
@@ -213,9 +214,72 @@ _CACHE_CONTROL_PREFIXES = (
|
||||
"glm-",
|
||||
)
|
||||
|
||||
# OpenRouter sub-provider prefixes whose upstream API honors `cache_control`.
|
||||
# OpenRouter passes the marker through to the underlying provider for these.
|
||||
# (See https://openrouter.ai/docs/guides/best-practices/prompt-caching.)
|
||||
# OpenAI/DeepSeek/Groq/Grok/Moonshot route through OpenRouter but cache
|
||||
# automatically server-side — sending cache_control there is a no-op, not a
|
||||
# win, and they need a separate prefix-stability fix to actually get hits.
|
||||
_OPENROUTER_CACHE_CONTROL_PREFIXES = (
|
||||
"openrouter/anthropic/",
|
||||
"openrouter/google/gemini-",
|
||||
"openrouter/z-ai/glm",
|
||||
"openrouter/minimax/",
|
||||
)
|
||||
|
||||
|
||||
def _model_supports_cache_control(model: str) -> bool:
|
||||
return any(model.startswith(p) for p in _CACHE_CONTROL_PREFIXES)
|
||||
if any(model.startswith(p) for p in _CACHE_CONTROL_PREFIXES):
|
||||
return True
|
||||
return any(model.startswith(p) for p in _OPENROUTER_CACHE_CONTROL_PREFIXES)
|
||||
|
||||
|
||||
def _build_system_message(
|
||||
system: str,
|
||||
system_dynamic_suffix: str | None,
|
||||
model: str,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Construct the system-role message for the chat completion.
|
||||
|
||||
Returns ``None`` when there is nothing to send.
|
||||
|
||||
Two-block split path — used when the caller supplied a non-empty
|
||||
``system_dynamic_suffix`` AND the provider honors ``cache_control``
|
||||
(Anthropic, MiniMax, Z-AI/GLM). We emit ``content`` as a list of two
|
||||
text blocks with an ephemeral ``cache_control`` marker on the first
|
||||
block only. The prompt cache keeps the static prefix warm across
|
||||
turns and across iterations within a turn; only the small dynamic
|
||||
tail is recomputed on every request.
|
||||
|
||||
Single-string path — used for every other case (no suffix provided,
|
||||
or provider doesn't honor ``cache_control``). We concatenate
|
||||
``system`` + ``\\n\\n`` + ``system_dynamic_suffix`` and attach
|
||||
``cache_control`` to the whole message when the provider supports
|
||||
it. This is byte-identical to the pre-split behavior for all
|
||||
non-cache-control providers (OpenAI, Gemini, Groq, Ollama, etc.).
|
||||
"""
|
||||
if not system and not system_dynamic_suffix:
|
||||
return None
|
||||
if system_dynamic_suffix and _model_supports_cache_control(model):
|
||||
content_blocks: list[dict[str, Any]] = []
|
||||
if system:
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": system,
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}
|
||||
)
|
||||
content_blocks.append({"type": "text", "text": system_dynamic_suffix})
|
||||
return {"role": "system", "content": content_blocks}
|
||||
# Single-string path (legacy or no-cache-control provider).
|
||||
combined = system
|
||||
if system_dynamic_suffix:
|
||||
combined = f"{system}\n\n{system_dynamic_suffix}" if system else system_dynamic_suffix
|
||||
sys_msg: dict[str, Any] = {"role": "system", "content": combined}
|
||||
if _model_supports_cache_control(model):
|
||||
sys_msg["cache_control"] = {"type": "ephemeral"}
|
||||
return sys_msg
|
||||
|
||||
|
||||
# Kimi For Coding uses an Anthropic-compatible endpoint (no /v1 suffix).
|
||||
@@ -297,6 +361,171 @@ FAILED_REQUESTS_DIR = Path.home() / ".hive" / "failed_requests"
|
||||
MAX_FAILED_REQUEST_DUMPS = 50
|
||||
|
||||
|
||||
def _cost_from_catalog_pricing(
|
||||
model: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cached_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
) -> float:
|
||||
"""Last-resort cost calculation using curated catalog pricing.
|
||||
|
||||
Consulted only when the provider response carries no native cost and
|
||||
LiteLLM's own catalog has no pricing for ``model``. Reads
|
||||
``pricing_usd_per_mtok`` from ``model_catalog.json``. Rates are USD per
|
||||
million tokens.
|
||||
|
||||
``cached_tokens`` and ``cache_creation_tokens`` are subsets of
|
||||
``input_tokens`` (see ``_extract_cache_tokens``), so subtract them from
|
||||
the base input count to avoid double-billing. If a cache rate is absent,
|
||||
fall back to the plain input rate.
|
||||
"""
|
||||
if not model or (input_tokens == 0 and output_tokens == 0):
|
||||
return 0.0
|
||||
pricing = get_model_pricing(model)
|
||||
if pricing is None and "/" in model:
|
||||
# LiteLLM prefixes some ids (e.g. "openrouter/z-ai/glm-5.1"); the
|
||||
# catalog stores the bare form ("z-ai/glm-5.1"). Strip one segment.
|
||||
pricing = get_model_pricing(model.split("/", 1)[1])
|
||||
if pricing is None:
|
||||
return 0.0
|
||||
|
||||
per_mtok_in = pricing.get("input", 0.0)
|
||||
per_mtok_out = pricing.get("output", 0.0)
|
||||
per_mtok_cache_read = pricing.get("cache_read", per_mtok_in)
|
||||
per_mtok_cache_write = pricing.get("cache_creation", per_mtok_in)
|
||||
|
||||
plain_input = max(input_tokens - cached_tokens - cache_creation_tokens, 0)
|
||||
total = (
|
||||
plain_input * per_mtok_in
|
||||
+ cached_tokens * per_mtok_cache_read
|
||||
+ cache_creation_tokens * per_mtok_cache_write
|
||||
+ output_tokens * per_mtok_out
|
||||
) / 1_000_000
|
||||
return float(total) if total > 0 else 0.0
|
||||
|
||||
|
||||
def _extract_cost(response: Any, model: str) -> float:
|
||||
"""Pull the USD cost for a non-streaming completion response.
|
||||
|
||||
Sources checked, in priority order:
|
||||
1. ``usage.cost`` — populated when OpenRouter returns native cost via
|
||||
``usage: {include: true}`` or when ``litellm.include_cost_in_streaming_usage``
|
||||
is on.
|
||||
2. ``response._hidden_params["response_cost"]`` — set by LiteLLM's
|
||||
logging layer after most successful completions.
|
||||
3. ``litellm.completion_cost(...)`` — computes from the model pricing
|
||||
table; works across Anthropic, OpenAI, and OpenRouter as long as the
|
||||
model is in LiteLLM's catalog.
|
||||
4. ``pricing_usd_per_mtok`` from the curated model catalog — covers
|
||||
models (e.g. GLM, Kimi, MiniMax) that LiteLLM doesn't price.
|
||||
|
||||
Returns 0.0 for unpriced models or unexpected response shapes — cost is a
|
||||
display concern, never let it break the hot path. For streaming paths
|
||||
where the aggregate response isn't a full ``ModelResponse``, use
|
||||
:func:`_cost_from_tokens` with the already-extracted token counts.
|
||||
"""
|
||||
if response is None:
|
||||
return 0.0
|
||||
usage = getattr(response, "usage", None)
|
||||
usage_cost = getattr(usage, "cost", None) if usage is not None else None
|
||||
if isinstance(usage_cost, (int, float)) and usage_cost > 0:
|
||||
return float(usage_cost)
|
||||
|
||||
hidden = getattr(response, "_hidden_params", None)
|
||||
if isinstance(hidden, dict):
|
||||
hp_cost = hidden.get("response_cost")
|
||||
if isinstance(hp_cost, (int, float)) and hp_cost > 0:
|
||||
return float(hp_cost)
|
||||
|
||||
try:
|
||||
import litellm as _litellm
|
||||
|
||||
computed = _litellm.completion_cost(completion_response=response, model=model)
|
||||
if isinstance(computed, (int, float)) and computed > 0:
|
||||
return float(computed)
|
||||
except Exception as exc:
|
||||
logger.debug("[cost] completion_cost failed for %s: %s", model, exc)
|
||||
|
||||
if usage is not None:
|
||||
input_tokens = int(getattr(usage, "prompt_tokens", 0) or 0)
|
||||
output_tokens = int(getattr(usage, "completion_tokens", 0) or 0)
|
||||
cache_read, cache_creation = _extract_cache_tokens(usage)
|
||||
fallback = _cost_from_catalog_pricing(model, input_tokens, output_tokens, cache_read, cache_creation)
|
||||
if fallback > 0:
|
||||
return fallback
|
||||
return 0.0
|
||||
|
||||
|
||||
def _cost_from_tokens(
|
||||
model: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cached_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
) -> float:
|
||||
"""Compute USD cost from already-normalized token counts.
|
||||
|
||||
Used on streaming paths where the aggregate ``response`` is the stream
|
||||
wrapper (not a full ``ModelResponse``) and ``litellm.completion_cost`` on
|
||||
it either no-ops or raises. Calls ``litellm.cost_per_token`` directly
|
||||
with the cache-aware inputs so Anthropic's 5-min-write / cache-read
|
||||
multipliers are applied correctly.
|
||||
"""
|
||||
if not model or (input_tokens == 0 and output_tokens == 0):
|
||||
return 0.0
|
||||
try:
|
||||
import litellm as _litellm
|
||||
|
||||
prompt_cost, completion_cost = _litellm.cost_per_token(
|
||||
model=model,
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=output_tokens,
|
||||
cache_read_input_tokens=cached_tokens,
|
||||
cache_creation_input_tokens=cache_creation_tokens,
|
||||
)
|
||||
total = (prompt_cost or 0.0) + (completion_cost or 0.0)
|
||||
if total > 0:
|
||||
return float(total)
|
||||
except Exception as exc:
|
||||
logger.debug("[cost] cost_per_token failed for %s: %s", model, exc)
|
||||
return _cost_from_catalog_pricing(model, input_tokens, output_tokens, cached_tokens, cache_creation_tokens)
|
||||
|
||||
|
||||
def _extract_cache_tokens(usage: Any) -> tuple[int, int]:
|
||||
"""Pull (cache_read, cache_creation) from a LiteLLM usage object.
|
||||
|
||||
Both are subsets of ``prompt_tokens`` already — providers count them
|
||||
inside the input total. Surface separately for visibility, never sum.
|
||||
|
||||
Field names vary by provider/proxy; check the known shapes in priority
|
||||
order and fall back to 0:
|
||||
|
||||
cache_read:
|
||||
- ``prompt_tokens_details.cached_tokens`` — OpenAI-shape; also what
|
||||
LiteLLM normalizes Anthropic and OpenRouter into.
|
||||
- ``cache_read_input_tokens`` — raw Anthropic field name.
|
||||
|
||||
cache_creation:
|
||||
- ``prompt_tokens_details.cache_write_tokens`` — OpenRouter's
|
||||
normalized field for cache writes (verified empirically against
|
||||
``openrouter/anthropic/*`` and ``openrouter/z-ai/*`` responses).
|
||||
- ``cache_creation_input_tokens`` — raw Anthropic top-level field.
|
||||
"""
|
||||
if not usage:
|
||||
return 0, 0
|
||||
_details = getattr(usage, "prompt_tokens_details", None)
|
||||
cache_read = (
|
||||
getattr(_details, "cached_tokens", 0) or 0
|
||||
if _details is not None
|
||||
else getattr(usage, "cache_read_input_tokens", 0) or 0
|
||||
)
|
||||
cache_creation = (getattr(_details, "cache_write_tokens", 0) or 0 if _details is not None else 0) or (
|
||||
getattr(usage, "cache_creation_input_tokens", 0) or 0
|
||||
)
|
||||
return cache_read, cache_creation
|
||||
|
||||
|
||||
def _estimate_tokens(model: str, messages: list[dict]) -> tuple[int, str]:
|
||||
"""Estimate token count for messages. Returns (token_count, method)."""
|
||||
# Try litellm's token counter first
|
||||
@@ -1015,12 +1244,17 @@ class LiteLLMProvider(LLMProvider):
|
||||
usage = response.usage
|
||||
input_tokens = usage.prompt_tokens if usage else 0
|
||||
output_tokens = usage.completion_tokens if usage else 0
|
||||
cached_tokens, cache_creation_tokens = _extract_cache_tokens(usage)
|
||||
cost_usd = _extract_cost(response, self.model)
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=response.model or self.model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cached_tokens=cached_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
cost_usd=cost_usd,
|
||||
stop_reason=response.choices[0].finish_reason or "",
|
||||
raw_response=response,
|
||||
)
|
||||
@@ -1169,8 +1403,16 @@ class LiteLLMProvider(LLMProvider):
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
system_dynamic_suffix: str | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Async version of complete(). Uses litellm.acompletion — non-blocking."""
|
||||
"""Async version of complete(). Uses litellm.acompletion — non-blocking.
|
||||
|
||||
``system_dynamic_suffix`` is an optional per-turn tail. When set and
|
||||
the provider honors ``cache_control``, ``system`` is sent as the
|
||||
cached prefix and the suffix trails as an uncached second content
|
||||
block. Otherwise the two strings are concatenated into a single
|
||||
system message (legacy behavior).
|
||||
"""
|
||||
# Codex ChatGPT backend requires streaming — route through stream() which
|
||||
# already handles Codex quirks and has proper tool call accumulation.
|
||||
if self._codex_backend:
|
||||
@@ -1181,6 +1423,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_format,
|
||||
json_mode=json_mode,
|
||||
system_dynamic_suffix=system_dynamic_suffix,
|
||||
)
|
||||
return await self._collect_stream_to_response(stream_iter)
|
||||
|
||||
@@ -1188,10 +1431,8 @@ class LiteLLMProvider(LLMProvider):
|
||||
if self._claude_code_oauth:
|
||||
billing = _claude_code_billing_header(messages)
|
||||
full_messages.append({"role": "system", "content": billing})
|
||||
if system:
|
||||
sys_msg: dict[str, Any] = {"role": "system", "content": system}
|
||||
if _model_supports_cache_control(self.model):
|
||||
sys_msg["cache_control"] = {"type": "ephemeral"}
|
||||
sys_msg = _build_system_message(system, system_dynamic_suffix, self.model)
|
||||
if sys_msg is not None:
|
||||
full_messages.append(sys_msg)
|
||||
full_messages.extend(messages)
|
||||
|
||||
@@ -1228,12 +1469,17 @@ class LiteLLMProvider(LLMProvider):
|
||||
usage = response.usage
|
||||
input_tokens = usage.prompt_tokens if usage else 0
|
||||
output_tokens = usage.completion_tokens if usage else 0
|
||||
cached_tokens, cache_creation_tokens = _extract_cache_tokens(usage)
|
||||
cost_usd = _extract_cost(response, self.model)
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=response.model or self.model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cached_tokens=cached_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
cost_usd=cost_usd,
|
||||
stop_reason=response.choices[0].finish_reason or "",
|
||||
raw_response=response,
|
||||
)
|
||||
@@ -1619,6 +1865,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
system_dynamic_suffix: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build a JSON-only prompt for models without native tool support."""
|
||||
tool_specs = [
|
||||
@@ -1646,7 +1893,19 @@ class LiteLLMProvider(LLMProvider):
|
||||
)
|
||||
compat_system = compat_instruction if not system else f"{system}\n\n{compat_instruction}"
|
||||
|
||||
full_messages: list[dict[str, Any]] = [{"role": "system", "content": compat_system}]
|
||||
# If the routed sub-provider honors cache_control (e.g.
|
||||
# openrouter/anthropic/*), split the static prefix from the dynamic
|
||||
# suffix so the prefix stays cache-warm across turns. Otherwise fall
|
||||
# back to a single concatenated string.
|
||||
system_message = _build_system_message(
|
||||
compat_system,
|
||||
system_dynamic_suffix,
|
||||
self.model,
|
||||
)
|
||||
|
||||
full_messages: list[dict[str, Any]] = []
|
||||
if system_message is not None:
|
||||
full_messages.append(system_message)
|
||||
full_messages.extend(messages)
|
||||
return [
|
||||
message
|
||||
@@ -1660,9 +1919,21 @@ class LiteLLMProvider(LLMProvider):
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
max_tokens: int,
|
||||
system_dynamic_suffix: str | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Emulate tool calling via JSON when OpenRouter rejects native tools."""
|
||||
full_messages = self._build_openrouter_tool_compat_messages(messages, system, tools)
|
||||
"""Emulate tool calling via JSON when OpenRouter rejects native tools.
|
||||
|
||||
When the routed sub-provider honors ``cache_control`` (e.g.
|
||||
``openrouter/anthropic/*``), the message builder splits the static
|
||||
prefix from the dynamic suffix so the prefix stays cache-warm.
|
||||
Otherwise the suffix is concatenated into a single system string.
|
||||
"""
|
||||
full_messages = self._build_openrouter_tool_compat_messages(
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
system_dynamic_suffix=system_dynamic_suffix,
|
||||
)
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": full_messages,
|
||||
@@ -1683,6 +1954,8 @@ class LiteLLMProvider(LLMProvider):
|
||||
usage = response.usage
|
||||
input_tokens = usage.prompt_tokens if usage else 0
|
||||
output_tokens = usage.completion_tokens if usage else 0
|
||||
cached_tokens, cache_creation_tokens = _extract_cache_tokens(usage)
|
||||
cost_usd = _extract_cost(response, self.model)
|
||||
stop_reason = "tool_calls" if tool_calls else (response.choices[0].finish_reason or "stop")
|
||||
|
||||
return LLMResponse(
|
||||
@@ -1690,6 +1963,9 @@ class LiteLLMProvider(LLMProvider):
|
||||
model=response.model or self.model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cached_tokens=cached_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
cost_usd=cost_usd,
|
||||
stop_reason=stop_reason,
|
||||
raw_response={
|
||||
"compat_mode": "openrouter_tool_emulation",
|
||||
@@ -1704,6 +1980,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
max_tokens: int,
|
||||
system_dynamic_suffix: str | None = None,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Fallback stream for OpenRouter models without native tool support."""
|
||||
from framework.llm.stream_events import (
|
||||
@@ -1724,6 +2001,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
system_dynamic_suffix=system_dynamic_suffix,
|
||||
)
|
||||
except Exception as e:
|
||||
yield StreamErrorEvent(error=str(e), recoverable=False)
|
||||
@@ -1747,6 +2025,9 @@ class LiteLLMProvider(LLMProvider):
|
||||
stop_reason=response.stop_reason,
|
||||
input_tokens=response.input_tokens,
|
||||
output_tokens=response.output_tokens,
|
||||
cached_tokens=response.cached_tokens,
|
||||
cache_creation_tokens=response.cache_creation_tokens,
|
||||
cost_usd=response.cost_usd,
|
||||
model=response.model,
|
||||
)
|
||||
|
||||
@@ -1758,6 +2039,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
max_tokens: int,
|
||||
response_format: dict[str, Any] | None,
|
||||
json_mode: bool,
|
||||
system_dynamic_suffix: str | None = None,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Fallback path: convert non-stream completion to stream events.
|
||||
|
||||
@@ -1781,6 +2063,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_format,
|
||||
json_mode=json_mode,
|
||||
system_dynamic_suffix=system_dynamic_suffix,
|
||||
)
|
||||
except Exception as e:
|
||||
yield StreamErrorEvent(error=str(e), recoverable=False)
|
||||
@@ -1812,6 +2095,9 @@ class LiteLLMProvider(LLMProvider):
|
||||
stop_reason=response.stop_reason or "stop",
|
||||
input_tokens=response.input_tokens,
|
||||
output_tokens=response.output_tokens,
|
||||
cached_tokens=response.cached_tokens,
|
||||
cache_creation_tokens=response.cache_creation_tokens,
|
||||
cost_usd=response.cost_usd,
|
||||
model=response.model,
|
||||
)
|
||||
|
||||
@@ -1823,6 +2109,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
max_tokens: int = 4096,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
system_dynamic_suffix: str | None = None,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Stream a completion via litellm.acompletion(stream=True).
|
||||
|
||||
@@ -1833,6 +2120,9 @@ class LiteLLMProvider(LLMProvider):
|
||||
Empty responses (e.g. Gemini stealth rate-limits that return 200
|
||||
with no content) are retried with exponential backoff, mirroring
|
||||
the retry behaviour of ``_completion_with_rate_limit_retry``.
|
||||
|
||||
``system_dynamic_suffix`` is an optional per-turn tail. See
|
||||
``acomplete`` docstring for the two-block split semantics.
|
||||
"""
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
@@ -1852,6 +2142,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_format,
|
||||
json_mode=json_mode,
|
||||
system_dynamic_suffix=system_dynamic_suffix,
|
||||
):
|
||||
yield event
|
||||
return
|
||||
@@ -1862,6 +2153,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
system_dynamic_suffix=system_dynamic_suffix,
|
||||
):
|
||||
yield event
|
||||
return
|
||||
@@ -1870,10 +2162,8 @@ class LiteLLMProvider(LLMProvider):
|
||||
if self._claude_code_oauth:
|
||||
billing = _claude_code_billing_header(messages)
|
||||
full_messages.append({"role": "system", "content": billing})
|
||||
if system:
|
||||
sys_msg: dict[str, Any] = {"role": "system", "content": system}
|
||||
if _model_supports_cache_control(self.model):
|
||||
sys_msg["cache_control"] = {"type": "ephemeral"}
|
||||
sys_msg = _build_system_message(system, system_dynamic_suffix, self.model)
|
||||
if sys_msg is not None:
|
||||
full_messages.append(sys_msg)
|
||||
full_messages.extend(messages)
|
||||
|
||||
@@ -2109,37 +2399,44 @@ class LiteLLMProvider(LLMProvider):
|
||||
type(usage).__name__,
|
||||
)
|
||||
cached_tokens = 0
|
||||
cache_creation_tokens = 0
|
||||
if usage:
|
||||
input_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
output_tokens = getattr(usage, "completion_tokens", 0) or 0
|
||||
_details = getattr(usage, "prompt_tokens_details", None)
|
||||
cached_tokens = (
|
||||
getattr(_details, "cached_tokens", 0) or 0
|
||||
if _details is not None
|
||||
else getattr(usage, "cache_read_input_tokens", 0) or 0
|
||||
)
|
||||
cached_tokens, cache_creation_tokens = _extract_cache_tokens(usage)
|
||||
logger.debug(
|
||||
"[tokens] finish-chunk usage: input=%d output=%d cached=%d model=%s",
|
||||
"[tokens] finish-chunk usage: input=%d output=%d cached=%d cache_creation=%d model=%s",
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cached_tokens,
|
||||
cache_creation_tokens,
|
||||
self.model,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"[tokens] finish event: input=%d output=%d cached=%d stop=%s model=%s",
|
||||
"[tokens] finish event: input=%d output=%d cached=%d cache_creation=%d stop=%s model=%s",
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cached_tokens,
|
||||
cache_creation_tokens,
|
||||
choice.finish_reason,
|
||||
self.model,
|
||||
)
|
||||
cost_usd = _cost_from_tokens(
|
||||
self.model,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cached_tokens,
|
||||
cache_creation_tokens,
|
||||
)
|
||||
tail_events.append(
|
||||
FinishEvent(
|
||||
stop_reason=choice.finish_reason,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cached_tokens=cached_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
cost_usd=cost_usd,
|
||||
model=self.model,
|
||||
)
|
||||
)
|
||||
@@ -2159,19 +2456,36 @@ class LiteLLMProvider(LLMProvider):
|
||||
_usage = calculate_total_usage(chunks=_chunks)
|
||||
input_tokens = _usage.prompt_tokens or 0
|
||||
output_tokens = _usage.completion_tokens or 0
|
||||
_details = getattr(_usage, "prompt_tokens_details", None)
|
||||
cached_tokens = (
|
||||
getattr(_details, "cached_tokens", 0) or 0
|
||||
if _details is not None
|
||||
else getattr(_usage, "cache_read_input_tokens", 0) or 0
|
||||
)
|
||||
# `calculate_total_usage` aggregates token totals
|
||||
# but discards `prompt_tokens_details` — which is
|
||||
# where OpenRouter puts `cached_tokens` and
|
||||
# `cache_write_tokens`. Recover them directly
|
||||
# from the most recent chunk that carries usage.
|
||||
cached_tokens, cache_creation_tokens = 0, 0
|
||||
for _raw in reversed(_chunks):
|
||||
_raw_usage = getattr(_raw, "usage", None)
|
||||
if _raw_usage is None:
|
||||
continue
|
||||
_cr, _cc = _extract_cache_tokens(_raw_usage)
|
||||
if _cr or _cc:
|
||||
cached_tokens, cache_creation_tokens = _cr, _cc
|
||||
break
|
||||
logger.debug(
|
||||
"[tokens] post-loop chunks fallback: input=%d output=%d cached=%d model=%s",
|
||||
"[tokens] post-loop chunks fallback: input=%d output=%d "
|
||||
"cached=%d cache_creation=%d model=%s",
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cached_tokens,
|
||||
cache_creation_tokens,
|
||||
self.model,
|
||||
)
|
||||
cost_usd = _cost_from_tokens(
|
||||
self.model,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cached_tokens,
|
||||
cache_creation_tokens,
|
||||
)
|
||||
# Patch the FinishEvent already queued with 0 tokens
|
||||
for _i, _ev in enumerate(tail_events):
|
||||
if isinstance(_ev, FinishEvent) and _ev.input_tokens == 0:
|
||||
@@ -2180,6 +2494,8 @@ class LiteLLMProvider(LLMProvider):
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cached_tokens=cached_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
cost_usd=cost_usd,
|
||||
model=_ev.model,
|
||||
)
|
||||
break
|
||||
@@ -2390,6 +2706,8 @@ class LiteLLMProvider(LLMProvider):
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
cached_tokens = 0
|
||||
cache_creation_tokens = 0
|
||||
stop_reason = ""
|
||||
model = self.model
|
||||
|
||||
@@ -2407,6 +2725,8 @@ class LiteLLMProvider(LLMProvider):
|
||||
elif isinstance(event, FinishEvent):
|
||||
input_tokens = event.input_tokens
|
||||
output_tokens = event.output_tokens
|
||||
cached_tokens = event.cached_tokens
|
||||
cache_creation_tokens = event.cache_creation_tokens
|
||||
stop_reason = event.stop_reason
|
||||
if event.model:
|
||||
model = event.model
|
||||
@@ -2419,6 +2739,8 @@ class LiteLLMProvider(LLMProvider):
|
||||
model=model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cached_tokens=cached_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
stop_reason=stop_reason,
|
||||
raw_response={"tool_calls": tool_calls} if tool_calls else None,
|
||||
)
|
||||
|
||||
@@ -155,8 +155,11 @@ class MockLLMProvider(LLMProvider):
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
system_dynamic_suffix: str | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Async mock completion (no I/O, returns immediately)."""
|
||||
if system_dynamic_suffix:
|
||||
system = f"{system}\n\n{system_dynamic_suffix}" if system else system_dynamic_suffix
|
||||
return self.complete(
|
||||
messages=messages,
|
||||
system=system,
|
||||
@@ -173,6 +176,7 @@ class MockLLMProvider(LLMProvider):
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
system_dynamic_suffix: str | None = None,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Stream a mock completion as word-level TextDeltaEvents.
|
||||
|
||||
@@ -180,6 +184,8 @@ class MockLLMProvider(LLMProvider):
|
||||
TextDeltaEvent with an accumulating snapshot, exercising the full
|
||||
streaming pipeline without any API calls.
|
||||
"""
|
||||
if system_dynamic_suffix:
|
||||
system = f"{system}\n\n{system_dynamic_suffix}" if system else system_dynamic_suffix
|
||||
content = self._generate_mock_response(system=system, json_mode=False)
|
||||
words = content.split(" ")
|
||||
accumulated = ""
|
||||
|
||||
@@ -9,47 +9,65 @@
|
||||
"label": "Haiku 4.5 - Fast + cheap",
|
||||
"recommended": false,
|
||||
"max_tokens": 64000,
|
||||
"max_context_tokens": 136000
|
||||
"max_context_tokens": 136000,
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "claude-sonnet-4-5-20250929",
|
||||
"label": "Sonnet 4.5 - Best balance",
|
||||
"recommended": false,
|
||||
"max_tokens": 64000,
|
||||
"max_context_tokens": 136000
|
||||
"max_context_tokens": 136000,
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "claude-opus-4-6",
|
||||
"label": "Opus 4.6 - Most capable",
|
||||
"recommended": true,
|
||||
"max_tokens": 128000,
|
||||
"max_context_tokens": 872000
|
||||
"max_context_tokens": 872000,
|
||||
"supports_vision": true
|
||||
}
|
||||
]
|
||||
},
|
||||
"openai": {
|
||||
"default_model": "gpt-5.4",
|
||||
"default_model": "gpt-5.5",
|
||||
"models": [
|
||||
{
|
||||
"id": "gpt-5.4",
|
||||
"label": "GPT-5.4 - Best intelligence",
|
||||
"id": "gpt-5.5",
|
||||
"label": "GPT-5.5 - Frontier coding + reasoning",
|
||||
"recommended": true,
|
||||
"max_tokens": 128000,
|
||||
"max_context_tokens": 960000
|
||||
"max_context_tokens": 1050000,
|
||||
"pricing_usd_per_mtok": {
|
||||
"input": 5.00,
|
||||
"output": 30.00
|
||||
},
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.4",
|
||||
"label": "GPT-5.4 - Previous flagship",
|
||||
"recommended": false,
|
||||
"max_tokens": 128000,
|
||||
"max_context_tokens": 960000,
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.4-mini",
|
||||
"label": "GPT-5.4 Mini - Faster + cheaper",
|
||||
"recommended": false,
|
||||
"max_tokens": 128000,
|
||||
"max_context_tokens": 400000
|
||||
"max_context_tokens": 400000,
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.4-nano",
|
||||
"label": "GPT-5.4 Nano - Cheapest high-volume",
|
||||
"recommended": false,
|
||||
"max_tokens": 128000,
|
||||
"max_context_tokens": 400000
|
||||
"max_context_tokens": 400000,
|
||||
"supports_vision": true
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -61,14 +79,16 @@
|
||||
"label": "Gemini 3 Flash - Fast",
|
||||
"recommended": false,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 240000
|
||||
"max_context_tokens": 240000,
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "gemini-3.1-pro-preview-customtools",
|
||||
"label": "Gemini 3.1 Pro - Best quality",
|
||||
"recommended": true,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 240000
|
||||
"max_context_tokens": 240000,
|
||||
"supports_vision": true
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -80,28 +100,32 @@
|
||||
"label": "GPT-OSS 120B - Best reasoning",
|
||||
"recommended": true,
|
||||
"max_tokens": 65536,
|
||||
"max_context_tokens": 131072
|
||||
"max_context_tokens": 131072,
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "openai/gpt-oss-20b",
|
||||
"label": "GPT-OSS 20B - Fast + cheaper",
|
||||
"recommended": false,
|
||||
"max_tokens": 65536,
|
||||
"max_context_tokens": 131072
|
||||
"max_context_tokens": 131072,
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "llama-3.3-70b-versatile",
|
||||
"label": "Llama 3.3 70B - General purpose",
|
||||
"recommended": false,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 131072
|
||||
"max_context_tokens": 131072,
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "llama-3.1-8b-instant",
|
||||
"label": "Llama 3.1 8B - Fastest",
|
||||
"recommended": false,
|
||||
"max_tokens": 131072,
|
||||
"max_context_tokens": 131072
|
||||
"max_context_tokens": 131072,
|
||||
"supports_vision": false
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -113,21 +137,24 @@
|
||||
"label": "GPT-OSS 120B - Best production reasoning",
|
||||
"recommended": true,
|
||||
"max_tokens": 40960,
|
||||
"max_context_tokens": 131072
|
||||
"max_context_tokens": 131072,
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "zai-glm-4.7",
|
||||
"label": "Z.ai GLM 4.7 - Strong coding preview",
|
||||
"recommended": true,
|
||||
"max_tokens": 40960,
|
||||
"max_context_tokens": 131072
|
||||
"max_context_tokens": 131072,
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "qwen-3-235b-a22b-instruct-2507",
|
||||
"label": "Qwen 3 235B Instruct - Frontier preview",
|
||||
"recommended": false,
|
||||
"max_tokens": 40960,
|
||||
"max_context_tokens": 131072
|
||||
"max_context_tokens": 131072,
|
||||
"supports_vision": false
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -139,14 +166,20 @@
|
||||
"label": "MiniMax M2.7 - Best coding quality",
|
||||
"recommended": true,
|
||||
"max_tokens": 40960,
|
||||
"max_context_tokens": 180000
|
||||
"max_context_tokens": 180000,
|
||||
"pricing_usd_per_mtok": {
|
||||
"input": 0.30,
|
||||
"output": 1.20
|
||||
},
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "MiniMax-M2.5",
|
||||
"label": "MiniMax M2.5 - Strong value",
|
||||
"recommended": false,
|
||||
"max_tokens": 40960,
|
||||
"max_context_tokens": 180000
|
||||
"max_context_tokens": 180000,
|
||||
"supports_vision": false
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -158,28 +191,32 @@
|
||||
"label": "Mistral Large 3 - Best quality",
|
||||
"recommended": true,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 256000
|
||||
"max_context_tokens": 256000,
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "mistral-medium-2508",
|
||||
"label": "Mistral Medium 3.1 - Balanced",
|
||||
"recommended": false,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 128000
|
||||
"max_context_tokens": 128000,
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "mistral-small-2603",
|
||||
"label": "Mistral Small 4 - Fast + capable",
|
||||
"recommended": false,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 256000
|
||||
"max_context_tokens": 256000,
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "codestral-2508",
|
||||
"label": "Codestral - Coding specialist",
|
||||
"recommended": false,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 128000
|
||||
"max_context_tokens": 128000,
|
||||
"supports_vision": false
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -191,47 +228,71 @@
|
||||
"label": "DeepSeek V3.1 - Best general coding",
|
||||
"recommended": true,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 128000
|
||||
"max_context_tokens": 128000,
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8",
|
||||
"label": "Qwen3 Coder 480B - Advanced coding",
|
||||
"recommended": false,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 262144
|
||||
"max_context_tokens": 262144,
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "openai/gpt-oss-120b",
|
||||
"label": "GPT-OSS 120B - Strong reasoning",
|
||||
"recommended": false,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 128000
|
||||
"max_context_tokens": 128000,
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
||||
"label": "Llama 3.3 70B Turbo - Fast baseline",
|
||||
"recommended": false,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 131072
|
||||
"max_context_tokens": 131072,
|
||||
"supports_vision": false
|
||||
}
|
||||
]
|
||||
},
|
||||
"deepseek": {
|
||||
"default_model": "deepseek-chat",
|
||||
"default_model": "deepseek-v4-pro",
|
||||
"models": [
|
||||
{
|
||||
"id": "deepseek-chat",
|
||||
"label": "DeepSeek Chat - Fast default",
|
||||
"id": "deepseek-v4-pro",
|
||||
"label": "DeepSeek V4 Pro - Most capable",
|
||||
"recommended": true,
|
||||
"max_tokens": 8192,
|
||||
"max_context_tokens": 128000
|
||||
"max_tokens": 384000,
|
||||
"max_context_tokens": 1000000,
|
||||
"pricing_usd_per_mtok": {
|
||||
"input": 1.74,
|
||||
"output": 3.48,
|
||||
"cache_read": 0.145
|
||||
},
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "deepseek-v4-flash",
|
||||
"label": "DeepSeek V4 Flash - Fast + cheap",
|
||||
"recommended": true,
|
||||
"max_tokens": 384000,
|
||||
"max_context_tokens": 1000000,
|
||||
"pricing_usd_per_mtok": {
|
||||
"input": 0.14,
|
||||
"output": 0.28,
|
||||
"cache_read": 0.028
|
||||
},
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "deepseek-reasoner",
|
||||
"label": "DeepSeek Reasoner - Deep thinking",
|
||||
"label": "DeepSeek Reasoner - Legacy (deprecating)",
|
||||
"recommended": false,
|
||||
"max_tokens": 64000,
|
||||
"max_context_tokens": 128000
|
||||
"max_context_tokens": 128000,
|
||||
"supports_vision": false
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -243,7 +304,13 @@
|
||||
"label": "Kimi K2.5 - Best coding",
|
||||
"recommended": true,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 200000
|
||||
"max_context_tokens": 200000,
|
||||
"pricing_usd_per_mtok": {
|
||||
"input": 0.60,
|
||||
"output": 2.50,
|
||||
"cache_read": 0.15
|
||||
},
|
||||
"supports_vision": true
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -255,21 +322,30 @@
|
||||
"label": "Queen - Hive native",
|
||||
"recommended": true,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 180000
|
||||
"max_context_tokens": 180000,
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "kimi-2.5",
|
||||
"label": "Kimi 2.5 - Via Hive",
|
||||
"recommended": false,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 240000
|
||||
"max_context_tokens": 240000,
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "GLM-5",
|
||||
"label": "GLM-5 - Via Hive",
|
||||
"id": "glm-5.1",
|
||||
"label": "GLM-5.1 - Via Hive",
|
||||
"recommended": false,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 180000
|
||||
"max_context_tokens": 180000,
|
||||
"pricing_usd_per_mtok": {
|
||||
"input": 1.40,
|
||||
"output": 4.40,
|
||||
"cache_read": 0.26,
|
||||
"cache_creation": 0.0
|
||||
},
|
||||
"supports_vision": false
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -281,63 +357,82 @@
|
||||
"label": "GPT-5.4 - Best overall",
|
||||
"recommended": true,
|
||||
"max_tokens": 128000,
|
||||
"max_context_tokens": 872000
|
||||
"max_context_tokens": 872000,
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "anthropic/claude-sonnet-4.6",
|
||||
"label": "Claude Sonnet 4.6 - Best coding balance",
|
||||
"recommended": false,
|
||||
"max_tokens": 64000,
|
||||
"max_context_tokens": 872000
|
||||
"max_context_tokens": 872000,
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "anthropic/claude-opus-4.6",
|
||||
"label": "Claude Opus 4.6 - Most capable",
|
||||
"recommended": false,
|
||||
"max_tokens": 128000,
|
||||
"max_context_tokens": 872000
|
||||
"max_context_tokens": 872000,
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "google/gemini-3.1-pro-preview-customtools",
|
||||
"label": "Gemini 3.1 Pro Preview - Long-context reasoning",
|
||||
"recommended": false,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 872000
|
||||
"max_context_tokens": 872000,
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "qwen/qwen3.6-plus",
|
||||
"label": "Qwen 3.6 Plus - Strong reasoning",
|
||||
"recommended": true,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 240000
|
||||
"max_context_tokens": 240000,
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "z-ai/glm-5v-turbo",
|
||||
"label": "GLM-5V Turbo - Vision capable",
|
||||
"recommended": true,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 192000
|
||||
"max_context_tokens": 192000,
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "z-ai/glm-5.1",
|
||||
"label": "GLM-5.1 - Better but Slower",
|
||||
"recommended": true,
|
||||
"max_tokens": 40960,
|
||||
"max_context_tokens": 192000
|
||||
"max_context_tokens": 192000,
|
||||
"pricing_usd_per_mtok": {
|
||||
"input": 1.40,
|
||||
"output": 4.40,
|
||||
"cache_read": 0.26,
|
||||
"cache_creation": 0.0
|
||||
},
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "minimax/minimax-m2.7",
|
||||
"label": "Minimax M2.7 - Minimax flagship",
|
||||
"recommended": false,
|
||||
"max_tokens": 40960,
|
||||
"max_context_tokens": 180000
|
||||
"max_context_tokens": 180000,
|
||||
"pricing_usd_per_mtok": {
|
||||
"input": 0.30,
|
||||
"output": 1.20
|
||||
},
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "xiaomi/mimo-v2-pro",
|
||||
"label": "MiMo V2 Pro - Xiaomi multimodal",
|
||||
"recommended": true,
|
||||
"max_tokens": 64000,
|
||||
"max_context_tokens": 240000
|
||||
"max_context_tokens": 240000,
|
||||
"supports_vision": true
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -352,7 +447,7 @@
|
||||
"zai_code": {
|
||||
"provider": "openai",
|
||||
"api_key_env_var": "ZAI_API_KEY",
|
||||
"model": "glm-5",
|
||||
"model": "glm-5.1",
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 180000,
|
||||
"api_base": "https://api.z.ai/api/coding/paas/v4"
|
||||
@@ -399,8 +494,8 @@
|
||||
"recommended": false
|
||||
},
|
||||
{
|
||||
"id": "GLM-5",
|
||||
"label": "GLM-5",
|
||||
"id": "glm-5.1",
|
||||
"label": "glm-5.1",
|
||||
"recommended": false
|
||||
}
|
||||
]
|
||||
|
||||
@@ -27,6 +27,28 @@ def _require_list(value: Any, path: str) -> list[Any]:
|
||||
return value
|
||||
|
||||
|
||||
_PRICING_KEYS = ("input", "output", "cache_read", "cache_creation")
|
||||
|
||||
|
||||
def _validate_pricing(value: Any, path: str) -> None:
|
||||
"""Validate an optional ``pricing_usd_per_mtok`` block.
|
||||
|
||||
Keys are USD-per-million-tokens rates. ``input``/``output`` are required;
|
||||
``cache_read``/``cache_creation`` are optional. All values must be
|
||||
non-negative numbers. Used as a last-resort fallback when neither the
|
||||
provider nor LiteLLM's catalog reports a cost.
|
||||
"""
|
||||
pricing = _require_mapping(value, path)
|
||||
for key in ("input", "output"):
|
||||
if key not in pricing:
|
||||
raise ModelCatalogError(f"{path}.{key} is required")
|
||||
for key, rate in pricing.items():
|
||||
if key not in _PRICING_KEYS:
|
||||
raise ModelCatalogError(f"{path}.{key} is not a recognized pricing field")
|
||||
if not isinstance(rate, (int, float)) or isinstance(rate, bool) or rate < 0:
|
||||
raise ModelCatalogError(f"{path}.{key} must be a non-negative number")
|
||||
|
||||
|
||||
def _validate_model_catalog(data: dict[str, Any]) -> dict[str, Any]:
|
||||
providers = _require_mapping(data.get("providers"), "providers")
|
||||
|
||||
@@ -69,6 +91,14 @@ def _validate_model_catalog(data: dict[str, Any]) -> dict[str, Any]:
|
||||
if not isinstance(value, int) or value <= 0:
|
||||
raise ModelCatalogError(f"{model_path}.{key} must be a positive integer")
|
||||
|
||||
pricing = model_map.get("pricing_usd_per_mtok")
|
||||
if pricing is not None:
|
||||
_validate_pricing(pricing, f"{model_path}.pricing_usd_per_mtok")
|
||||
|
||||
supports_vision = model_map.get("supports_vision")
|
||||
if supports_vision is not None and not isinstance(supports_vision, bool):
|
||||
raise ModelCatalogError(f"{model_path}.supports_vision must be a boolean when present")
|
||||
|
||||
if not default_found:
|
||||
raise ModelCatalogError(
|
||||
f"{provider_path}.default_model={default_model!r} is not present in {provider_path}.models"
|
||||
@@ -184,6 +214,53 @@ def get_model_limits(provider: str, model_id: str) -> tuple[int, int] | None:
|
||||
return int(model["max_tokens"]), int(model["max_context_tokens"])
|
||||
|
||||
|
||||
def get_model_pricing(model_id: str) -> dict[str, float] | None:
|
||||
"""Return ``pricing_usd_per_mtok`` for a model id, searching all providers.
|
||||
|
||||
Returns ``None`` when the model is absent from the catalog or has no
|
||||
pricing entry. Used by the cost-extraction fallback in ``litellm.py``
|
||||
when the provider response and LiteLLM's catalog both come up empty.
|
||||
"""
|
||||
if not model_id:
|
||||
return None
|
||||
for provider_info in load_model_catalog()["providers"].values():
|
||||
for model in provider_info["models"]:
|
||||
if model["id"] == model_id:
|
||||
pricing = model.get("pricing_usd_per_mtok")
|
||||
if pricing is None:
|
||||
return None
|
||||
return {key: float(rate) for key, rate in pricing.items()}
|
||||
return None
|
||||
|
||||
|
||||
def model_supports_vision(model_id: str) -> bool:
|
||||
"""Return whether *model_id* supports image inputs per the curated catalog.
|
||||
|
||||
Looks up the bare model id (and the provider-prefix-stripped form) in the
|
||||
catalog. Returns the model's ``supports_vision`` flag when found, defaulting
|
||||
to ``True`` for unknown models or when the flag is absent — assume vision
|
||||
capable for hosted providers, since modern frontier models support images
|
||||
by default and the captioning fallback is more expensive than just letting
|
||||
the provider handle the image.
|
||||
"""
|
||||
if not model_id:
|
||||
return True
|
||||
|
||||
candidates = [model_id]
|
||||
if "/" in model_id:
|
||||
candidates.append(model_id.split("/", 1)[1])
|
||||
|
||||
for candidate in candidates:
|
||||
for provider_info in load_model_catalog()["providers"].values():
|
||||
for model in provider_info["models"]:
|
||||
if model["id"] == candidate:
|
||||
flag = model.get("supports_vision")
|
||||
if isinstance(flag, bool):
|
||||
return flag
|
||||
return True
|
||||
return True
|
||||
|
||||
|
||||
def get_preset(preset_id: str) -> dict[str, Any] | None:
|
||||
"""Return one preset entry."""
|
||||
preset = load_model_catalog()["presets"].get(preset_id)
|
||||
|
||||
@@ -10,12 +10,24 @@ from typing import Any
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""Response from an LLM call."""
|
||||
"""Response from an LLM call.
|
||||
|
||||
``cached_tokens`` and ``cache_creation_tokens`` are subsets of
|
||||
``input_tokens`` (providers report them inside ``prompt_tokens``).
|
||||
Surface them for visibility; do not add to a total.
|
||||
|
||||
``cost_usd`` is the per-call USD cost when the provider / pricing table
|
||||
can produce one (Anthropic, OpenAI, OpenRouter are supported). 0.0 when
|
||||
unknown or unpriced — treat as "unreported", not "free".
|
||||
"""
|
||||
|
||||
content: str
|
||||
model: str
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cached_tokens: int = 0
|
||||
cache_creation_tokens: int = 0
|
||||
cost_usd: float = 0.0
|
||||
stop_reason: str = ""
|
||||
raw_response: Any = None
|
||||
|
||||
@@ -110,19 +122,28 @@ class LLMProvider(ABC):
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
system_dynamic_suffix: str | None = None,
|
||||
) -> "LLMResponse":
|
||||
"""Async version of complete(). Non-blocking on the event loop.
|
||||
|
||||
Default implementation offloads the sync complete() to a thread pool.
|
||||
Subclasses SHOULD override for native async I/O.
|
||||
|
||||
``system_dynamic_suffix`` is an optional per-turn tail for providers
|
||||
that honor ``cache_control`` (see LiteLLMProvider for semantics).
|
||||
The default implementation concatenates it onto ``system`` since the
|
||||
sync ``complete()`` path does not support the split.
|
||||
"""
|
||||
combined_system = system
|
||||
if system_dynamic_suffix:
|
||||
combined_system = f"{system}\n\n{system_dynamic_suffix}" if system else system_dynamic_suffix
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
partial(
|
||||
self.complete,
|
||||
messages=messages,
|
||||
system=system,
|
||||
system=combined_system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_format,
|
||||
@@ -137,6 +158,7 @@ class LLMProvider(ABC):
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
system_dynamic_suffix: str | None = None,
|
||||
) -> AsyncIterator["StreamEvent"]:
|
||||
"""
|
||||
Stream a completion as an async iterator of StreamEvents.
|
||||
@@ -147,6 +169,9 @@ class LLMProvider(ABC):
|
||||
Tool orchestration is the CALLER's responsibility:
|
||||
- Caller detects ToolCallEvent, executes tool, adds result
|
||||
to messages, calls stream() again.
|
||||
|
||||
``system_dynamic_suffix`` is forwarded to ``acomplete``; see its
|
||||
docstring for the two-block split semantics.
|
||||
"""
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
@@ -159,6 +184,7 @@ class LLMProvider(ABC):
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
system_dynamic_suffix=system_dynamic_suffix,
|
||||
)
|
||||
yield TextDeltaEvent(content=response.content, snapshot=response.content)
|
||||
yield TextEndEvent(full_text=response.content)
|
||||
@@ -166,6 +192,9 @@ class LLMProvider(ABC):
|
||||
stop_reason=response.stop_reason,
|
||||
input_tokens=response.input_tokens,
|
||||
output_tokens=response.output_tokens,
|
||||
cached_tokens=response.cached_tokens,
|
||||
cache_creation_tokens=response.cache_creation_tokens,
|
||||
cost_usd=response.cost_usd,
|
||||
model=response.model,
|
||||
)
|
||||
|
||||
|
||||
@@ -65,13 +65,23 @@ class ReasoningDeltaEvent:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FinishEvent:
|
||||
"""The LLM has finished generating."""
|
||||
"""The LLM has finished generating.
|
||||
|
||||
``cached_tokens`` and ``cache_creation_tokens`` are subsets of
|
||||
``input_tokens`` — providers count both inside ``prompt_tokens`` already.
|
||||
Surface them separately for visibility; never add to a total.
|
||||
|
||||
``cost_usd`` is the per-turn USD cost when the provider or LiteLLM's
|
||||
pricing table supplies one; 0.0 means unreported (not free).
|
||||
"""
|
||||
|
||||
type: Literal["finish"] = "finish"
|
||||
stop_reason: str = ""
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cached_tokens: int = 0
|
||||
cache_creation_tokens: int = 0
|
||||
cost_usd: float = 0.0
|
||||
model: str = ""
|
||||
|
||||
|
||||
|
||||
@@ -9,8 +9,8 @@ Nodes that need browser access declare ``tools: {policy: "all"}`` in their
|
||||
agent.json config.
|
||||
|
||||
Note: the canonical source of truth for browser automation guidance is
|
||||
the ``browser-automation`` default skill at
|
||||
``core/framework/skills/_default_skills/browser-automation/SKILL.md``.
|
||||
the ``browser-automation`` preset skill at
|
||||
``core/framework/skills/_preset_skills/browser-automation/SKILL.md``.
|
||||
Activate that skill for the full decision tree. This module holds a
|
||||
compact subset suitable for direct inlining into a node's system prompt
|
||||
when a skill activation is not desired.
|
||||
|
||||
@@ -543,6 +543,10 @@ class NodeContext:
|
||||
# Dynamic memory provider — when set, EventLoopNode rebuilds the
|
||||
# system prompt with the latest memory block each iteration.
|
||||
dynamic_memory_provider: Any = None # Callable[[], str] | None
|
||||
# Surgical skills-catalog refresh, same contract as AgentContext's
|
||||
# field of the same name. Lets workers pick up UI-driven skill
|
||||
# toggles without rebuilding the full system prompt each turn.
|
||||
dynamic_skills_catalog_provider: Any = None # Callable[[], str] | None
|
||||
|
||||
# Skill system prompts — injected by the skill discovery pipeline
|
||||
skills_catalog_prompt: str = "" # Available skills XML catalog
|
||||
|
||||
@@ -140,6 +140,25 @@ async def cors_middleware(request: web.Request, handler):
|
||||
return response
|
||||
|
||||
|
||||
@web.middleware
|
||||
async def no_cache_api_middleware(request: web.Request, handler):
|
||||
"""Prevent browsers from caching API responses.
|
||||
|
||||
Without this, a one-off bad response (e.g. the SPA catch-all leaking
|
||||
index.html for an /api/* URL before a route was registered) can get
|
||||
pinned in the browser's disk cache and replayed forever, since our
|
||||
JSON handlers don't emit ETag/Last-Modified and browsers fall back
|
||||
to heuristic freshness.
|
||||
"""
|
||||
try:
|
||||
response = await handler(request)
|
||||
except web.HTTPException as exc:
|
||||
response = exc
|
||||
if request.path.startswith("/api/"):
|
||||
response.headers["Cache-Control"] = "no-store"
|
||||
return response
|
||||
|
||||
|
||||
@web.middleware
|
||||
async def error_middleware(request: web.Request, handler):
|
||||
"""Catch exceptions and return JSON error responses.
|
||||
@@ -268,7 +287,7 @@ def create_app(model: str | None = None) -> web.Application:
|
||||
Returns:
|
||||
Configured aiohttp Application ready to run.
|
||||
"""
|
||||
app = web.Application(middlewares=[cors_middleware, error_middleware])
|
||||
app = web.Application(middlewares=[cors_middleware, no_cache_api_middleware, error_middleware])
|
||||
|
||||
# Initialize credential store (before SessionManager so it can be shared)
|
||||
from framework.credentials.store import CredentialStore
|
||||
@@ -325,16 +344,20 @@ def create_app(model: str | None = None) -> web.Application:
|
||||
app.router.add_get("/api/browser/status/stream", handle_browser_status_stream)
|
||||
|
||||
# Register route modules
|
||||
from framework.server.routes_colony_tools import register_routes as register_colony_tools_routes
|
||||
from framework.server.routes_colony_workers import register_routes as register_colony_worker_routes
|
||||
from framework.server.routes_config import register_routes as register_config_routes
|
||||
from framework.server.routes_credentials import register_routes as register_credential_routes
|
||||
from framework.server.routes_events import register_routes as register_event_routes
|
||||
from framework.server.routes_execution import register_routes as register_execution_routes
|
||||
from framework.server.routes_logs import register_routes as register_log_routes
|
||||
from framework.server.routes_mcp import register_routes as register_mcp_routes
|
||||
from framework.server.routes_messages import register_routes as register_message_routes
|
||||
from framework.server.routes_prompts import register_routes as register_prompt_routes
|
||||
from framework.server.routes_queen_tools import register_routes as register_queen_tools_routes
|
||||
from framework.server.routes_queens import register_routes as register_queen_routes
|
||||
from framework.server.routes_sessions import register_routes as register_session_routes
|
||||
from framework.server.routes_skills import register_routes as register_skills_routes
|
||||
from framework.server.routes_workers import register_routes as register_worker_routes
|
||||
|
||||
register_config_routes(app)
|
||||
@@ -346,8 +369,12 @@ def create_app(model: str | None = None) -> web.Application:
|
||||
register_worker_routes(app)
|
||||
register_log_routes(app)
|
||||
register_queen_routes(app)
|
||||
register_queen_tools_routes(app)
|
||||
register_colony_tools_routes(app)
|
||||
register_mcp_routes(app)
|
||||
register_colony_worker_routes(app)
|
||||
register_prompt_routes(app)
|
||||
register_skills_routes(app)
|
||||
|
||||
# Static file serving — Option C production mode
|
||||
# If frontend/dist/ exists, serve built frontend files on /
|
||||
|
||||
@@ -253,6 +253,92 @@ async def materialize_queen_identity(
|
||||
)
|
||||
|
||||
|
||||
def build_queen_tool_registry_bare() -> tuple[Any, dict[str, list[dict[str, Any]]]]:
|
||||
"""Build a Queen ``ToolRegistry`` and a (server_name → tools) catalog.
|
||||
|
||||
Used by the Tool Library GET route to populate the MCP tool surface
|
||||
without needing a live queen session. We DO NOT register queen
|
||||
lifecycle tools here (they require a Session stub); the catalog only
|
||||
covers MCP-origin tools, which is what the allowlist gates.
|
||||
|
||||
Loading MCP servers spawns subprocesses, so call this once per
|
||||
backend process and cache the result.
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import framework.agents.queen as _queen_pkg
|
||||
from framework.loader.mcp_registry import MCPRegistry
|
||||
from framework.loader.tool_registry import ToolRegistry
|
||||
|
||||
queen_registry = ToolRegistry()
|
||||
queen_pkg_dir = Path(_queen_pkg.__file__).parent
|
||||
|
||||
mcp_config = queen_pkg_dir / "mcp_servers.json"
|
||||
if mcp_config.exists():
|
||||
try:
|
||||
queen_registry.load_mcp_config(mcp_config)
|
||||
except Exception:
|
||||
logger.warning("build_queen_tool_registry_bare: MCP config failed", exc_info=True)
|
||||
|
||||
try:
|
||||
reg = MCPRegistry()
|
||||
reg.initialize()
|
||||
if (queen_pkg_dir / "mcp_registry.json").is_file():
|
||||
queen_registry.set_mcp_registry_agent_path(queen_pkg_dir)
|
||||
registry_configs, selection_max_tools = reg.load_agent_selection(queen_pkg_dir)
|
||||
|
||||
already = {cfg.get("name") for cfg in registry_configs if cfg.get("name")}
|
||||
extra: list[str] = []
|
||||
try:
|
||||
for entry in reg.list_installed():
|
||||
if entry.get("source") != "local":
|
||||
continue
|
||||
if not entry.get("enabled", True):
|
||||
continue
|
||||
name = entry.get("name")
|
||||
if name and name not in already:
|
||||
extra.append(name)
|
||||
except Exception:
|
||||
pass
|
||||
if extra:
|
||||
try:
|
||||
extra_configs = reg.resolve_for_agent(include=extra)
|
||||
registry_configs = list(registry_configs) + [reg._server_config_to_dict(c) for c in extra_configs]
|
||||
except Exception:
|
||||
logger.debug("build_queen_tool_registry_bare: resolve_for_agent(extra) failed", exc_info=True)
|
||||
|
||||
if registry_configs:
|
||||
queen_registry.load_registry_servers(
|
||||
registry_configs,
|
||||
preserve_existing_tools=True,
|
||||
log_collisions=False,
|
||||
max_tools=selection_max_tools,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("build_queen_tool_registry_bare: MCP registry load failed", exc_info=True)
|
||||
|
||||
# Build the catalog.
|
||||
tools_by_name = queen_registry.get_tools()
|
||||
server_map = dict(getattr(queen_registry, "_mcp_server_tools", {}) or {})
|
||||
catalog: dict[str, list[dict[str, Any]]] = {}
|
||||
for server_name in sorted(server_map):
|
||||
entries: list[dict[str, Any]] = []
|
||||
for tool_name in sorted(server_map[server_name]):
|
||||
tool = tools_by_name.get(tool_name)
|
||||
if tool is None:
|
||||
continue
|
||||
entries.append(
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": tool.parameters,
|
||||
}
|
||||
)
|
||||
catalog[server_name] = entries
|
||||
|
||||
return queen_registry, catalog
|
||||
|
||||
|
||||
async def create_queen(
|
||||
session: Session,
|
||||
session_manager: Any,
|
||||
@@ -326,6 +412,45 @@ async def create_queen(
|
||||
if (queen_pkg_dir / "mcp_registry.json").is_file():
|
||||
queen_registry.set_mcp_registry_agent_path(queen_pkg_dir)
|
||||
registry_configs, selection_max_tools = registry.load_agent_selection(queen_pkg_dir)
|
||||
|
||||
# Auto-include every user-added local MCP server that the repo
|
||||
# selection hasn't already loaded. Users register servers via
|
||||
# the `/api/mcp/servers` route (or `hive mcp add`); they live in
|
||||
# ~/.hive/mcp_registry/installed.json with source == "local".
|
||||
# New servers take effect on the next queen session start; the
|
||||
# prompt cache and ToolRegistry are still loaded once per boot.
|
||||
already_loaded_names = {cfg.get("name") for cfg in registry_configs if cfg.get("name")}
|
||||
extra_names: list[str] = []
|
||||
try:
|
||||
for entry in registry.list_installed():
|
||||
if entry.get("source") != "local":
|
||||
continue
|
||||
if not entry.get("enabled", True):
|
||||
continue
|
||||
name = entry.get("name")
|
||||
if not name or name in already_loaded_names:
|
||||
continue
|
||||
extra_names.append(name)
|
||||
except Exception:
|
||||
logger.debug("Queen: list_installed() failed while auto-including user servers", exc_info=True)
|
||||
|
||||
if extra_names:
|
||||
try:
|
||||
extra_configs = registry.resolve_for_agent(include=extra_names)
|
||||
extra_dicts = [registry._server_config_to_dict(c) for c in extra_configs]
|
||||
registry_configs = list(registry_configs) + extra_dicts
|
||||
logger.info(
|
||||
"Queen: auto-including %d user-added MCP server(s): %s",
|
||||
len(extra_dicts),
|
||||
[c.get("name") for c in extra_dicts],
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Queen: failed to resolve user-added MCP servers %s",
|
||||
extra_names,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if registry_configs:
|
||||
results = queen_registry.load_registry_servers(
|
||||
registry_configs,
|
||||
@@ -417,6 +542,71 @@ async def create_queen(
|
||||
sorted(t.name for t in phase_state.incubating_tools),
|
||||
)
|
||||
|
||||
# ---- Per-queen MCP tool allowlist --------------------------------
|
||||
# Capture the set of MCP-origin tool names so the allowlist in
|
||||
# ``QueenPhaseState`` only gates MCP tools (lifecycle and synthetic
|
||||
# tools always pass through). Then apply the queen profile's stored
|
||||
# allowlist (if any) and memoize the filtered independent tool list.
|
||||
mcp_server_tools_map: dict[str, set[str]] = dict(getattr(queen_registry, "_mcp_server_tools", {}))
|
||||
phase_state.mcp_tool_names_all = set().union(*mcp_server_tools_map.values()) if mcp_server_tools_map else set()
|
||||
# The queen's MCP tool allowlist now lives in a dedicated
|
||||
# ``tools.json`` sidecar next to ``profile.yaml``. ``load_queen_tools_config``
|
||||
# migrates any legacy ``enabled_mcp_tools`` field out of profile.yaml
|
||||
# on first read, so existing installs upgrade silently.
|
||||
from framework.agents.queen.queen_tools_config import load_queen_tools_config
|
||||
|
||||
# Build a minimal catalog for default-tool resolution. The full
|
||||
# ``session_manager._mcp_tool_catalog`` snapshot is written further
|
||||
# down the flow; a queen booted for the first time needs the catalog
|
||||
# now so ``@server:NAME`` shorthands in the role-default table can
|
||||
# expand against the just-loaded MCP servers.
|
||||
_boot_catalog: dict[str, list[dict]] = {
|
||||
srv: [{"name": name} for name in sorted(names)] for srv, names in mcp_server_tools_map.items()
|
||||
}
|
||||
# ``queen_dir`` is ``queens/<queen_id>/sessions/<session_id>``; the
|
||||
# allowlist sidecar is keyed by queen_id, not session_id.
|
||||
phase_state.enabled_mcp_tools = load_queen_tools_config(session.queen_name, _boot_catalog)
|
||||
phase_state.rebuild_independent_filter()
|
||||
if phase_state.enabled_mcp_tools is not None:
|
||||
total_mcp = len(phase_state.mcp_tool_names_all)
|
||||
allowed_mcp = len(set(phase_state.enabled_mcp_tools) & phase_state.mcp_tool_names_all)
|
||||
logger.info(
|
||||
"Queen: per-queen MCP allowlist active — %d of %d MCP tools enabled",
|
||||
allowed_mcp,
|
||||
total_mcp,
|
||||
)
|
||||
|
||||
# ---- MCP tool catalog for the frontend ---------------------------
|
||||
# Snapshot per-server tool metadata so the Queen Tools API can render
|
||||
# the tool surface without spawning MCP subprocesses. Keyed by server
|
||||
# name so the UI can group tools by origin. Updated every time a
|
||||
# queen boots, so installing a new server and starting a new queen
|
||||
# session refreshes the catalog.
|
||||
mcp_tool_catalog: dict[str, list[dict[str, Any]]] = {}
|
||||
tools_by_name = {t.name: t for t in queen_tools}
|
||||
for server_name, tool_names in mcp_server_tools_map.items():
|
||||
server_entries: list[dict[str, Any]] = []
|
||||
for tool_name in sorted(tool_names):
|
||||
tool = tools_by_name.get(tool_name)
|
||||
if tool is None:
|
||||
continue
|
||||
server_entries.append(
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": tool.parameters,
|
||||
}
|
||||
)
|
||||
mcp_tool_catalog[server_name] = server_entries
|
||||
# All queens share one MCP registry, so the catalog is a manager-level
|
||||
# fact; stash it on the SessionManager so the Queen Tools route can
|
||||
# render the tool list even when no queen session is currently live.
|
||||
if session_manager is not None:
|
||||
try:
|
||||
session_manager._mcp_tool_catalog = mcp_tool_catalog # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("Queen: could not attach mcp_tool_catalog to manager", exc_info=True)
|
||||
|
||||
# ---- Global + queen-scoped memory ----------------------------------
|
||||
global_dir, queen_mem_dir = initialize_memory_scopes(session, phase_state)
|
||||
|
||||
@@ -476,12 +666,34 @@ async def create_queen(
|
||||
# ---- Default skill protocols -------------------------------------
|
||||
_queen_skill_dirs: list[str] = []
|
||||
try:
|
||||
from framework.config import QUEENS_DIR
|
||||
from framework.skills.discovery import ExtraScope
|
||||
from framework.skills.manager import SkillsManager, SkillsManagerConfig
|
||||
|
||||
# Pass project_root so user-scope skills (~/.hive/skills/, ~/.agents/skills/)
|
||||
# are discovered. Queen has no agent-specific project root, so we use its
|
||||
# own directory — the value just needs to be non-None to enable user-scope scanning.
|
||||
_queen_skills_mgr = SkillsManager(SkillsManagerConfig(project_root=Path(__file__).parent))
|
||||
# Queen home backs the queen-UI skill scope and the queen's
|
||||
# override store. The directory already exists (or is created on
|
||||
# demand by queen_profiles.py); treat a missing queen_name as the
|
||||
# default queen to preserve backwards compatibility.
|
||||
_queen_id = getattr(session, "queen_name", None) or "default"
|
||||
_queen_home = QUEENS_DIR / _queen_id
|
||||
_queen_skills_mgr = SkillsManager(
|
||||
SkillsManagerConfig(
|
||||
queen_id=_queen_id,
|
||||
queen_overrides_path=_queen_home / "skills_overrides.json",
|
||||
extra_scope_dirs=[
|
||||
ExtraScope(
|
||||
directory=_queen_home / "skills",
|
||||
label="queen_ui",
|
||||
priority=2,
|
||||
)
|
||||
],
|
||||
# No project_root — queen's project is her own identity;
|
||||
# user-scope discovery still runs without one.
|
||||
project_root=None,
|
||||
skip_community_discovery=True,
|
||||
interactive=False,
|
||||
)
|
||||
)
|
||||
_queen_skills_mgr.load()
|
||||
phase_state.protocols_prompt = _queen_skills_mgr.protocols_prompt
|
||||
phase_state.skills_catalog_prompt = _queen_skills_mgr.skills_catalog_prompt
|
||||
@@ -520,8 +732,37 @@ async def create_queen(
|
||||
|
||||
# ---- Recall on each real user turn --------------------------------
|
||||
async def _recall_on_user_input(event: AgentEvent) -> None:
|
||||
"""Re-select memories when real user input arrives."""
|
||||
await _refresh_recall_cache((event.data or {}).get("content", ""))
|
||||
"""On real user input, freeze the dynamic system-prompt suffix and
|
||||
refresh recall memories in the background.
|
||||
|
||||
The EventBus drops handlers that exceed 15s, so we MUST return fast.
|
||||
Recall selection queries the LLM and can take >15s on slow backends;
|
||||
we fire it off as a background task and re-stamp the suffix when it
|
||||
completes. The immediate refresh_dynamic_suffix call stamps a fresh
|
||||
timestamp using the last-known recall blocks so every iteration of
|
||||
THIS user turn sees a byte-stable prompt (prompt cache hits on the
|
||||
static block). Phase-change injections and worker-report injections
|
||||
go through agent_loop.inject_event() and do NOT publish
|
||||
CLIENT_INPUT_RECEIVED, so this runs exactly once per real user turn.
|
||||
"""
|
||||
query = (event.data or {}).get("content", "")
|
||||
# Immediate: stamp "now" into the frozen suffix, using whatever
|
||||
# recall blocks we already cached (from the prior turn or seeding).
|
||||
phase_state.refresh_dynamic_suffix()
|
||||
|
||||
async def _bg_refresh() -> None:
|
||||
try:
|
||||
await _refresh_recall_cache(query)
|
||||
# Re-stamp with the fresh recall blocks. Any iteration that
|
||||
# read the suffix before this point used the older recall
|
||||
# — acceptable; recall was already eventual-consistency.
|
||||
phase_state.refresh_dynamic_suffix()
|
||||
except Exception:
|
||||
logger.debug("background recall refresh failed", exc_info=True)
|
||||
|
||||
import asyncio as _asyncio
|
||||
|
||||
_asyncio.create_task(_bg_refresh())
|
||||
|
||||
session.event_bus.subscribe(
|
||||
[EventType.CLIENT_INPUT_RECEIVED],
|
||||
@@ -631,6 +872,9 @@ async def create_queen(
|
||||
except Exception:
|
||||
logger.debug("recall: initial seeding failed", exc_info=True)
|
||||
|
||||
# Freeze the dynamic suffix once so the first real turn sends a
|
||||
# byte-stable prompt even before CLIENT_INPUT_RECEIVED fires.
|
||||
phase_state.refresh_dynamic_suffix()
|
||||
return HookResult(system_prompt=phase_state.get_current_prompt())
|
||||
|
||||
# ---- Colony preparation -------------------------------------------
|
||||
@@ -730,7 +974,8 @@ async def create_queen(
|
||||
stream_id="queen",
|
||||
execution_id=session.id,
|
||||
dynamic_tools_provider=phase_state.get_current_tools,
|
||||
dynamic_prompt_provider=phase_state.get_current_prompt,
|
||||
dynamic_prompt_provider=phase_state.get_static_prompt,
|
||||
dynamic_prompt_suffix_provider=phase_state.get_dynamic_suffix,
|
||||
iteration_metadata_provider=lambda: {"phase": phase_state.phase},
|
||||
skills_catalog_prompt=phase_state.skills_catalog_prompt,
|
||||
protocols_prompt=phase_state.protocols_prompt,
|
||||
|
||||
@@ -0,0 +1,329 @@
|
||||
"""Per-colony MCP tool allowlist routes.
|
||||
|
||||
- GET /api/colony/{colony_name}/tools -- enumerate colony tool surface
|
||||
- PATCH /api/colony/{colony_name}/tools -- set or clear the allowlist
|
||||
|
||||
A colony's tool set is inherited from the queen that forked it, so the
|
||||
tool surface mirrors the queen's MCP servers. Lifecycle/synthetic tools
|
||||
are included for display only. MCP tools are grouped by origin server
|
||||
with per-tool ``enabled`` flags.
|
||||
|
||||
Semantics:
|
||||
|
||||
- ``enabled_mcp_tools: null`` → allow every MCP tool (default).
|
||||
- ``enabled_mcp_tools: []`` → allow no MCP tools (only lifecycle /
|
||||
synthetic pass through).
|
||||
- ``enabled_mcp_tools: [...]`` → only listed names pass.
|
||||
|
||||
The allowlist is persisted in a dedicated ``tools.json`` sidecar at
|
||||
``~/.hive/colonies/{colony_name}/tools.json``. Changes take effect on
|
||||
the *next* worker spawn. In-flight workers keep the tool list they
|
||||
booted with because workers have no dynamic tools provider today —
|
||||
mutating their tool set mid-turn would diverge from the list the LLM
|
||||
is already using.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from framework.host.colony_metadata import colony_metadata_path
|
||||
from framework.host.colony_tools_config import (
|
||||
load_colony_tools_config,
|
||||
update_colony_tools_config,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_SYNTHETIC_NAMES = {"ask_user"}
|
||||
|
||||
|
||||
def _synthetic_entries() -> list[dict[str, Any]]:
|
||||
try:
|
||||
from framework.agent_loop.internals.synthetic_tools import build_ask_user_tool
|
||||
|
||||
tool = build_ask_user_tool()
|
||||
return [
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"editable": False,
|
||||
}
|
||||
]
|
||||
except Exception:
|
||||
return [
|
||||
{
|
||||
"name": "ask_user",
|
||||
"description": "Pause and ask the user a structured question.",
|
||||
"editable": False,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def _colony_runtimes_for_name(manager: Any, colony_name: str) -> list[Any]:
|
||||
"""Return every live ColonyRuntime whose session is attached to ``colony_name``."""
|
||||
sessions = getattr(manager, "_sessions", None) or {}
|
||||
runtimes: list[Any] = []
|
||||
for session in sessions.values():
|
||||
if getattr(session, "colony_name", None) != colony_name:
|
||||
continue
|
||||
# Both ``session.colony`` (queen-side unified runtime) and
|
||||
# ``session.colony_runtime`` (legacy worker runtime) may carry
|
||||
# tools that need the allowlist applied. We update both.
|
||||
for attr in ("colony", "colony_runtime"):
|
||||
rt = getattr(session, attr, None)
|
||||
if rt is not None and rt not in runtimes:
|
||||
runtimes.append(rt)
|
||||
return runtimes
|
||||
|
||||
|
||||
async def _render_catalog(manager: Any, colony_name: str) -> dict[str, list[dict[str, Any]]]:
|
||||
"""Build a per-server tool catalog for this colony.
|
||||
|
||||
All colonies inherit the queen's MCP surface, so we reuse the
|
||||
manager-level ``_mcp_tool_catalog`` populated during queen boot.
|
||||
"""
|
||||
# If a live runtime exists and carries its own registry, prefer it —
|
||||
# it's authoritative (reflects any post-queen-boot MCP additions).
|
||||
for rt in _colony_runtimes_for_name(manager, colony_name):
|
||||
tools = getattr(rt, "_tools", None)
|
||||
if not tools:
|
||||
continue
|
||||
mcp_names = set(getattr(rt, "_mcp_tool_names_all", set()) or set())
|
||||
if not mcp_names:
|
||||
continue
|
||||
catalog: dict[str, list[dict[str, Any]]] = {"(mcp)": []}
|
||||
for tool in tools:
|
||||
name = getattr(tool, "name", None)
|
||||
if name in mcp_names:
|
||||
catalog["(mcp)"].append(
|
||||
{
|
||||
"name": name,
|
||||
"description": getattr(tool, "description", ""),
|
||||
"input_schema": getattr(tool, "parameters", {}),
|
||||
}
|
||||
)
|
||||
return catalog
|
||||
|
||||
# Otherwise fall back to the queen-level snapshot. Build it on demand
|
||||
# (off the event loop) when empty so the Tool Library works before
|
||||
# any queen has been started in this process.
|
||||
cached = getattr(manager, "_mcp_tool_catalog", None)
|
||||
if isinstance(cached, dict) and cached:
|
||||
return cached
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
from framework.server.queen_orchestrator import build_queen_tool_registry_bare
|
||||
|
||||
registry, built = await asyncio.to_thread(build_queen_tool_registry_bare)
|
||||
if manager is not None:
|
||||
manager._mcp_tool_catalog = built # type: ignore[attr-defined]
|
||||
manager._bootstrap_tool_registry = registry # type: ignore[attr-defined]
|
||||
return built
|
||||
except Exception:
|
||||
logger.warning("Colony tools: catalog bootstrap failed", exc_info=True)
|
||||
return {}
|
||||
|
||||
|
||||
def _lifecycle_entries_from_runtime(manager: Any, colony_name: str) -> list[dict[str, Any]]:
|
||||
"""Non-MCP tools currently registered on the colony runtime (if any).
|
||||
|
||||
When no live runtime is available we fall back to the bootstrap
|
||||
registry stashed on the manager by ``routes_queen_tools`` — it
|
||||
already has queen lifecycle tools registered, which are also the
|
||||
lifecycle tools colonies inherit at spawn time.
|
||||
"""
|
||||
out: list[dict[str, Any]] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
def _push(name: str, description: str) -> None:
|
||||
if not name or name in seen:
|
||||
return
|
||||
if name in _SYNTHETIC_NAMES:
|
||||
return
|
||||
seen.add(name)
|
||||
out.append({"name": name, "description": description, "editable": False})
|
||||
|
||||
runtimes = _colony_runtimes_for_name(manager, colony_name)
|
||||
if runtimes:
|
||||
for rt in runtimes:
|
||||
mcp_names = set(getattr(rt, "_mcp_tool_names_all", set()) or set())
|
||||
for tool in getattr(rt, "_tools", []) or []:
|
||||
name = getattr(tool, "name", None)
|
||||
if name in mcp_names:
|
||||
continue
|
||||
_push(name, getattr(tool, "description", ""))
|
||||
else:
|
||||
# No live runtime — derive from the bootstrap registry.
|
||||
from framework.server.routes_queen_tools import _lifecycle_entries_without_session
|
||||
|
||||
catalog = getattr(manager, "_mcp_tool_catalog", {}) or {}
|
||||
mcp_names: set[str] = set()
|
||||
for entries in catalog.values():
|
||||
for entry in entries:
|
||||
if entry.get("name"):
|
||||
mcp_names.add(entry["name"])
|
||||
out.extend(_lifecycle_entries_without_session(manager, mcp_names))
|
||||
return out
|
||||
return sorted(out, key=lambda e: e["name"])
|
||||
|
||||
|
||||
def _render_servers(
|
||||
catalog: dict[str, list[dict[str, Any]]],
|
||||
enabled_mcp_tools: list[str] | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
allowed: set[str] | None = None if enabled_mcp_tools is None else set(enabled_mcp_tools)
|
||||
servers: list[dict[str, Any]] = []
|
||||
for name in sorted(catalog):
|
||||
tools = []
|
||||
for entry in catalog[name]:
|
||||
tool_name = entry.get("name")
|
||||
tools.append(
|
||||
{
|
||||
"name": tool_name,
|
||||
"description": entry.get("description", ""),
|
||||
"input_schema": entry.get("input_schema", {}),
|
||||
"enabled": True if allowed is None else tool_name in allowed,
|
||||
}
|
||||
)
|
||||
servers.append({"name": name, "tools": tools})
|
||||
return servers
|
||||
|
||||
|
||||
async def handle_get_tools(request: web.Request) -> web.Response:
|
||||
"""GET /api/colony/{colony_name}/tools."""
|
||||
colony_name = request.match_info["colony_name"]
|
||||
if not colony_metadata_path(colony_name).exists():
|
||||
return web.json_response({"error": f"Colony '{colony_name}' not found"}, status=404)
|
||||
|
||||
manager = request.app.get("manager")
|
||||
# Allowlist now lives in a dedicated tools.json sidecar; helper
|
||||
# migrates any legacy metadata.json field on first read.
|
||||
enabled = load_colony_tools_config(colony_name)
|
||||
|
||||
catalog = await _render_catalog(manager, colony_name)
|
||||
stale = not catalog
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"colony_name": colony_name,
|
||||
"enabled_mcp_tools": enabled,
|
||||
"stale": stale,
|
||||
"lifecycle": _lifecycle_entries_from_runtime(manager, colony_name),
|
||||
"synthetic": _synthetic_entries(),
|
||||
"mcp_servers": _render_servers(catalog, enabled),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def handle_patch_tools(request: web.Request) -> web.Response:
|
||||
"""PATCH /api/colony/{colony_name}/tools."""
|
||||
colony_name = request.match_info["colony_name"]
|
||||
if not colony_metadata_path(colony_name).exists():
|
||||
return web.json_response({"error": f"Colony '{colony_name}' not found"}, status=404)
|
||||
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return web.json_response({"error": "Invalid JSON body"}, status=400)
|
||||
if not isinstance(body, dict) or "enabled_mcp_tools" not in body:
|
||||
return web.json_response(
|
||||
{"error": "Body must be an object with an 'enabled_mcp_tools' field"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
enabled = body["enabled_mcp_tools"]
|
||||
if enabled is not None:
|
||||
if not isinstance(enabled, list) or not all(isinstance(x, str) for x in enabled):
|
||||
return web.json_response(
|
||||
{"error": "'enabled_mcp_tools' must be null or a list of strings"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
manager = request.app.get("manager")
|
||||
|
||||
# Validate names against the known MCP catalog — lifts the same
|
||||
# typo-catching guarantee we already offer on queen tools.
|
||||
catalog = await _render_catalog(manager, colony_name)
|
||||
known: set[str] = {e.get("name") for entries in catalog.values() for e in entries if e.get("name")}
|
||||
if enabled is not None and known:
|
||||
unknown = sorted(set(enabled) - known)
|
||||
if unknown:
|
||||
return web.json_response(
|
||||
{"error": "Unknown MCP tool name(s)", "unknown": unknown},
|
||||
status=400,
|
||||
)
|
||||
|
||||
# Persist — tools.json sidecar, not metadata.json. Missing directory
|
||||
# is already guarded by the 404 check above.
|
||||
try:
|
||||
update_colony_tools_config(colony_name, enabled)
|
||||
except FileNotFoundError:
|
||||
return web.json_response({"error": f"Colony '{colony_name}' not found"}, status=404)
|
||||
|
||||
# Update any live runtimes so the NEXT worker spawn reflects the change.
|
||||
# We do NOT rebuild in-flight workers' tool lists (see module docstring).
|
||||
refreshed = 0
|
||||
for rt in _colony_runtimes_for_name(manager, colony_name):
|
||||
setter = getattr(rt, "set_tool_allowlist", None)
|
||||
if callable(setter):
|
||||
try:
|
||||
setter(enabled)
|
||||
refreshed += 1
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Colony tools: set_tool_allowlist failed on runtime for %s",
|
||||
colony_name,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Colony tools: colony=%s allowlist=%s refreshed_runtimes=%d",
|
||||
colony_name,
|
||||
"null" if enabled is None else f"{len(enabled)} tool(s)",
|
||||
refreshed,
|
||||
)
|
||||
return web.json_response(
|
||||
{
|
||||
"colony_name": colony_name,
|
||||
"enabled_mcp_tools": enabled,
|
||||
"refreshed_runtimes": refreshed,
|
||||
"note": "Changes apply to the next worker spawn. Running workers keep their booted tool list.",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def handle_list_colonies(request: web.Request) -> web.Response:
|
||||
"""GET /api/colonies — list colonies with their tool allowlist status.
|
||||
|
||||
Powers the Tool Library page's colony picker.
|
||||
"""
|
||||
from framework.host.colony_metadata import list_colony_names, load_colony_metadata
|
||||
|
||||
colonies: list[dict[str, Any]] = []
|
||||
for name in list_colony_names():
|
||||
meta = load_colony_metadata(name)
|
||||
# Provenance stays in metadata.json; allowlist lives in tools.json.
|
||||
allowlist = load_colony_tools_config(name)
|
||||
colonies.append(
|
||||
{
|
||||
"name": name,
|
||||
"queen_name": meta.get("queen_name"),
|
||||
"created_at": meta.get("created_at"),
|
||||
"has_allowlist": allowlist is not None,
|
||||
"enabled_count": len(allowlist) if isinstance(allowlist, list) else None,
|
||||
}
|
||||
)
|
||||
return web.json_response({"colonies": colonies})
|
||||
|
||||
|
||||
def register_routes(app: web.Application) -> None:
|
||||
"""Register per-colony tool routes."""
|
||||
app.router.add_get("/api/colonies/tools-index", handle_list_colonies)
|
||||
app.router.add_get("/api/colony/{colony_name}/tools", handle_get_tools)
|
||||
app.router.add_patch("/api/colony/{colony_name}/tools", handle_patch_tools)
|
||||
@@ -1577,6 +1577,39 @@ async def fork_session_into_colony(
|
||||
}
|
||||
metadata_path.write_text(json.dumps(metadata, indent=2, ensure_ascii=False), encoding="utf-8")
|
||||
|
||||
# ── 4a. Inherit the queen's tool allowlist into the colony ───
|
||||
# A colony forked from a curated queen should start with the same
|
||||
# tool surface (otherwise the colony silently falls back to its own
|
||||
# "allow every MCP tool" default, undoing the parent's curation).
|
||||
# We copy the queen's LIVE effective allowlist so the snapshot
|
||||
# reflects whatever was in force the moment the user clicked "Create
|
||||
# Colony". Users can further narrow the colony via the Tool Library.
|
||||
# Skip the write when the queen is on allow-all (None) so the colony
|
||||
# keeps the same semantics without creating an inert sidecar.
|
||||
try:
|
||||
queen_enabled = getattr(
|
||||
getattr(session, "phase_state", None),
|
||||
"enabled_mcp_tools",
|
||||
None,
|
||||
)
|
||||
if isinstance(queen_enabled, list):
|
||||
from framework.host.colony_tools_config import update_colony_tools_config
|
||||
|
||||
update_colony_tools_config(colony_name, list(queen_enabled))
|
||||
logger.info(
|
||||
"Inherited queen allowlist into colony '%s' (%d tools)",
|
||||
colony_name,
|
||||
len(queen_enabled),
|
||||
)
|
||||
except Exception:
|
||||
# Inheritance is best-effort — don't let a tools.json hiccup
|
||||
# abort colony creation.
|
||||
logger.warning(
|
||||
"Failed to inherit queen allowlist into colony '%s'",
|
||||
colony_name,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# ── 5. Update source queen session meta.json ─────────────────
|
||||
# Link the originating session back to the colony for discovery.
|
||||
source_meta_path = source_queen_dir / "meta.json"
|
||||
|
||||
@@ -0,0 +1,291 @@
|
||||
"""MCP server registration routes.
|
||||
|
||||
Thin HTTP wrapper around ``MCPRegistry`` so the frontend can add, remove,
|
||||
enable, and health-check user-registered MCP servers. The CLI path
|
||||
(``hive mcp add`` / ``hive mcp remove`` / etc.) is unchanged.
|
||||
|
||||
- GET /api/mcp/servers -- list installed servers
|
||||
- POST /api/mcp/servers -- register a local server
|
||||
- DELETE /api/mcp/servers/{name} -- remove a local server
|
||||
- POST /api/mcp/servers/{name}/enable -- enable a server
|
||||
- POST /api/mcp/servers/{name}/disable -- disable a server
|
||||
- POST /api/mcp/servers/{name}/health -- probe server health
|
||||
|
||||
New servers take effect on the *next* queen session start. Existing live
|
||||
queen sessions keep the tool list they booted with to avoid mid-turn
|
||||
cache invalidation. The ``add`` response hints at this explicitly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from framework.loader.mcp_errors import MCPError
|
||||
from framework.loader.mcp_registry import MCPRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_VALID_TRANSPORTS = {"stdio", "http", "sse", "unix"}
|
||||
|
||||
|
||||
def _registry() -> MCPRegistry:
|
||||
# MCPRegistry is a thin wrapper around ~/.hive/mcp_registry/installed.json
|
||||
# so instantiation is cheap — no need to cache on app["..."].
|
||||
reg = MCPRegistry()
|
||||
reg.initialize()
|
||||
return reg
|
||||
|
||||
|
||||
def _package_builtin_servers() -> list[dict[str, Any]]:
|
||||
"""Return the package-baked queen MCP servers from ``queen/mcp_servers.json``.
|
||||
|
||||
Those servers are loaded directly by ``ToolRegistry.load_mcp_config``
|
||||
at queen boot and never go through ``MCPRegistry.list_installed``,
|
||||
so the raw registry view shows them as missing. Surface them here so
|
||||
the Tool Library reflects what the queen actually talks to.
|
||||
|
||||
Entries carry ``source: "built-in"`` and are NOT removable / toggleable
|
||||
— editing them requires changing the repo file.
|
||||
"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import framework.agents.queen as _queen_pkg
|
||||
|
||||
path = Path(_queen_pkg.__file__).parent / "mcp_servers.json"
|
||||
if not path.exists():
|
||||
return []
|
||||
try:
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return []
|
||||
|
||||
out: list[dict[str, Any]] = []
|
||||
for name, cfg in data.items():
|
||||
if not isinstance(cfg, dict):
|
||||
continue
|
||||
out.append(
|
||||
{
|
||||
"name": name,
|
||||
"source": "built-in",
|
||||
"transport": cfg.get("transport", "stdio"),
|
||||
"description": cfg.get("description", "") or "",
|
||||
"enabled": True,
|
||||
"last_health_status": None,
|
||||
"last_error": None,
|
||||
"last_health_check_at": None,
|
||||
"tool_count": None,
|
||||
"removable": False,
|
||||
}
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def _server_to_summary(entry: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Shape an installed.json entry for API responses.
|
||||
|
||||
Strips the full manifest body (which can be large) but keeps the tool
|
||||
list if the manifest already embeds one (happens for registry-installed
|
||||
servers). Users with ``source: "local"`` only get a tool list after
|
||||
running a health check.
|
||||
"""
|
||||
manifest = entry.get("manifest") or {}
|
||||
tools = manifest.get("tools") if isinstance(manifest, dict) else None
|
||||
if not isinstance(tools, list):
|
||||
tools = None
|
||||
return {
|
||||
"name": entry.get("name"),
|
||||
"source": entry.get("source"),
|
||||
"transport": entry.get("transport"),
|
||||
"description": (manifest.get("description") if isinstance(manifest, dict) else None) or "",
|
||||
"enabled": entry.get("enabled", True),
|
||||
"last_health_status": entry.get("last_health_status"),
|
||||
"last_error": entry.get("last_error"),
|
||||
"last_health_check_at": entry.get("last_health_check_at"),
|
||||
"tool_count": (len(tools) if tools is not None else None),
|
||||
}
|
||||
|
||||
|
||||
def _mcp_error_response(exc: MCPError, *, default_status: int = 400) -> web.Response:
|
||||
return web.json_response(
|
||||
{
|
||||
"error": exc.what,
|
||||
"code": exc.code.value,
|
||||
"what": exc.what,
|
||||
"why": exc.why,
|
||||
"fix": exc.fix,
|
||||
},
|
||||
status=default_status,
|
||||
)
|
||||
|
||||
|
||||
async def handle_list_servers(request: web.Request) -> web.Response:
|
||||
"""GET /api/mcp/servers — list every server the queen actually uses.
|
||||
|
||||
Merges two sources:
|
||||
|
||||
- ``MCPRegistry.list_installed()`` — servers registered via
|
||||
``hive mcp add`` / the ``/api/mcp/servers`` POST route, stored in
|
||||
``~/.hive/mcp_registry/installed.json``. These carry
|
||||
``source: "local"`` (user-added) or ``source: "registry"``
|
||||
(installed from the remote registry).
|
||||
- Repo-baked queen servers from
|
||||
``core/framework/agents/queen/mcp_servers.json``. These are loaded
|
||||
directly by the queen's ``ToolRegistry`` at boot and never touch
|
||||
``MCPRegistry``; we surface them here so the UI reflects what the
|
||||
queen really talks to. They are not removable from the UI because
|
||||
editing them requires changing the repo.
|
||||
|
||||
If a name collides between the two sources, the registry entry wins
|
||||
because that's the one the user has customized.
|
||||
"""
|
||||
reg = _registry()
|
||||
registry_entries = [_server_to_summary(e) for e in reg.list_installed()]
|
||||
seen_names = {e.get("name") for e in registry_entries}
|
||||
|
||||
package_entries = [e for e in _package_builtin_servers() if e.get("name") not in seen_names]
|
||||
|
||||
servers = [*package_entries, *registry_entries]
|
||||
return web.json_response({"servers": servers})
|
||||
|
||||
|
||||
async def handle_add_server(request: web.Request) -> web.Response:
|
||||
"""POST /api/mcp/servers — register a local MCP server.
|
||||
|
||||
Body mirrors ``MCPRegistry.add_local`` args:
|
||||
|
||||
::
|
||||
|
||||
{
|
||||
"name": "my-tool",
|
||||
"transport": "stdio" | "http" | "sse" | "unix",
|
||||
"command": "...", "args": [...], "env": {...}, "cwd": "...",
|
||||
"url": "...", "headers": {...},
|
||||
"socket_path": "...",
|
||||
"description": "..."
|
||||
}
|
||||
"""
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return web.json_response({"error": "Invalid JSON body"}, status=400)
|
||||
if not isinstance(body, dict):
|
||||
return web.json_response({"error": "Body must be a JSON object"}, status=400)
|
||||
|
||||
name = body.get("name")
|
||||
transport = body.get("transport")
|
||||
if not isinstance(name, str) or not name.strip():
|
||||
return web.json_response({"error": "'name' is required"}, status=400)
|
||||
if transport not in _VALID_TRANSPORTS:
|
||||
return web.json_response(
|
||||
{"error": f"'transport' must be one of {sorted(_VALID_TRANSPORTS)}"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
reg = _registry()
|
||||
try:
|
||||
entry = reg.add_local(
|
||||
name=name.strip(),
|
||||
transport=transport,
|
||||
command=body.get("command"),
|
||||
args=body.get("args"),
|
||||
env=body.get("env"),
|
||||
cwd=body.get("cwd"),
|
||||
url=body.get("url"),
|
||||
headers=body.get("headers"),
|
||||
socket_path=body.get("socket_path"),
|
||||
description=body.get("description", ""),
|
||||
)
|
||||
except MCPError as exc:
|
||||
status = 409 if "already exists" in exc.what else 400
|
||||
return _mcp_error_response(exc, default_status=status)
|
||||
except Exception as exc:
|
||||
logger.exception("MCP add_local failed for %r", name)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
summary = _server_to_summary({"name": name, **entry})
|
||||
return web.json_response(
|
||||
{
|
||||
"server": summary,
|
||||
"hint": "Start a new queen session to use this server's tools.",
|
||||
},
|
||||
status=201,
|
||||
)
|
||||
|
||||
|
||||
async def handle_remove_server(request: web.Request) -> web.Response:
|
||||
"""DELETE /api/mcp/servers/{name} — remove a local server."""
|
||||
name = request.match_info["name"]
|
||||
reg = _registry()
|
||||
|
||||
existing = reg.get_server(name)
|
||||
if existing is None:
|
||||
return web.json_response({"error": f"Server '{name}' not installed"}, status=404)
|
||||
if existing.get("source") != "local":
|
||||
return web.json_response(
|
||||
{
|
||||
"error": f"Server '{name}' is a built-in; it cannot be removed from the UI.",
|
||||
},
|
||||
status=400,
|
||||
)
|
||||
|
||||
try:
|
||||
reg.remove(name)
|
||||
except MCPError as exc:
|
||||
return _mcp_error_response(exc, default_status=404)
|
||||
return web.json_response({"removed": name})
|
||||
|
||||
|
||||
async def handle_set_enabled(request: web.Request, *, enabled: bool) -> web.Response:
|
||||
name = request.match_info["name"]
|
||||
reg = _registry()
|
||||
try:
|
||||
if enabled:
|
||||
reg.enable(name)
|
||||
else:
|
||||
reg.disable(name)
|
||||
except MCPError as exc:
|
||||
return _mcp_error_response(exc, default_status=404)
|
||||
return web.json_response({"name": name, "enabled": enabled})
|
||||
|
||||
|
||||
async def handle_enable(request: web.Request) -> web.Response:
|
||||
"""POST /api/mcp/servers/{name}/enable."""
|
||||
return await handle_set_enabled(request, enabled=True)
|
||||
|
||||
|
||||
async def handle_disable(request: web.Request) -> web.Response:
|
||||
"""POST /api/mcp/servers/{name}/disable."""
|
||||
return await handle_set_enabled(request, enabled=False)
|
||||
|
||||
|
||||
async def handle_health(request: web.Request) -> web.Response:
|
||||
"""POST /api/mcp/servers/{name}/health — probe one server."""
|
||||
name = request.match_info["name"]
|
||||
reg = _registry()
|
||||
try:
|
||||
# MCPRegistry.health_check blocks on subprocess IO — run it off
|
||||
# the event loop so the HTTP worker stays responsive.
|
||||
import asyncio
|
||||
|
||||
result = await asyncio.to_thread(reg.health_check, name)
|
||||
except MCPError as exc:
|
||||
return _mcp_error_response(exc, default_status=404)
|
||||
except Exception as exc:
|
||||
logger.exception("MCP health_check failed for %r", name)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
return web.json_response(result)
|
||||
|
||||
|
||||
def register_routes(app: web.Application) -> None:
|
||||
"""Register MCP server CRUD routes."""
|
||||
app.router.add_get("/api/mcp/servers", handle_list_servers)
|
||||
app.router.add_post("/api/mcp/servers", handle_add_server)
|
||||
app.router.add_delete("/api/mcp/servers/{name}", handle_remove_server)
|
||||
app.router.add_post("/api/mcp/servers/{name}/enable", handle_enable)
|
||||
app.router.add_post("/api/mcp/servers/{name}/disable", handle_disable)
|
||||
app.router.add_post("/api/mcp/servers/{name}/health", handle_health)
|
||||
@@ -0,0 +1,506 @@
|
||||
"""Per-queen MCP tool allowlist routes.
|
||||
|
||||
- GET /api/queen/{queen_id}/tools -- enumerate the queen's tool surface
|
||||
- PATCH /api/queen/{queen_id}/tools -- set or clear the MCP tool allowlist
|
||||
|
||||
Lifecycle and synthetic tools (``ask_user``) are always part of the queen's
|
||||
surface in INDEPENDENT mode and are returned with ``editable: false``. MCP
|
||||
tools are grouped by origin server and carry per-tool ``enabled`` flags.
|
||||
|
||||
The allowlist is persisted in a dedicated ``tools.json`` sidecar at
|
||||
``~/.hive/agents/queens/{queen_id}/tools.json``:
|
||||
|
||||
- ``null`` / missing file -> "allow every MCP tool" (default)
|
||||
- ``[]`` -> explicitly disable every MCP tool
|
||||
- ``["foo", "bar"]`` -> only these MCP tools pass through to the LLM
|
||||
|
||||
Filtering happens in ``QueenPhaseState.rebuild_independent_filter`` so the
|
||||
LLM prompt cache stays warm between saves.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from framework.agents.queen.queen_profiles import (
|
||||
ensure_default_queens,
|
||||
load_queen_profile,
|
||||
)
|
||||
from framework.agents.queen.queen_tools_config import (
|
||||
delete_queen_tools_config,
|
||||
load_queen_tools_config,
|
||||
tools_config_exists,
|
||||
update_queen_tools_config,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_SYNTHETIC_NAMES = {"ask_user"}
|
||||
|
||||
|
||||
async def _ensure_manager_catalog(manager: Any) -> dict[str, list[dict[str, Any]]]:
|
||||
"""Return the cached MCP tool catalog, building it on first call.
|
||||
|
||||
``queen_orchestrator.create_queen`` populates ``_mcp_tool_catalog`` on
|
||||
every queen boot. On a fresh backend process the user may open the
|
||||
Tool Library before any queen session has started, so the catalog is
|
||||
empty. In that case we build one from the shared MCP config; the
|
||||
first call pays an MCP-subprocess-spawn cost, subsequent calls are
|
||||
cache hits. The build runs off the event loop via asyncio.to_thread
|
||||
so the HTTP worker stays responsive while MCP servers initialize.
|
||||
"""
|
||||
if manager is None:
|
||||
return {}
|
||||
catalog = getattr(manager, "_mcp_tool_catalog", None)
|
||||
if isinstance(catalog, dict) and catalog:
|
||||
return catalog
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
from framework.server.queen_orchestrator import build_queen_tool_registry_bare
|
||||
|
||||
registry, built = await asyncio.to_thread(build_queen_tool_registry_bare)
|
||||
manager._mcp_tool_catalog = built # type: ignore[attr-defined]
|
||||
manager._bootstrap_tool_registry = registry # type: ignore[attr-defined]
|
||||
return built
|
||||
except Exception:
|
||||
logger.warning("Tool catalog bootstrap failed", exc_info=True)
|
||||
return {}
|
||||
|
||||
|
||||
def _lifecycle_entries_without_session(
|
||||
manager: Any,
|
||||
mcp_names: set[str],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Derive lifecycle tool names from the registry even without a session.
|
||||
|
||||
We register queen lifecycle tools against a temporary registry using a
|
||||
minimal stub, then subtract the MCP-origin set and the synthetic set.
|
||||
The result matches what the queen sees at runtime (minus context-
|
||||
specific variants).
|
||||
"""
|
||||
registry = getattr(manager, "_bootstrap_tool_registry", None)
|
||||
# If the bootstrap registry exists but doesn't carry lifecycle tools
|
||||
# yet, register them now.
|
||||
if registry is not None and not getattr(registry, "_lifecycle_bootstrap_done", False):
|
||||
try:
|
||||
from types import SimpleNamespace
|
||||
|
||||
from framework.tools.queen_lifecycle_tools import register_queen_lifecycle_tools
|
||||
|
||||
stub_session = SimpleNamespace(
|
||||
id="tool-library-bootstrap",
|
||||
colony_runtime=None,
|
||||
event_bus=None,
|
||||
worker_path=None,
|
||||
phase_state=None,
|
||||
llm=None,
|
||||
)
|
||||
register_queen_lifecycle_tools(
|
||||
registry,
|
||||
session=stub_session,
|
||||
session_id=stub_session.id,
|
||||
session_manager=None,
|
||||
manager_session_id=stub_session.id,
|
||||
phase_state=None,
|
||||
)
|
||||
registry._lifecycle_bootstrap_done = True # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("lifecycle bootstrap failed", exc_info=True)
|
||||
|
||||
if registry is None:
|
||||
return []
|
||||
|
||||
out: list[dict[str, Any]] = []
|
||||
for name, tool in sorted(registry.get_tools().items()):
|
||||
if name in mcp_names or name in _SYNTHETIC_NAMES:
|
||||
continue
|
||||
out.append(
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"editable": False,
|
||||
}
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def _synthetic_entries() -> list[dict[str, Any]]:
|
||||
"""Return display metadata for synthetic tools injected by the agent loop.
|
||||
|
||||
Kept behind a lazy import so test harnesses that don't wire the agent
|
||||
loop can still hit this route without blowing up.
|
||||
"""
|
||||
try:
|
||||
from framework.agent_loop.internals.synthetic_tools import build_ask_user_tool
|
||||
|
||||
tool = build_ask_user_tool()
|
||||
return [
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"editable": False,
|
||||
}
|
||||
]
|
||||
except Exception:
|
||||
return [
|
||||
{
|
||||
"name": "ask_user",
|
||||
"description": "Pause and ask the user a structured question.",
|
||||
"editable": False,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def _live_queen_session(manager: Any, queen_id: str) -> Any:
|
||||
"""Return any live DM session owned by this queen, or ``None``."""
|
||||
sessions = getattr(manager, "_sessions", None) or {}
|
||||
for session in sessions.values():
|
||||
if getattr(session, "queen_name", None) != queen_id:
|
||||
continue
|
||||
# Prefer DM (non-colony) sessions
|
||||
if getattr(session, "colony_runtime", None) is None:
|
||||
return session
|
||||
return None
|
||||
|
||||
|
||||
def _render_mcp_servers(
|
||||
*,
|
||||
mcp_tool_names_by_server: dict[str, list[dict[str, Any]]],
|
||||
enabled_mcp_tools: list[str] | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Shape the mcp_tool_catalog entries for the API response."""
|
||||
allowed: set[str] | None = None if enabled_mcp_tools is None else set(enabled_mcp_tools)
|
||||
servers: list[dict[str, Any]] = []
|
||||
for server_name in sorted(mcp_tool_names_by_server):
|
||||
entries = mcp_tool_names_by_server[server_name]
|
||||
tools = []
|
||||
for entry in entries:
|
||||
name = entry.get("name")
|
||||
enabled = True if allowed is None else name in allowed
|
||||
tools.append(
|
||||
{
|
||||
"name": name,
|
||||
"description": entry.get("description", ""),
|
||||
"input_schema": entry.get("input_schema", {}),
|
||||
"enabled": enabled,
|
||||
}
|
||||
)
|
||||
servers.append({"name": server_name, "tools": tools})
|
||||
return servers
|
||||
|
||||
|
||||
def _catalog_from_live_session(session: Any) -> dict[str, list[dict[str, Any]]]:
|
||||
"""Rebuild a per-server tool catalog from a live queen session.
|
||||
|
||||
The session's registry is authoritative — this reflects any hot-added
|
||||
MCP servers since the manager-level snapshot was cached.
|
||||
"""
|
||||
registry = getattr(session, "_queen_tool_registry", None)
|
||||
if registry is None:
|
||||
# session._queen_tools_by_name is a stash from create_queen; we
|
||||
# only have registry via the tools list, so reconstruct from the
|
||||
# phase state instead.
|
||||
phase_state = getattr(session, "phase_state", None)
|
||||
if phase_state is None:
|
||||
return {}
|
||||
mcp_names = getattr(phase_state, "mcp_tool_names_all", set()) or set()
|
||||
independent_tools = getattr(phase_state, "independent_tools", []) or []
|
||||
result: dict[str, list[dict[str, Any]]] = {"MCP Tools": []}
|
||||
for tool in independent_tools:
|
||||
if tool.name not in mcp_names:
|
||||
continue
|
||||
result["MCP Tools"].append(
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": tool.parameters,
|
||||
}
|
||||
)
|
||||
return result if result["MCP Tools"] else {}
|
||||
|
||||
server_map = getattr(registry, "_mcp_server_tools", {}) or {}
|
||||
tools_by_name = {t.name: t for t in registry.get_tools().values()}
|
||||
catalog: dict[str, list[dict[str, Any]]] = {}
|
||||
for server_name, tool_names in server_map.items():
|
||||
entries: list[dict[str, Any]] = []
|
||||
for name in sorted(tool_names):
|
||||
tool = tools_by_name.get(name)
|
||||
if tool is None:
|
||||
continue
|
||||
entries.append(
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": tool.parameters,
|
||||
}
|
||||
)
|
||||
catalog[server_name] = entries
|
||||
return catalog
|
||||
|
||||
|
||||
def _lifecycle_entries(
|
||||
*,
|
||||
session: Any,
|
||||
mcp_tool_names_all: set[str],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Lifecycle tools = independent_tools minus MCP-origin minus synthetic.
|
||||
|
||||
We compute this from a live session when available so the list exactly
|
||||
matches what the queen actually sees on her next turn.
|
||||
"""
|
||||
if session is None:
|
||||
return []
|
||||
phase_state = getattr(session, "phase_state", None)
|
||||
if phase_state is None:
|
||||
return []
|
||||
result: list[dict[str, Any]] = []
|
||||
for tool in getattr(phase_state, "independent_tools", []) or []:
|
||||
if tool.name in mcp_tool_names_all:
|
||||
continue
|
||||
if tool.name in _SYNTHETIC_NAMES:
|
||||
continue
|
||||
result.append(
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"editable": False,
|
||||
}
|
||||
)
|
||||
return sorted(result, key=lambda x: x["name"])
|
||||
|
||||
|
||||
async def handle_get_tools(request: web.Request) -> web.Response:
|
||||
"""GET /api/queen/{queen_id}/tools — enumerate tool surface for the UI."""
|
||||
queen_id = request.match_info["queen_id"]
|
||||
ensure_default_queens()
|
||||
try:
|
||||
load_queen_profile(queen_id)
|
||||
except FileNotFoundError:
|
||||
return web.json_response({"error": f"Queen '{queen_id}' not found"}, status=404)
|
||||
|
||||
manager = request.app.get("manager")
|
||||
session = _live_queen_session(manager, queen_id) if manager is not None else None
|
||||
|
||||
# Prefer a live session's registry for freshness. Otherwise use (or
|
||||
# build on demand) the manager-level catalog so the Tool Library works
|
||||
# even before any queen has been started in this process.
|
||||
if session is not None:
|
||||
catalog = _catalog_from_live_session(session)
|
||||
else:
|
||||
catalog = await _ensure_manager_catalog(manager)
|
||||
stale = not catalog
|
||||
|
||||
mcp_tool_names_all: set[str] = set()
|
||||
for entries in catalog.values():
|
||||
for entry in entries:
|
||||
if entry.get("name"):
|
||||
mcp_tool_names_all.add(entry["name"])
|
||||
|
||||
if session is not None:
|
||||
lifecycle = _lifecycle_entries(
|
||||
session=session,
|
||||
mcp_tool_names_all=mcp_tool_names_all,
|
||||
)
|
||||
else:
|
||||
lifecycle = _lifecycle_entries_without_session(manager, mcp_tool_names_all)
|
||||
|
||||
# Allowlist lives in the dedicated tools.json sidecar; helper
|
||||
# migrates legacy profile.yaml field on first read, and falls back
|
||||
# to the role-based default when no sidecar exists.
|
||||
enabled_mcp_tools = load_queen_tools_config(queen_id, mcp_catalog=catalog)
|
||||
is_role_default = not tools_config_exists(queen_id)
|
||||
|
||||
response = {
|
||||
"queen_id": queen_id,
|
||||
"enabled_mcp_tools": enabled_mcp_tools,
|
||||
"is_role_default": is_role_default,
|
||||
"stale": stale,
|
||||
"lifecycle": lifecycle,
|
||||
"synthetic": _synthetic_entries(),
|
||||
"mcp_servers": _render_mcp_servers(
|
||||
mcp_tool_names_by_server=catalog,
|
||||
enabled_mcp_tools=enabled_mcp_tools,
|
||||
),
|
||||
}
|
||||
return web.json_response(response)
|
||||
|
||||
|
||||
async def handle_patch_tools(request: web.Request) -> web.Response:
|
||||
"""PATCH /api/queen/{queen_id}/tools — persist the MCP tool allowlist.
|
||||
|
||||
Body: ``{"enabled_mcp_tools": null | string[]}``.
|
||||
|
||||
- ``null`` resets to "allow every MCP tool" (default).
|
||||
- A list is validated against the known MCP catalog; unknown names
|
||||
are rejected with 400 so the frontend catches typos.
|
||||
"""
|
||||
queen_id = request.match_info["queen_id"]
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return web.json_response({"error": "Invalid JSON body"}, status=400)
|
||||
if not isinstance(body, dict) or "enabled_mcp_tools" not in body:
|
||||
return web.json_response(
|
||||
{"error": "Body must be an object with an 'enabled_mcp_tools' field"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
enabled = body["enabled_mcp_tools"]
|
||||
if enabled is not None:
|
||||
if not isinstance(enabled, list) or not all(isinstance(x, str) for x in enabled):
|
||||
return web.json_response(
|
||||
{"error": "'enabled_mcp_tools' must be null or a list of strings"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
ensure_default_queens()
|
||||
try:
|
||||
load_queen_profile(queen_id)
|
||||
except FileNotFoundError:
|
||||
return web.json_response({"error": f"Queen '{queen_id}' not found"}, status=404)
|
||||
|
||||
# Validate names against the known MCP tool catalog. We prefer a live
|
||||
# session's registry for the most up-to-date set, then fall back to
|
||||
# the manager-level snapshot (building it on demand if absent).
|
||||
manager = request.app.get("manager")
|
||||
session = _live_queen_session(manager, queen_id) if manager is not None else None
|
||||
if session is not None:
|
||||
catalog = _catalog_from_live_session(session)
|
||||
else:
|
||||
catalog = await _ensure_manager_catalog(manager)
|
||||
known_names: set[str] = set()
|
||||
for entries in catalog.values():
|
||||
for entry in entries:
|
||||
if entry.get("name"):
|
||||
known_names.add(entry["name"])
|
||||
|
||||
if enabled is not None and known_names:
|
||||
unknown = sorted(set(enabled) - known_names)
|
||||
if unknown:
|
||||
return web.json_response(
|
||||
{"error": "Unknown MCP tool name(s)", "unknown": unknown},
|
||||
status=400,
|
||||
)
|
||||
|
||||
# Persist — tools.json sidecar, not profile.yaml.
|
||||
try:
|
||||
update_queen_tools_config(queen_id, enabled)
|
||||
except FileNotFoundError:
|
||||
return web.json_response({"error": f"Queen '{queen_id}' not found"}, status=404)
|
||||
|
||||
# Hot-reload every live DM session for this queen. The filter memo is
|
||||
# rebuilt so the very next turn sees the new allowlist without a
|
||||
# session restart, and the prompt cache is invalidated exactly once.
|
||||
refreshed = 0
|
||||
sessions = getattr(manager, "_sessions", None) or {}
|
||||
for sess in sessions.values():
|
||||
if getattr(sess, "queen_name", None) != queen_id:
|
||||
continue
|
||||
phase_state = getattr(sess, "phase_state", None)
|
||||
if phase_state is None:
|
||||
continue
|
||||
phase_state.enabled_mcp_tools = enabled
|
||||
rebuild = getattr(phase_state, "rebuild_independent_filter", None)
|
||||
if callable(rebuild):
|
||||
try:
|
||||
rebuild()
|
||||
refreshed += 1
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Queen tools: rebuild_independent_filter failed for session %s",
|
||||
getattr(sess, "id", "?"),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Queen tools: queen_id=%s allowlist=%s refreshed_sessions=%d",
|
||||
queen_id,
|
||||
"null" if enabled is None else f"{len(enabled)} tool(s)",
|
||||
refreshed,
|
||||
)
|
||||
return web.json_response(
|
||||
{
|
||||
"queen_id": queen_id,
|
||||
"enabled_mcp_tools": enabled,
|
||||
"refreshed_sessions": refreshed,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def handle_delete_tools(request: web.Request) -> web.Response:
|
||||
"""DELETE /api/queen/{queen_id}/tools — drop the sidecar, fall back to role defaults.
|
||||
|
||||
Users click "Reset to role default" in the Tool Library. That
|
||||
removes ``tools.json`` so the queen's effective allowlist becomes
|
||||
the role-based default (or allow-all if the queen has no role
|
||||
entry). Live sessions are refreshed so the next turn reflects the
|
||||
change without a restart.
|
||||
"""
|
||||
queen_id = request.match_info["queen_id"]
|
||||
ensure_default_queens()
|
||||
try:
|
||||
load_queen_profile(queen_id)
|
||||
except FileNotFoundError:
|
||||
return web.json_response({"error": f"Queen '{queen_id}' not found"}, status=404)
|
||||
|
||||
removed = delete_queen_tools_config(queen_id)
|
||||
|
||||
# Recompute the queen's effective allowlist from the role defaults
|
||||
# so we can hot-reload live sessions in one pass (same shape as
|
||||
# PATCH).
|
||||
manager = request.app.get("manager")
|
||||
session = _live_queen_session(manager, queen_id) if manager is not None else None
|
||||
if session is not None:
|
||||
catalog = _catalog_from_live_session(session)
|
||||
else:
|
||||
catalog = await _ensure_manager_catalog(manager)
|
||||
new_enabled = load_queen_tools_config(queen_id, mcp_catalog=catalog)
|
||||
|
||||
refreshed = 0
|
||||
sessions = getattr(manager, "_sessions", None) or {}
|
||||
for sess in sessions.values():
|
||||
if getattr(sess, "queen_name", None) != queen_id:
|
||||
continue
|
||||
phase_state = getattr(sess, "phase_state", None)
|
||||
if phase_state is None:
|
||||
continue
|
||||
phase_state.enabled_mcp_tools = new_enabled
|
||||
rebuild = getattr(phase_state, "rebuild_independent_filter", None)
|
||||
if callable(rebuild):
|
||||
try:
|
||||
rebuild()
|
||||
refreshed += 1
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Queen tools: rebuild_independent_filter failed for session %s",
|
||||
getattr(sess, "id", "?"),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Queen tools: queen_id=%s reset-to-default removed=%s refreshed_sessions=%d",
|
||||
queen_id,
|
||||
removed,
|
||||
refreshed,
|
||||
)
|
||||
return web.json_response(
|
||||
{
|
||||
"queen_id": queen_id,
|
||||
"removed": removed,
|
||||
"enabled_mcp_tools": new_enabled,
|
||||
"is_role_default": True,
|
||||
"refreshed_sessions": refreshed,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def register_routes(app: web.Application) -> None:
|
||||
"""Register queen-tools routes."""
|
||||
app.router.add_get("/api/queen/{queen_id}/tools", handle_get_tools)
|
||||
app.router.add_patch("/api/queen/{queen_id}/tools", handle_patch_tools)
|
||||
app.router.add_delete("/api/queen/{queen_id}/tools", handle_delete_tools)
|
||||
@@ -248,15 +248,22 @@ async def handle_queen_session(request: web.Request) -> web.Response:
|
||||
# Skip colony sessions: a colony forked from this queen also carries
|
||||
# queen_name == queen_id, but it has a worker loaded (colony_id /
|
||||
# worker_path set) and is the colony's chat, not the queen's DM.
|
||||
for session in manager.list_sessions():
|
||||
if session.queen_name == queen_id and session.colony_id is None and session.worker_path is None:
|
||||
return web.json_response(
|
||||
{
|
||||
"session_id": session.id,
|
||||
"queen_id": queen_id,
|
||||
"status": "live",
|
||||
}
|
||||
)
|
||||
# When multiple DM sessions for this queen are live at once (e.g. the
|
||||
# user created a new session, then navigated away and back), return
|
||||
# the most recently loaded one so we don't resurrect a stale older
|
||||
# session ahead of a freshly created one.
|
||||
live_matches = [
|
||||
s for s in manager.list_sessions() if s.queen_name == queen_id and s.colony_id is None and s.worker_path is None
|
||||
]
|
||||
if live_matches:
|
||||
latest = max(live_matches, key=lambda s: s.loaded_at)
|
||||
return web.json_response(
|
||||
{
|
||||
"session_id": latest.id,
|
||||
"queen_id": queen_id,
|
||||
"status": "live",
|
||||
}
|
||||
)
|
||||
|
||||
# 2. Find the most recent cold session for this queen and resume it.
|
||||
# IMPORTANT: skip sessions that don't belong in the queen DM:
|
||||
@@ -378,6 +385,8 @@ async def handle_select_queen_session(request: web.Request) -> web.Response:
|
||||
|
||||
async def handle_new_queen_session(request: web.Request) -> web.Response:
|
||||
"""POST /api/queen/{queen_id}/session/new -- create a fresh queen session."""
|
||||
from framework.tools.queen_lifecycle_tools import QUEEN_PHASES
|
||||
|
||||
queen_id = request.match_info["queen_id"]
|
||||
manager = request.app["manager"]
|
||||
|
||||
@@ -387,9 +396,25 @@ async def handle_new_queen_session(request: web.Request) -> web.Response:
|
||||
except FileNotFoundError:
|
||||
return web.json_response({"error": f"Queen '{queen_id}' not found"}, status=404)
|
||||
|
||||
body = await request.json() if request.can_read_body else {}
|
||||
if request.can_read_body:
|
||||
try:
|
||||
body = await request.json()
|
||||
except json.JSONDecodeError:
|
||||
return web.json_response({"error": "Invalid JSON body"}, status=400)
|
||||
if not isinstance(body, dict):
|
||||
return web.json_response({"error": "Request body must be a JSON object"}, status=400)
|
||||
else:
|
||||
body = {}
|
||||
initial_prompt = body.get("initial_prompt")
|
||||
initial_phase = body.get("initial_phase") or "independent"
|
||||
if initial_phase not in QUEEN_PHASES:
|
||||
return web.json_response(
|
||||
{
|
||||
"error": f"Invalid initial_phase '{initial_phase}'",
|
||||
"valid": sorted(QUEEN_PHASES),
|
||||
},
|
||||
status=400,
|
||||
)
|
||||
|
||||
session = await manager.create_session(
|
||||
initial_prompt=initial_prompt,
|
||||
|
||||
@@ -122,8 +122,19 @@ async def handle_create_session(request: web.Request) -> web.Response:
|
||||
(equivalent to the old POST /api/agents). Otherwise creates a queen-only
|
||||
session that can later have a colony loaded via POST /sessions/{id}/colony.
|
||||
"""
|
||||
from framework.agents.queen.queen_profiles import ensure_default_queens, load_queen_profile
|
||||
from framework.tools.queen_lifecycle_tools import QUEEN_PHASES
|
||||
|
||||
manager = _get_manager(request)
|
||||
body = await request.json() if request.can_read_body else {}
|
||||
if request.can_read_body:
|
||||
try:
|
||||
body = await request.json()
|
||||
except json.JSONDecodeError:
|
||||
return web.json_response({"error": "Invalid JSON body"}, status=400)
|
||||
if not isinstance(body, dict):
|
||||
return web.json_response({"error": "Request body must be a JSON object"}, status=400)
|
||||
else:
|
||||
body = {}
|
||||
agent_path = body.get("agent_path")
|
||||
agent_id = body.get("agent_id")
|
||||
session_id = body.get("session_id")
|
||||
@@ -134,6 +145,21 @@ async def handle_create_session(request: web.Request) -> web.Response:
|
||||
initial_phase = body.get("initial_phase")
|
||||
worker_name = body.get("worker_name")
|
||||
|
||||
if initial_phase is not None and initial_phase not in QUEEN_PHASES:
|
||||
return web.json_response(
|
||||
{
|
||||
"error": f"Invalid initial_phase '{initial_phase}'",
|
||||
"valid": sorted(QUEEN_PHASES),
|
||||
},
|
||||
status=400,
|
||||
)
|
||||
if queen_name:
|
||||
ensure_default_queens()
|
||||
try:
|
||||
load_queen_profile(queen_name)
|
||||
except FileNotFoundError:
|
||||
return web.json_response({"error": f"Queen '{queen_name}' not found"}, status=404)
|
||||
|
||||
if agent_path:
|
||||
try:
|
||||
agent_path = str(validate_agent_path(agent_path))
|
||||
@@ -160,6 +186,7 @@ async def handle_create_session(request: web.Request) -> web.Response:
|
||||
model=model,
|
||||
initial_prompt=initial_prompt,
|
||||
queen_resume_from=queen_resume_from,
|
||||
queen_name=queen_name,
|
||||
initial_phase=initial_phase,
|
||||
)
|
||||
except ValueError as e:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1223,8 +1223,27 @@ class SessionManager:
|
||||
logger.info("Session '%s': shutdown reflection spawned", session_id)
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
except Exception:
|
||||
logger.warning("Session '%s': failed to spawn shutdown reflection", session_id, exc_info=True)
|
||||
except RuntimeError as exc:
|
||||
# Most common when a session is stopped after the event loop
|
||||
# has closed (e.g. during server shutdown or from an atexit
|
||||
# handler). The reflection would have had nothing to write
|
||||
# anyway — no new turns since the last periodic reflection.
|
||||
logger.warning(
|
||||
"Session '%s': shutdown reflection skipped — event loop unavailable (%s). "
|
||||
"Normal during server shutdown; anything worth persisting was saved by the "
|
||||
"periodic reflection after the last turn.",
|
||||
session_id,
|
||||
exc,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Session '%s': failed to spawn shutdown reflection: %s: %s. "
|
||||
"Check that queen_dir exists and session.llm is configured; full traceback follows.",
|
||||
session_id,
|
||||
type(exc).__name__,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if session.queen_task is not None:
|
||||
session.queen_task.cancel()
|
||||
@@ -1516,8 +1535,46 @@ class SessionManager:
|
||||
tool_executor=queen_tool_executor,
|
||||
event_bus=session.event_bus,
|
||||
colony_id=session.id,
|
||||
# Wire the on-disk colony name and queen id so
|
||||
# ColonyRuntime auto-derives its override paths. DM sessions
|
||||
# have no colony_name (session.colony_name is None), which
|
||||
# keeps them out of the per-colony JSON store.
|
||||
colony_name=getattr(session, "colony_name", None),
|
||||
queen_id=getattr(session, "queen_name", None) or None,
|
||||
pipeline_stages=[], # queen pipeline runs in queen_orchestrator, not here
|
||||
)
|
||||
|
||||
# Per-colony tool allowlist, loaded from the colony's metadata.json
|
||||
# when this session is attached to a real forked colony. For pure
|
||||
# queen DM sessions (session.colony_name is None) we only capture
|
||||
# the MCP-origin set — the allowlist stays ``None`` so every MCP
|
||||
# tool passes through by default.
|
||||
try:
|
||||
mcp_tool_names_all: set[str] = set()
|
||||
mgr_catalog = getattr(self, "_mcp_tool_catalog", None)
|
||||
if isinstance(mgr_catalog, dict):
|
||||
for entries in mgr_catalog.values():
|
||||
for entry in entries:
|
||||
name = entry.get("name") if isinstance(entry, dict) else None
|
||||
if name:
|
||||
mcp_tool_names_all.add(name)
|
||||
enabled_mcp_tools: list[str] | None = None
|
||||
colony_name = getattr(session, "colony_name", None)
|
||||
if colony_name:
|
||||
# Colony tool allowlist lives in a dedicated tools.json
|
||||
# sidecar next to metadata.json. The helper migrates any
|
||||
# legacy field out of metadata.json on first read.
|
||||
from framework.host.colony_tools_config import load_colony_tools_config
|
||||
|
||||
enabled_mcp_tools = load_colony_tools_config(colony_name)
|
||||
colony.set_tool_allowlist(enabled_mcp_tools, mcp_tool_names_all)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Colony allowlist bootstrap failed for session %s",
|
||||
session.id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
await colony.start()
|
||||
session.colony = colony
|
||||
|
||||
@@ -1707,6 +1764,42 @@ class SessionManager:
|
||||
def list_sessions(self) -> list[Session]:
|
||||
return list(self._sessions.values())
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Skill override helpers — used by routes_skills to find every live
|
||||
# SkillsManager affected by a queen- or colony-scope mutation so a
|
||||
# single HTTP call can reload them all.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def iter_queen_sessions(self, queen_id: str):
|
||||
"""Yield live sessions whose queen matches ``queen_id``."""
|
||||
for s in self._sessions.values():
|
||||
if getattr(s, "queen_name", None) == queen_id:
|
||||
yield s
|
||||
|
||||
def iter_colony_runtimes(
|
||||
self,
|
||||
*,
|
||||
queen_id: str | None = None,
|
||||
colony_name: str | None = None,
|
||||
):
|
||||
"""Yield live ``ColonyRuntime`` instances matching the filters.
|
||||
|
||||
``queen_id`` alone → every runtime whose ``queen_id`` matches
|
||||
(useful when the user toggles a queen-scope skill — all her
|
||||
colonies must reload). ``colony_name`` alone → the single
|
||||
runtime pinned to that colony. Both → intersection. No filters
|
||||
→ every live runtime (used by global ``/api/skills`` reload).
|
||||
"""
|
||||
for s in self._sessions.values():
|
||||
colony = getattr(s, "colony", None)
|
||||
if colony is None:
|
||||
continue
|
||||
if queen_id is not None and getattr(colony, "queen_id", None) != queen_id:
|
||||
continue
|
||||
if colony_name is not None and getattr(colony, "colony_name", None) != colony_name:
|
||||
continue
|
||||
yield colony
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Cold session helpers (disk-only, no live runtime required)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,300 @@
|
||||
"""Tests for the per-colony MCP tool allowlist filter + routes.
|
||||
|
||||
Covers:
|
||||
1. ``ColonyRuntime`` filter semantics (default-allow, allowlist, empty,
|
||||
lifecycle passes through).
|
||||
2. routes_colony_tools round trip (GET/PATCH, validation, 404).
|
||||
3. Colony index route for the Tool Library picker.
|
||||
|
||||
Routes never touch the real ``~/.hive/colonies`` tree — we redirect
|
||||
``COLONIES_DIR`` into ``tmp_path`` via monkeypatch.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
from framework.host.colony_runtime import ColonyRuntime
|
||||
from framework.llm.provider import Tool
|
||||
from framework.server import routes_colony_tools
|
||||
|
||||
|
||||
def _tool(name: str) -> Tool:
|
||||
return Tool(name=name, description=f"desc of {name}", parameters={"type": "object"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ColonyRuntime filter unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _bare_runtime() -> ColonyRuntime:
|
||||
rt = ColonyRuntime.__new__(ColonyRuntime)
|
||||
rt._enabled_mcp_tools = None
|
||||
rt._mcp_tool_names_all = set()
|
||||
return rt
|
||||
|
||||
|
||||
class TestColonyFilter:
|
||||
def test_default_is_noop(self):
|
||||
rt = _bare_runtime()
|
||||
tools = [_tool("mcp_a"), _tool("lc_b")]
|
||||
assert rt._apply_tool_allowlist(tools) == tools
|
||||
|
||||
def test_allowlist_gates_mcp_only(self):
|
||||
rt = _bare_runtime()
|
||||
rt._mcp_tool_names_all = {"mcp_a", "mcp_b"}
|
||||
rt._enabled_mcp_tools = ["mcp_a"]
|
||||
tools = [_tool("mcp_a"), _tool("mcp_b"), _tool("lc_c")]
|
||||
names = [t.name for t in rt._apply_tool_allowlist(tools)]
|
||||
assert names == ["mcp_a", "lc_c"]
|
||||
|
||||
def test_empty_allowlist_keeps_lifecycle(self):
|
||||
rt = _bare_runtime()
|
||||
rt._mcp_tool_names_all = {"mcp_a", "mcp_b"}
|
||||
rt._enabled_mcp_tools = []
|
||||
tools = [_tool("mcp_a"), _tool("mcp_b"), _tool("lc_c")]
|
||||
names = [t.name for t in rt._apply_tool_allowlist(tools)]
|
||||
assert names == ["lc_c"]
|
||||
|
||||
def test_setter_mutates_live_state(self):
|
||||
rt = _bare_runtime()
|
||||
rt.set_tool_allowlist(["x"], {"x", "y"})
|
||||
assert rt._enabled_mcp_tools == ["x"]
|
||||
assert rt._mcp_tool_names_all == {"x", "y"}
|
||||
|
||||
# Passing None on allowlist clears gating; mcp_tool_names_all
|
||||
# defaults to "keep current" so a subsequent caller doesn't need
|
||||
# to repeat the set.
|
||||
rt.set_tool_allowlist(None)
|
||||
assert rt._enabled_mcp_tools is None
|
||||
assert rt._mcp_tool_names_all == {"x", "y"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Route round-trip tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeSession:
|
||||
colony_name: str
|
||||
colony: Any = None
|
||||
colony_runtime: Any = None
|
||||
id: str = "sess-1"
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeManager:
|
||||
_sessions: dict = field(default_factory=dict)
|
||||
_mcp_tool_catalog: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def colony_dir(tmp_path, monkeypatch):
|
||||
"""Point COLONIES_DIR into a tmp tree and seed a colony."""
|
||||
colonies = tmp_path / "colonies"
|
||||
colonies.mkdir()
|
||||
monkeypatch.setattr("framework.host.colony_metadata.COLONIES_DIR", colonies)
|
||||
monkeypatch.setattr("framework.host.colony_tools_config.COLONIES_DIR", colonies)
|
||||
|
||||
name = "my_colony"
|
||||
cdir = colonies / name
|
||||
cdir.mkdir()
|
||||
(cdir / "metadata.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"colony_name": name,
|
||||
"queen_name": "queen_technology",
|
||||
"created_at": "2026-04-20T00:00:00+00:00",
|
||||
}
|
||||
)
|
||||
)
|
||||
return colonies, name
|
||||
|
||||
|
||||
async def _app(manager: _FakeManager) -> web.Application:
|
||||
app = web.Application()
|
||||
app["manager"] = manager
|
||||
routes_colony_tools.register_routes(app)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tools_default_allow(colony_dir):
|
||||
_, name = colony_dir
|
||||
manager = _FakeManager(
|
||||
_mcp_tool_catalog={
|
||||
"coder-tools": [
|
||||
{"name": "read_file", "description": "read", "input_schema": {}},
|
||||
{"name": "write_file", "description": "write", "input_schema": {}},
|
||||
],
|
||||
}
|
||||
)
|
||||
app = await _app(manager)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.get(f"/api/colony/{name}/tools")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["enabled_mcp_tools"] is None
|
||||
assert body["stale"] is False
|
||||
tools = {t["name"]: t for t in body["mcp_servers"][0]["tools"]}
|
||||
assert all(t["enabled"] for t in tools.values())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_patch_persists_and_validates(colony_dir):
|
||||
colonies_dir, name = colony_dir
|
||||
manager = _FakeManager(
|
||||
_mcp_tool_catalog={
|
||||
"coder-tools": [
|
||||
{"name": "read_file", "description": "", "input_schema": {}},
|
||||
{"name": "write_file", "description": "", "input_schema": {}},
|
||||
]
|
||||
}
|
||||
)
|
||||
app = await _app(manager)
|
||||
tools_path = colonies_dir / name / "tools.json"
|
||||
metadata_path = colonies_dir / name / "metadata.json"
|
||||
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.patch(f"/api/colony/{name}/tools", json={"enabled_mcp_tools": ["read_file"]})
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["enabled_mcp_tools"] == ["read_file"]
|
||||
|
||||
# Persisted to tools.json; metadata.json does not carry the field.
|
||||
sidecar = json.loads(tools_path.read_text())
|
||||
assert sidecar["enabled_mcp_tools"] == ["read_file"]
|
||||
assert "updated_at" in sidecar
|
||||
meta = json.loads(metadata_path.read_text())
|
||||
assert "enabled_mcp_tools" not in meta
|
||||
|
||||
# GET reflects the allowlist
|
||||
resp = await client.get(f"/api/colony/{name}/tools")
|
||||
body = await resp.json()
|
||||
tools = {t["name"]: t for t in body["mcp_servers"][0]["tools"]}
|
||||
assert tools["read_file"]["enabled"] is True
|
||||
assert tools["write_file"]["enabled"] is False
|
||||
|
||||
# Unknown → 400
|
||||
resp = await client.patch(f"/api/colony/{name}/tools", json={"enabled_mcp_tools": ["ghost"]})
|
||||
assert resp.status == 400
|
||||
assert "ghost" in (await resp.json()).get("unknown", [])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_patch_refreshes_live_runtime(colony_dir):
|
||||
_, name = colony_dir
|
||||
|
||||
rt = _bare_runtime()
|
||||
rt._mcp_tool_names_all = {"read_file", "write_file"}
|
||||
rt.set_tool_allowlist(None)
|
||||
|
||||
session = _FakeSession(colony_name=name, colony=rt)
|
||||
manager = _FakeManager(
|
||||
_sessions={session.id: session},
|
||||
_mcp_tool_catalog={
|
||||
"coder-tools": [
|
||||
{"name": "read_file", "description": "", "input_schema": {}},
|
||||
{"name": "write_file", "description": "", "input_schema": {}},
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
app = await _app(manager)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.patch(f"/api/colony/{name}/tools", json={"enabled_mcp_tools": ["read_file"]})
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["refreshed_runtimes"] == 1
|
||||
assert rt._enabled_mcp_tools == ["read_file"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_404_for_unknown_colony(colony_dir):
|
||||
manager = _FakeManager()
|
||||
app = await _app(manager)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.get("/api/colony/unknown/tools")
|
||||
assert resp.status == 404
|
||||
resp = await client.patch("/api/colony/unknown/tools", json={"enabled_mcp_tools": None})
|
||||
assert resp.status == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tools_index_lists_colonies(colony_dir):
|
||||
_, name = colony_dir
|
||||
manager = _FakeManager()
|
||||
app = await _app(manager)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.get("/api/colonies/tools-index")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
entries = {c["name"]: c for c in body["colonies"]}
|
||||
assert name in entries
|
||||
assert entries[name]["queen_name"] == "queen_technology"
|
||||
assert entries[name]["has_allowlist"] is False
|
||||
|
||||
|
||||
def test_queen_allowlist_inherits_into_new_colony(tmp_path, monkeypatch):
|
||||
"""A colony forked with a curated queen inherits her allowlist.
|
||||
|
||||
Exercises the inheritance hook in
|
||||
``routes_execution.fork_session_into_colony`` without running the
|
||||
full fork machinery — we just call
|
||||
``update_colony_tools_config`` the same way the hook does and
|
||||
assert the colony's ``tools.json`` matches the queen's live list.
|
||||
"""
|
||||
colonies = tmp_path / "colonies"
|
||||
colonies.mkdir()
|
||||
monkeypatch.setattr("framework.host.colony_tools_config.COLONIES_DIR", colonies)
|
||||
|
||||
from framework.host.colony_tools_config import (
|
||||
load_colony_tools_config,
|
||||
update_colony_tools_config,
|
||||
)
|
||||
|
||||
colony_name = "forked_child"
|
||||
(colonies / colony_name).mkdir()
|
||||
|
||||
# Simulate: queen has a curated allowlist (e.g. role default resolved
|
||||
# to a concrete list). The inheritance hook copies it verbatim.
|
||||
queen_live_allowlist = ["read_file", "web_scrape", "csv_read"]
|
||||
update_colony_tools_config(colony_name, list(queen_live_allowlist))
|
||||
|
||||
assert load_colony_tools_config(colony_name) == queen_live_allowlist
|
||||
|
||||
|
||||
def test_legacy_metadata_field_migrates_to_sidecar(colony_dir):
|
||||
"""A legacy enabled_mcp_tools field in metadata.json is hoisted to tools.json."""
|
||||
colonies_dir, name = colony_dir
|
||||
meta_path = colonies_dir / name / "metadata.json"
|
||||
tools_path = colonies_dir / name / "tools.json"
|
||||
|
||||
# Seed legacy field in metadata.json.
|
||||
meta = json.loads(meta_path.read_text())
|
||||
meta["enabled_mcp_tools"] = ["read_file"]
|
||||
meta_path.write_text(json.dumps(meta))
|
||||
|
||||
from framework.host.colony_tools_config import load_colony_tools_config
|
||||
|
||||
# First load migrates.
|
||||
assert load_colony_tools_config(name) == ["read_file"]
|
||||
assert tools_path.exists()
|
||||
sidecar = json.loads(tools_path.read_text())
|
||||
assert sidecar["enabled_mcp_tools"] == ["read_file"]
|
||||
|
||||
# metadata.json no longer contains the field; provenance fields preserved.
|
||||
migrated = json.loads(meta_path.read_text())
|
||||
assert "enabled_mcp_tools" not in migrated
|
||||
assert migrated["queen_name"] == "queen_technology"
|
||||
|
||||
# Second load is a direct sidecar read.
|
||||
assert load_colony_tools_config(name) == ["read_file"]
|
||||
@@ -0,0 +1,239 @@
|
||||
"""Tests for the MCP server CRUD HTTP routes.
|
||||
|
||||
Monkey-patches ``MCPRegistry`` inside ``routes_mcp`` so the HTTP layer is
|
||||
exercised without reading or writing ``~/.hive/mcp_registry/installed.json``
|
||||
or spawning actual subprocesses.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
from framework.loader.mcp_errors import MCPError, MCPErrorCode
|
||||
from framework.server import routes_mcp
|
||||
|
||||
|
||||
class _FakeRegistry:
|
||||
"""Stand-in for MCPRegistry — just enough surface for the routes."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._servers: dict[str, dict[str, Any]] = {
|
||||
"built-in-seed": {
|
||||
"source": "registry",
|
||||
"transport": "stdio",
|
||||
"enabled": True,
|
||||
"manifest": {"description": "Factory-seeded server", "tools": []},
|
||||
"last_health_status": "healthy",
|
||||
"last_error": None,
|
||||
"last_health_check_at": None,
|
||||
}
|
||||
}
|
||||
|
||||
def initialize(self) -> None: # noqa: D401 — registry idempotent init
|
||||
return
|
||||
|
||||
def list_installed(self) -> list[dict[str, Any]]:
|
||||
return [{"name": name, **entry} for name, entry in self._servers.items()]
|
||||
|
||||
def get_server(self, name: str) -> dict | None:
|
||||
if name not in self._servers:
|
||||
return None
|
||||
return {"name": name, **self._servers[name]}
|
||||
|
||||
def add_local(self, *, name: str, transport: str, **kwargs: Any) -> dict:
|
||||
if name in self._servers:
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Server '{name}' already exists",
|
||||
why="A server with this name is already registered locally.",
|
||||
fix=f"Run: hive mcp remove {name}",
|
||||
)
|
||||
entry = {
|
||||
"source": "local",
|
||||
"transport": transport,
|
||||
"enabled": True,
|
||||
"manifest": {"description": kwargs.get("description") or ""},
|
||||
"last_health_status": None,
|
||||
"last_error": None,
|
||||
"last_health_check_at": None,
|
||||
}
|
||||
self._servers[name] = entry
|
||||
return entry
|
||||
|
||||
def remove(self, name: str) -> None:
|
||||
if name not in self._servers:
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Cannot remove server '{name}'",
|
||||
why="Server is not installed.",
|
||||
fix="Run: hive mcp list",
|
||||
)
|
||||
del self._servers[name]
|
||||
|
||||
def enable(self, name: str) -> None:
|
||||
if name not in self._servers:
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what="not found",
|
||||
why="not found",
|
||||
fix="x",
|
||||
)
|
||||
self._servers[name]["enabled"] = True
|
||||
|
||||
def disable(self, name: str) -> None:
|
||||
if name not in self._servers:
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what="not found",
|
||||
why="not found",
|
||||
fix="x",
|
||||
)
|
||||
self._servers[name]["enabled"] = False
|
||||
|
||||
def health_check(self, name: str) -> dict[str, Any]:
|
||||
if name not in self._servers:
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_HEALTH_FAILED,
|
||||
what="not found",
|
||||
why="not found",
|
||||
fix="x",
|
||||
)
|
||||
return {"name": name, "status": "healthy", "tools": 3, "error": None}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def registry(monkeypatch):
|
||||
reg = _FakeRegistry()
|
||||
monkeypatch.setattr(routes_mcp, "_registry", lambda: reg)
|
||||
return reg
|
||||
|
||||
|
||||
async def _make_app() -> web.Application:
|
||||
app = web.Application()
|
||||
routes_mcp.register_routes(app)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_servers_returns_built_in(registry):
|
||||
app = await _make_app()
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.get("/api/mcp/servers")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
names = {s["name"] for s in body["servers"]}
|
||||
# The registry fake carries one entry; the list also merges package-
|
||||
# baked entries from core/framework/agents/queen/mcp_servers.json so
|
||||
# the UI matches what the queen actually loads. Both should appear.
|
||||
assert "built-in-seed" in names
|
||||
sources = {s["name"]: s["source"] for s in body["servers"]}
|
||||
assert sources.get("built-in-seed") == "registry"
|
||||
# The package-baked servers (coder-tools/gcu-tools/hive_tools) carry
|
||||
# source=="built-in" and are non-removable.
|
||||
pkg_entries = [s for s in body["servers"] if s["source"] == "built-in"]
|
||||
assert pkg_entries, "expected at least one package-baked MCP server"
|
||||
assert all(s.get("removable") is False for s in pkg_entries)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_local_server(registry):
|
||||
app = await _make_app()
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.post(
|
||||
"/api/mcp/servers",
|
||||
json={
|
||||
"name": "my-tool",
|
||||
"transport": "stdio",
|
||||
"command": "echo",
|
||||
"args": ["hi"],
|
||||
"description": "says hi",
|
||||
},
|
||||
)
|
||||
assert resp.status == 201
|
||||
body = await resp.json()
|
||||
assert body["server"]["name"] == "my-tool"
|
||||
assert body["server"]["source"] == "local"
|
||||
|
||||
resp = await client.get("/api/mcp/servers")
|
||||
names = [s["name"] for s in (await resp.json())["servers"]]
|
||||
assert "my-tool" in names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_rejects_duplicate(registry):
|
||||
app = await _make_app()
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
for _ in range(2):
|
||||
resp = await client.post(
|
||||
"/api/mcp/servers",
|
||||
json={"name": "dup", "transport": "stdio", "command": "x"},
|
||||
)
|
||||
assert resp.status == 409
|
||||
body = await resp.json()
|
||||
assert "already exists" in body["error"].lower()
|
||||
assert body["fix"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_rejects_invalid_transport(registry):
|
||||
app = await _make_app()
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.post(
|
||||
"/api/mcp/servers",
|
||||
json={"name": "x", "transport": "nope"},
|
||||
)
|
||||
assert resp.status == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enable_disable_cycle(registry):
|
||||
app = await _make_app()
|
||||
# Seed a local server
|
||||
registry.add_local(name="local-one", transport="stdio", command="x")
|
||||
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.post("/api/mcp/servers/local-one/disable")
|
||||
assert resp.status == 200
|
||||
assert (await resp.json())["enabled"] is False
|
||||
assert registry._servers["local-one"]["enabled"] is False
|
||||
|
||||
resp = await client.post("/api/mcp/servers/local-one/enable")
|
||||
assert resp.status == 200
|
||||
assert (await resp.json())["enabled"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_local_only(registry):
|
||||
app = await _make_app()
|
||||
registry.add_local(name="local-two", transport="stdio", command="x")
|
||||
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
# Built-ins are protected
|
||||
resp = await client.delete("/api/mcp/servers/built-in-seed")
|
||||
assert resp.status == 400
|
||||
|
||||
# Missing
|
||||
resp = await client.delete("/api/mcp/servers/ghost")
|
||||
assert resp.status == 404
|
||||
|
||||
# Happy path
|
||||
resp = await client.delete("/api/mcp/servers/local-two")
|
||||
assert resp.status == 200
|
||||
assert "local-two" not in registry._servers
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check(registry, monkeypatch):
|
||||
app = await _make_app()
|
||||
registry.add_local(name="pingable", transport="stdio", command="x")
|
||||
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.post("/api/mcp/servers/pingable/health")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["status"] == "healthy"
|
||||
assert body["tools"] == 3
|
||||
@@ -0,0 +1,443 @@
|
||||
"""Tests for the per-queen MCP tool allowlist filter + routes.
|
||||
|
||||
Covers:
|
||||
1. QueenPhaseState filter semantics (default-allow, allowlist, empty, phase-
|
||||
isolation, memo identity for LLM prompt-cache stability).
|
||||
2. routes_queen_tools round trip (GET, PATCH, validation, live-session
|
||||
hot-reload).
|
||||
|
||||
Route tests monkey-patch a tiny queen profile + manager catalog; they never
|
||||
spawn an MCP subprocess.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from aiohttp import web
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
from framework.llm.provider import Tool
|
||||
from framework.server import routes_queen_tools
|
||||
from framework.tools.queen_lifecycle_tools import QueenPhaseState
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# QueenPhaseState filter — pure unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _tool(name: str) -> Tool:
|
||||
return Tool(name=name, description=f"desc of {name}", parameters={"type": "object"})
|
||||
|
||||
|
||||
class TestPhaseStateFilter:
|
||||
def test_default_allow_returns_every_tool(self):
|
||||
ps = QueenPhaseState(phase="independent")
|
||||
ps.independent_tools = [_tool("mcp_a"), _tool("mcp_b"), _tool("lc_c")]
|
||||
ps.mcp_tool_names_all = {"mcp_a", "mcp_b"}
|
||||
ps.enabled_mcp_tools = None
|
||||
ps.rebuild_independent_filter()
|
||||
|
||||
names = [t.name for t in ps.get_current_tools()]
|
||||
assert names == ["mcp_a", "mcp_b", "lc_c"]
|
||||
|
||||
def test_allowlist_keeps_listed_mcp_plus_all_lifecycle(self):
|
||||
ps = QueenPhaseState(phase="independent")
|
||||
ps.independent_tools = [_tool("mcp_a"), _tool("mcp_b"), _tool("lc_c")]
|
||||
ps.mcp_tool_names_all = {"mcp_a", "mcp_b"}
|
||||
ps.enabled_mcp_tools = ["mcp_a"]
|
||||
ps.rebuild_independent_filter()
|
||||
|
||||
names = [t.name for t in ps.get_current_tools()]
|
||||
assert names == ["mcp_a", "lc_c"]
|
||||
|
||||
def test_empty_allowlist_keeps_only_lifecycle(self):
|
||||
ps = QueenPhaseState(phase="independent")
|
||||
ps.independent_tools = [_tool("mcp_a"), _tool("mcp_b"), _tool("lc_c")]
|
||||
ps.mcp_tool_names_all = {"mcp_a", "mcp_b"}
|
||||
ps.enabled_mcp_tools = []
|
||||
ps.rebuild_independent_filter()
|
||||
|
||||
names = [t.name for t in ps.get_current_tools()]
|
||||
assert names == ["lc_c"]
|
||||
|
||||
def test_filter_isolated_to_independent_phase(self):
|
||||
ps = QueenPhaseState(phase="independent")
|
||||
ps.independent_tools = [_tool("mcp_a"), _tool("lc_c")]
|
||||
ps.working_tools = [_tool("mcp_a"), _tool("lc_c")]
|
||||
ps.mcp_tool_names_all = {"mcp_a"}
|
||||
ps.enabled_mcp_tools = []
|
||||
ps.rebuild_independent_filter()
|
||||
|
||||
# Independent → filtered
|
||||
assert [t.name for t in ps.get_current_tools()] == ["lc_c"]
|
||||
|
||||
# Other phases → unaffected
|
||||
ps.phase = "working"
|
||||
assert [t.name for t in ps.get_current_tools()] == ["mcp_a", "lc_c"]
|
||||
|
||||
def test_memo_returns_stable_identity_for_prompt_cache(self):
|
||||
"""Same Python list object across turns → LLM prompt cache stays warm."""
|
||||
ps = QueenPhaseState(phase="independent")
|
||||
ps.independent_tools = [_tool("mcp_a"), _tool("lc_c")]
|
||||
ps.mcp_tool_names_all = {"mcp_a"}
|
||||
ps.enabled_mcp_tools = None
|
||||
ps.rebuild_independent_filter()
|
||||
|
||||
first = ps.get_current_tools()
|
||||
second = ps.get_current_tools()
|
||||
assert first is second, "memoized list must be the same object across turns"
|
||||
|
||||
# A rebuild should produce a different object so downstream caches
|
||||
# correctly invalidate.
|
||||
ps.enabled_mcp_tools = ["mcp_a"]
|
||||
ps.rebuild_independent_filter()
|
||||
third = ps.get_current_tools()
|
||||
assert third is not first
|
||||
assert [t.name for t in third] == ["mcp_a", "lc_c"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Route round-trip tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeSession:
|
||||
queen_name: str
|
||||
phase_state: QueenPhaseState
|
||||
colony_runtime: Any = None
|
||||
id: str = "sess-1"
|
||||
_queen_tool_registry: Any = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeManager:
|
||||
_sessions: dict = field(default_factory=dict)
|
||||
_mcp_tool_catalog: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def queen_dir(tmp_path, monkeypatch):
|
||||
"""Redirect queen profile + tools storage into a tmp dir."""
|
||||
queens_dir = tmp_path / "queens"
|
||||
queens_dir.mkdir()
|
||||
monkeypatch.setattr("framework.agents.queen.queen_profiles.QUEENS_DIR", queens_dir)
|
||||
monkeypatch.setattr("framework.agents.queen.queen_tools_config.QUEENS_DIR", queens_dir)
|
||||
|
||||
queen_id = "queen_technology"
|
||||
(queens_dir / queen_id).mkdir()
|
||||
(queens_dir / queen_id / "profile.yaml").write_text(
|
||||
yaml.safe_dump({"name": "Alexandra", "title": "Head of Technology"})
|
||||
)
|
||||
return queens_dir, queen_id
|
||||
|
||||
|
||||
async def _make_app(*, manager: _FakeManager) -> web.Application:
|
||||
app = web.Application()
|
||||
app["manager"] = manager
|
||||
routes_queen_tools.register_routes(app)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tools_default_allows_everything_for_unknown_queen(queen_dir, monkeypatch):
|
||||
"""Queens NOT in the role-default table fall back to allow-all."""
|
||||
monkeypatch.setattr(routes_queen_tools, "ensure_default_queens", lambda: None)
|
||||
|
||||
queens_dir, _ = queen_dir
|
||||
# Use a queen id that isn't in QUEEN_DEFAULT_CATEGORIES so we exercise
|
||||
# the fallback-to-allow-all path.
|
||||
custom_id = "queen_custom_unknown"
|
||||
(queens_dir / custom_id).mkdir()
|
||||
(queens_dir / custom_id / "profile.yaml").write_text(yaml.safe_dump({"name": "Custom", "title": "Custom Role"}))
|
||||
|
||||
manager = _FakeManager()
|
||||
manager._mcp_tool_catalog = {
|
||||
"coder-tools": [
|
||||
{"name": "read_file", "description": "read", "input_schema": {}},
|
||||
{"name": "write_file", "description": "write", "input_schema": {}},
|
||||
],
|
||||
}
|
||||
|
||||
app = await _make_app(manager=manager)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.get(f"/api/queen/{custom_id}/tools")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
|
||||
assert body["enabled_mcp_tools"] is None
|
||||
assert body["is_role_default"] is True # no sidecar → default-allow
|
||||
assert body["stale"] is False
|
||||
servers = {s["name"]: s for s in body["mcp_servers"]}
|
||||
assert set(servers) == {"coder-tools"}
|
||||
for tool in servers["coder-tools"]["tools"]:
|
||||
assert tool["enabled"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tools_applies_role_default(queen_dir, monkeypatch):
|
||||
"""Known persona queens get their role-based default allowlist."""
|
||||
monkeypatch.setattr(routes_queen_tools, "ensure_default_queens", lambda: None)
|
||||
_, queen_id = queen_dir # queen_technology — has a role default
|
||||
|
||||
manager = _FakeManager()
|
||||
# Seed a catalog covering tools the role default references so the
|
||||
# response reflects what the queen would actually see on boot.
|
||||
manager._mcp_tool_catalog = {
|
||||
"coder-tools": [
|
||||
{"name": "read_file", "description": "", "input_schema": {}},
|
||||
{"name": "port_scan", "description": "", "input_schema": {}}, # security
|
||||
{"name": "excel_read", "description": "", "input_schema": {}}, # data
|
||||
{"name": "fluffy_unknown_tool", "description": "", "input_schema": {}},
|
||||
],
|
||||
}
|
||||
|
||||
app = await _make_app(manager=manager)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.get(f"/api/queen/{queen_id}/tools")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
|
||||
# queen_technology's role default includes file_read, data, security, etc.
|
||||
assert body["is_role_default"] is True
|
||||
enabled = set(body["enabled_mcp_tools"] or [])
|
||||
assert "read_file" in enabled
|
||||
assert "port_scan" in enabled # technology role includes security
|
||||
assert "excel_read" in enabled
|
||||
# Tools not in any category (and not in a @server: expansion target
|
||||
# the role references) are NOT part of the default.
|
||||
assert "fluffy_unknown_tool" not in enabled
|
||||
|
||||
|
||||
def test_resolve_queen_default_tools_expands_server_shorthand():
|
||||
"""@server:NAME shorthand expands against the provided catalog."""
|
||||
from framework.agents.queen.queen_tools_defaults import resolve_queen_default_tools
|
||||
|
||||
catalog = {
|
||||
"gcu-tools": [
|
||||
{"name": "browser_navigate"},
|
||||
{"name": "browser_click"},
|
||||
],
|
||||
}
|
||||
# queen_brand_design uses "browser" category → expands via @server:gcu-tools.
|
||||
result = resolve_queen_default_tools("queen_brand_design", catalog)
|
||||
assert result is not None
|
||||
assert "browser_navigate" in result
|
||||
assert "browser_click" in result
|
||||
|
||||
|
||||
def test_resolve_queen_default_tools_unknown_queen_returns_none():
|
||||
from framework.agents.queen.queen_tools_defaults import resolve_queen_default_tools
|
||||
|
||||
assert resolve_queen_default_tools("queen_made_up", {}) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_patch_persists_and_validates(queen_dir, monkeypatch):
|
||||
monkeypatch.setattr(routes_queen_tools, "ensure_default_queens", lambda: None)
|
||||
queens_dir, queen_id = queen_dir
|
||||
|
||||
manager = _FakeManager()
|
||||
manager._mcp_tool_catalog = {
|
||||
"coder-tools": [
|
||||
{"name": "read_file", "description": "", "input_schema": {}},
|
||||
{"name": "write_file", "description": "", "input_schema": {}},
|
||||
]
|
||||
}
|
||||
|
||||
app = await _make_app(manager=manager)
|
||||
tools_path = queens_dir / queen_id / "tools.json"
|
||||
profile_path = queens_dir / queen_id / "profile.yaml"
|
||||
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
# Happy path
|
||||
resp = await client.patch(
|
||||
f"/api/queen/{queen_id}/tools",
|
||||
json={"enabled_mcp_tools": ["read_file"]},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["enabled_mcp_tools"] == ["read_file"]
|
||||
|
||||
# Sidecar persisted; profile YAML untouched by tools PATCH
|
||||
sidecar = json.loads(tools_path.read_text())
|
||||
assert sidecar["enabled_mcp_tools"] == ["read_file"]
|
||||
assert "updated_at" in sidecar
|
||||
profile = yaml.safe_load(profile_path.read_text())
|
||||
assert "enabled_mcp_tools" not in profile
|
||||
|
||||
# GET reflects the new state
|
||||
resp = await client.get(f"/api/queen/{queen_id}/tools")
|
||||
body = await resp.json()
|
||||
assert body["is_role_default"] is False # user has explicitly saved
|
||||
servers = {t["name"]: t for t in body["mcp_servers"][0]["tools"]}
|
||||
assert servers["read_file"]["enabled"] is True
|
||||
assert servers["write_file"]["enabled"] is False
|
||||
|
||||
# Null resets
|
||||
resp = await client.patch(f"/api/queen/{queen_id}/tools", json={"enabled_mcp_tools": None})
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["enabled_mcp_tools"] is None
|
||||
sidecar = json.loads(tools_path.read_text())
|
||||
assert sidecar["enabled_mcp_tools"] is None
|
||||
|
||||
# Unknown tool name → 400; sidecar unchanged
|
||||
resp = await client.patch(
|
||||
f"/api/queen/{queen_id}/tools",
|
||||
json={"enabled_mcp_tools": ["nope_not_a_tool"]},
|
||||
)
|
||||
assert resp.status == 400
|
||||
detail = await resp.json()
|
||||
assert "nope_not_a_tool" in detail.get("unknown", [])
|
||||
sidecar = json.loads(tools_path.read_text())
|
||||
assert sidecar["enabled_mcp_tools"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_patch_hot_reloads_live_session(queen_dir, monkeypatch):
|
||||
monkeypatch.setattr(routes_queen_tools, "ensure_default_queens", lambda: None)
|
||||
_, queen_id = queen_dir
|
||||
|
||||
# Build a fake live session whose phase state carries a tool list the
|
||||
# filter can gate. We also need a fake registry so
|
||||
# _catalog_from_live_session can enumerate tools.
|
||||
class _FakeRegistry:
|
||||
def __init__(self, server_map, tools_by_name):
|
||||
self._mcp_server_tools = server_map
|
||||
self._tools_by_name = tools_by_name
|
||||
|
||||
def get_tools(self):
|
||||
return {n: MagicMock(name=n) for n in self._tools_by_name}
|
||||
|
||||
tools_by_name = {"read_file": _tool("read_file"), "write_file": _tool("write_file")}
|
||||
registry = _FakeRegistry(
|
||||
server_map={"coder-tools": {"read_file", "write_file"}},
|
||||
tools_by_name=tools_by_name,
|
||||
)
|
||||
# Patch get_tools to return real Tool objects for name/description plumbing.
|
||||
registry.get_tools = lambda: tools_by_name # type: ignore[method-assign]
|
||||
|
||||
phase_state = QueenPhaseState(phase="independent")
|
||||
phase_state.independent_tools = [tools_by_name["read_file"], tools_by_name["write_file"]]
|
||||
phase_state.mcp_tool_names_all = {"read_file", "write_file"}
|
||||
phase_state.enabled_mcp_tools = None
|
||||
phase_state.rebuild_independent_filter()
|
||||
|
||||
session = _FakeSession(queen_name=queen_id, phase_state=phase_state)
|
||||
session._queen_tool_registry = registry
|
||||
manager = _FakeManager(_sessions={"sess-1": session})
|
||||
|
||||
app = await _make_app(manager=manager)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.patch(
|
||||
f"/api/queen/{queen_id}/tools",
|
||||
json={"enabled_mcp_tools": ["read_file"]},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["refreshed_sessions"] == 1
|
||||
|
||||
# Session's phase state reflects the new allowlist without a restart
|
||||
current = phase_state.get_current_tools()
|
||||
assert [t.name for t in current] == ["read_file"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_queen_returns_404(queen_dir, monkeypatch):
|
||||
monkeypatch.setattr(routes_queen_tools, "ensure_default_queens", lambda: None)
|
||||
manager = _FakeManager()
|
||||
|
||||
app = await _make_app(manager=manager)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.get("/api/queen/queen_nonexistent/tools")
|
||||
assert resp.status == 404
|
||||
|
||||
resp = await client.patch(
|
||||
"/api/queen/queen_nonexistent/tools",
|
||||
json={"enabled_mcp_tools": None},
|
||||
)
|
||||
assert resp.status == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_restores_role_default(queen_dir, monkeypatch):
|
||||
"""DELETE removes tools.json so the queen falls back to the role default."""
|
||||
monkeypatch.setattr(routes_queen_tools, "ensure_default_queens", lambda: None)
|
||||
queens_dir, queen_id = queen_dir
|
||||
tools_path = queens_dir / queen_id / "tools.json"
|
||||
|
||||
manager = _FakeManager()
|
||||
manager._mcp_tool_catalog = {
|
||||
"coder-tools": [
|
||||
{"name": "read_file", "description": "", "input_schema": {}},
|
||||
{"name": "port_scan", "description": "", "input_schema": {}},
|
||||
],
|
||||
}
|
||||
|
||||
app = await _make_app(manager=manager)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
# Seed a custom allowlist first so we have a sidecar to delete.
|
||||
resp = await client.patch(
|
||||
f"/api/queen/{queen_id}/tools",
|
||||
json={"enabled_mcp_tools": ["read_file"]},
|
||||
)
|
||||
assert resp.status == 200
|
||||
assert tools_path.exists()
|
||||
|
||||
resp = await client.delete(f"/api/queen/{queen_id}/tools")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["removed"] is True
|
||||
assert body["is_role_default"] is True
|
||||
assert not tools_path.exists()
|
||||
|
||||
# The new effective list is the role default for queen_technology,
|
||||
# which includes both read_file (file_read) and port_scan (security).
|
||||
enabled = set(body["enabled_mcp_tools"] or [])
|
||||
assert "read_file" in enabled
|
||||
assert "port_scan" in enabled
|
||||
|
||||
# GET confirms.
|
||||
resp = await client.get(f"/api/queen/{queen_id}/tools")
|
||||
body = await resp.json()
|
||||
assert body["is_role_default"] is True
|
||||
|
||||
# Deleting again is a no-op.
|
||||
resp = await client.delete(f"/api/queen/{queen_id}/tools")
|
||||
assert resp.status == 200
|
||||
assert (await resp.json())["removed"] is False
|
||||
|
||||
|
||||
def test_legacy_profile_field_migrates_to_sidecar(queen_dir):
|
||||
"""A legacy enabled_mcp_tools field in profile.yaml is hoisted to tools.json."""
|
||||
queens_dir, queen_id = queen_dir
|
||||
profile_path = queens_dir / queen_id / "profile.yaml"
|
||||
tools_path = queens_dir / queen_id / "tools.json"
|
||||
|
||||
# Seed legacy field in profile.yaml.
|
||||
profile = yaml.safe_load(profile_path.read_text()) or {}
|
||||
profile["enabled_mcp_tools"] = ["read_file", "write_file"]
|
||||
profile_path.write_text(yaml.safe_dump(profile, sort_keys=False))
|
||||
|
||||
from framework.agents.queen.queen_tools_config import load_queen_tools_config
|
||||
|
||||
# First load migrates.
|
||||
assert load_queen_tools_config(queen_id) == ["read_file", "write_file"]
|
||||
assert tools_path.exists()
|
||||
sidecar = json.loads(tools_path.read_text())
|
||||
assert sidecar["enabled_mcp_tools"] == ["read_file", "write_file"]
|
||||
|
||||
# profile.yaml no longer contains the field; other fields preserved.
|
||||
migrated_profile = yaml.safe_load(profile_path.read_text())
|
||||
assert "enabled_mcp_tools" not in migrated_profile
|
||||
assert migrated_profile["name"] == "Alexandra"
|
||||
|
||||
# Second load is a direct read — no migration work to do.
|
||||
assert load_queen_tools_config(queen_id) == ["read_file", "write_file"]
|
||||
@@ -0,0 +1,203 @@
|
||||
"""Shared skill authoring primitives.
|
||||
|
||||
Validates and materializes a skill folder. Used by three callers:
|
||||
|
||||
1. Queen's ``create_colony`` tool (``queen_lifecycle_tools.py``) — inline
|
||||
content passed by the queen during colony creation.
|
||||
2. HTTP POST / PUT routes under ``/api/**/skills`` — UI-driven creation.
|
||||
3. Future ``create_learned_skill`` tool — runtime learning.
|
||||
|
||||
Keeping the validators and writer here ensures the three paths share one
|
||||
authority; changes to the name regex or frontmatter layout happen in one
|
||||
place.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import shutil
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Framework skill names include dots (``hive.note-taking``), so the
|
||||
# validator needs to allow them even though the queen's ``create_colony``
|
||||
# tool historically forbade dots. User-created skills without dots still
|
||||
# pass; the dot cap just prevents us from rejecting existing framework
|
||||
# names when the UI toggles them via ``validate_skill_name``.
|
||||
_SKILL_NAME_RE = re.compile(r"^[a-z0-9.-]+$")
|
||||
_MAX_NAME_LEN = 64
|
||||
_MAX_DESC_LEN = 1024
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillFile:
|
||||
"""Supporting file bundled with a skill (relative path + content)."""
|
||||
|
||||
rel_path: Path
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillDraft:
|
||||
"""Validated skill content ready to be written to disk."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
body: str
|
||||
files: list[SkillFile] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def skill_md_text(self) -> str:
|
||||
"""Assemble the final SKILL.md text (frontmatter + body)."""
|
||||
body_norm = self.body.rstrip() + "\n"
|
||||
return f"---\nname: {self.name}\ndescription: {self.description}\n---\n\n{body_norm}"
|
||||
|
||||
|
||||
def validate_skill_name(raw: str) -> tuple[str | None, str | None]:
|
||||
"""Return ``(normalized_name, error)``. Either side may be None."""
|
||||
name = (raw or "").strip() if isinstance(raw, str) else ""
|
||||
if not name:
|
||||
return None, "skill_name is required"
|
||||
if not _SKILL_NAME_RE.match(name):
|
||||
return None, f"skill_name '{name}' must match [a-z0-9-] pattern"
|
||||
if name.startswith("-") or name.endswith("-") or "--" in name:
|
||||
return None, f"skill_name '{name}' has leading/trailing/consecutive hyphens"
|
||||
if len(name) > _MAX_NAME_LEN:
|
||||
return None, f"skill_name '{name}' exceeds {_MAX_NAME_LEN} chars"
|
||||
return name, None
|
||||
|
||||
|
||||
def validate_description(raw: str) -> tuple[str | None, str | None]:
|
||||
desc = (raw or "").strip() if isinstance(raw, str) else ""
|
||||
if not desc:
|
||||
return None, "skill_description is required"
|
||||
if len(desc) > _MAX_DESC_LEN:
|
||||
return None, f"skill_description must be 1–{_MAX_DESC_LEN} chars"
|
||||
# Frontmatter descriptions are line-oriented — the parser reads one value.
|
||||
if "\n" in desc or "\r" in desc:
|
||||
return None, "skill_description must be a single line (no newlines)"
|
||||
return desc, None
|
||||
|
||||
|
||||
def validate_files(raw: list[dict] | None) -> tuple[list[SkillFile] | None, str | None]:
|
||||
if not raw:
|
||||
return [], None
|
||||
if not isinstance(raw, list):
|
||||
return None, "skill_files must be an array"
|
||||
out: list[SkillFile] = []
|
||||
for entry in raw:
|
||||
if not isinstance(entry, dict):
|
||||
return None, "each skill_files entry must be an object with 'path' and 'content'"
|
||||
rel_raw = entry.get("path")
|
||||
content = entry.get("content")
|
||||
if not isinstance(rel_raw, str) or not rel_raw.strip():
|
||||
return None, "skill_files entry missing non-empty 'path'"
|
||||
if not isinstance(content, str):
|
||||
return None, f"skill_files entry '{rel_raw}' missing string 'content'"
|
||||
rel_stripped = rel_raw.strip()
|
||||
# Allow './foo' but reject '/foo' — relativizing absolute paths silently
|
||||
# has bitten other tools; make the intent loud instead.
|
||||
if rel_stripped.startswith("./"):
|
||||
rel_stripped = rel_stripped[2:]
|
||||
rel_path = Path(rel_stripped)
|
||||
if rel_stripped.startswith("/") or rel_path.is_absolute() or ".." in rel_path.parts:
|
||||
return None, f"skill_files path '{rel_raw}' must be relative and inside the skill folder"
|
||||
if rel_path.as_posix() == "SKILL.md":
|
||||
return None, "skill_files must not contain SKILL.md — pass skill_body instead"
|
||||
out.append(SkillFile(rel_path=rel_path, content=content))
|
||||
return out, None
|
||||
|
||||
|
||||
def build_draft(
|
||||
*,
|
||||
skill_name: str,
|
||||
skill_description: str,
|
||||
skill_body: str,
|
||||
skill_files: list[dict] | None = None,
|
||||
) -> tuple[SkillDraft | None, str | None]:
|
||||
"""Validate all inputs and return an immutable draft ready for writing."""
|
||||
name, err = validate_skill_name(skill_name)
|
||||
if err or name is None:
|
||||
return None, err
|
||||
desc, err = validate_description(skill_description)
|
||||
if err or desc is None:
|
||||
return None, err
|
||||
body = skill_body if isinstance(skill_body, str) else ""
|
||||
if not body.strip():
|
||||
return None, (
|
||||
"skill_body is required — the operational procedure the colony worker needs to run this job unattended"
|
||||
)
|
||||
files, err = validate_files(skill_files)
|
||||
if err or files is None:
|
||||
return None, err
|
||||
return SkillDraft(name=name, description=desc, body=body, files=list(files)), None
|
||||
|
||||
|
||||
def write_skill(
|
||||
draft: SkillDraft,
|
||||
*,
|
||||
target_root: Path,
|
||||
replace_existing: bool = True,
|
||||
) -> tuple[Path | None, str | None, bool]:
|
||||
"""Write the draft under ``target_root/{draft.name}/``.
|
||||
|
||||
``target_root`` is the parent scope dir (e.g.
|
||||
``~/.hive/agents/queens/{id}/skills`` or
|
||||
``{colony_dir}/.hive/skills``). The function creates it if needed.
|
||||
|
||||
Returns ``(installed_path, error, replaced)``. On success ``error`` is
|
||||
``None``; on failure ``installed_path`` is ``None`` and the target is
|
||||
left as it was before the call (best-effort).
|
||||
|
||||
When ``replace_existing=False`` and the target dir already exists,
|
||||
the write is refused with a non-fatal error (caller decides whether
|
||||
to surface it as a 409 or a warning).
|
||||
"""
|
||||
try:
|
||||
target_root.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as e:
|
||||
return None, f"failed to create skills root: {e}", False
|
||||
|
||||
target = target_root / draft.name
|
||||
replaced = False
|
||||
try:
|
||||
if target.exists():
|
||||
if not replace_existing:
|
||||
return None, f"skill '{draft.name}' already exists", False
|
||||
# Remove the old dir outright so stale files from a prior
|
||||
# version don't linger alongside the new ones.
|
||||
replaced = True
|
||||
shutil.rmtree(target)
|
||||
target.mkdir(parents=True, exist_ok=False)
|
||||
(target / "SKILL.md").write_text(draft.skill_md_text, encoding="utf-8")
|
||||
for sf in draft.files:
|
||||
full_path = target / sf.rel_path
|
||||
full_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
full_path.write_text(sf.content, encoding="utf-8")
|
||||
except OSError as e:
|
||||
return None, f"failed to write skill folder {target}: {e}", replaced
|
||||
return target, None, replaced
|
||||
|
||||
|
||||
def remove_skill(target_root: Path, skill_name: str) -> tuple[bool, str | None]:
|
||||
"""Rm-tree the skill directory under ``target_root/{skill_name}/``.
|
||||
|
||||
Returns ``(removed, error)``. ``removed=False, error=None`` means
|
||||
the directory didn't exist (idempotent). Name is validated on the
|
||||
way in so an attacker with UI access can't traverse out of the
|
||||
scope root.
|
||||
"""
|
||||
name, err = validate_skill_name(skill_name)
|
||||
if err or name is None:
|
||||
return False, err
|
||||
target = target_root / name
|
||||
if not target.exists():
|
||||
return False, None
|
||||
try:
|
||||
shutil.rmtree(target)
|
||||
except OSError as e:
|
||||
return False, f"failed to remove skill folder {target}: {e}"
|
||||
return True, None
|
||||
@@ -7,7 +7,7 @@ locations. Resolves name collisions deterministically.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from framework.skills.parser import ParsedSkill, parse_skill_md
|
||||
@@ -30,16 +30,40 @@ _SKIP_DIRS = frozenset(
|
||||
)
|
||||
|
||||
# Scope priority (higher = takes precedence)
|
||||
# ``preset`` sits between framework and user: bundled alongside the
|
||||
# framework distribution, but off by default — capability packs the user
|
||||
# opts into per queen/colony rather than globally-enabled infra.
|
||||
_SCOPE_PRIORITY = {
|
||||
"framework": 0,
|
||||
"user": 1,
|
||||
"project": 2,
|
||||
"preset": 1,
|
||||
"user": 2,
|
||||
"queen_ui": 3,
|
||||
"colony_ui": 4,
|
||||
"project": 5,
|
||||
}
|
||||
|
||||
# Within the same scope, Hive-specific paths override cross-client paths.
|
||||
# We encode this by scanning cross-client first, then Hive-specific (later wins).
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtraScope:
|
||||
"""Additional scope dir to scan beyond the standard five.
|
||||
|
||||
Used by :class:`framework.skills.manager.SkillsManager` to surface
|
||||
per-queen (``queen_ui``) and per-colony (``colony_ui``) skill
|
||||
directories created through the UI. The ``label`` feeds
|
||||
:attr:`ParsedSkill.source_scope` so downstream consumers (trust
|
||||
gate, UI provenance resolver) can distinguish scope origins.
|
||||
"""
|
||||
|
||||
directory: Path
|
||||
label: str
|
||||
# Kept for forward-compat with the priority table; discovery itself
|
||||
# relies on scan order for last-wins resolution.
|
||||
priority: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiscoveryConfig:
|
||||
"""Configuration for skill discovery."""
|
||||
@@ -49,6 +73,10 @@ class DiscoveryConfig:
|
||||
skip_framework_scope: bool = False
|
||||
max_depth: int = 4
|
||||
max_dirs: int = 2000
|
||||
# Additional scope dirs scanned between user and project scopes,
|
||||
# in the order they are provided. Use ``ExtraScope`` to tag each
|
||||
# with its logical label (``queen_ui`` / ``colony_ui``).
|
||||
extra_scopes: list[ExtraScope] = field(default_factory=list)
|
||||
|
||||
|
||||
class SkillDiscovery:
|
||||
@@ -82,13 +110,22 @@ class SkillDiscovery:
|
||||
all_skills: list[ParsedSkill] = []
|
||||
self._scanned_dirs = []
|
||||
|
||||
# Framework scope (lowest precedence)
|
||||
# Framework scope (lowest precedence) — always-on infra skills.
|
||||
if not self._config.skip_framework_scope:
|
||||
framework_dir = Path(__file__).parent / "_default_skills"
|
||||
if framework_dir.is_dir():
|
||||
self._scanned_dirs.append(framework_dir)
|
||||
all_skills.extend(self._scan_scope(framework_dir, "framework"))
|
||||
|
||||
# Preset scope — bundled capability packs that ship with the
|
||||
# framework but default to OFF. User opts in per queen/colony
|
||||
# via the Skills Library. ``skip_framework_scope`` covers both
|
||||
# bundled directories since they live side-by-side on disk.
|
||||
preset_dir = Path(__file__).parent / "_preset_skills"
|
||||
if preset_dir.is_dir():
|
||||
self._scanned_dirs.append(preset_dir)
|
||||
all_skills.extend(self._scan_scope(preset_dir, "preset"))
|
||||
|
||||
# User scope
|
||||
if not self._config.skip_user_scope:
|
||||
home = Path.home()
|
||||
@@ -105,6 +142,13 @@ class SkillDiscovery:
|
||||
self._scanned_dirs.append(user_hive)
|
||||
all_skills.extend(self._scan_scope(user_hive, "user"))
|
||||
|
||||
# Extra scopes (queen_ui / colony_ui), scanned between user and project
|
||||
# so colony overrides beat queen overrides, and both beat user-scope.
|
||||
for extra in self._config.extra_scopes:
|
||||
if extra.directory.is_dir():
|
||||
self._scanned_dirs.append(extra.directory)
|
||||
all_skills.extend(self._scan_scope(extra.directory, extra.label))
|
||||
|
||||
# Project scope (highest precedence)
|
||||
if self._config.project_root:
|
||||
root = self._config.project_root
|
||||
|
||||
@@ -23,6 +23,7 @@ Typical usage — **bare** (exported agents, SDK users)::
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
@@ -44,6 +45,18 @@ class SkillsManagerConfig:
|
||||
even when ``project_root`` is set.
|
||||
interactive: Whether trust gating can prompt the user interactively.
|
||||
When ``False``, untrusted project skills are silently skipped.
|
||||
queen_id: Optional queen identifier. When set, enables the
|
||||
``queen_ui`` scope and per-queen override file.
|
||||
queen_overrides_path: Path to
|
||||
``~/.hive/agents/queens/{queen_id}/skills_overrides.json``.
|
||||
When set, the store is loaded and its entries override
|
||||
discovery results (disable skills, record provenance).
|
||||
colony_name: Optional colony identifier; mirrors ``queen_id`` for
|
||||
the ``colony_ui`` scope.
|
||||
colony_overrides_path: Per-colony override file path.
|
||||
extra_scope_dirs: Extra scope dirs scanned between user and
|
||||
project scopes. Typically populated by the caller with the
|
||||
queen/colony UI skill directories.
|
||||
"""
|
||||
|
||||
skills_config: SkillsConfig = field(default_factory=SkillsConfig)
|
||||
@@ -51,6 +64,15 @@ class SkillsManagerConfig:
|
||||
skip_community_discovery: bool = False
|
||||
interactive: bool = True
|
||||
|
||||
# Override support
|
||||
queen_id: str | None = None
|
||||
queen_overrides_path: Path | None = None
|
||||
colony_name: str | None = None
|
||||
colony_overrides_path: Path | None = None
|
||||
# Typed at the call site as ``list[ExtraScope]`` — not imported here
|
||||
# to keep this module free of discovery-layer dependencies.
|
||||
extra_scope_dirs: list = field(default_factory=list)
|
||||
|
||||
|
||||
class SkillsManager:
|
||||
"""Unified skill lifecycle: discovery → loading → prompt renderation.
|
||||
@@ -65,13 +87,21 @@ class SkillsManager:
|
||||
self._config = config or SkillsManagerConfig()
|
||||
self._loaded = False
|
||||
self._catalog: object = None # SkillCatalog, set after load()
|
||||
self._all_skills: list = [] # list[ParsedSkill], pre-override-filter
|
||||
self._catalog_prompt: str = ""
|
||||
self._protocols_prompt: str = ""
|
||||
self._allowlisted_dirs: list[str] = []
|
||||
self._default_mgr: object = None # DefaultSkillManager, set after load()
|
||||
# Override stores (loaded lazily in _do_load). Queen-scope and
|
||||
# colony-scope are read together; colony entries win on collision.
|
||||
self._queen_overrides: object = None # SkillOverrideStore | None
|
||||
self._colony_overrides: object = None # SkillOverrideStore | None
|
||||
# Hot-reload state
|
||||
self._watched_dirs: list[str] = []
|
||||
self._watched_files: list[str] = []
|
||||
self._watcher_task: object = None # asyncio.Task, set by start_watching()
|
||||
# Serializes in-process mutations (HTTP handlers + create_colony).
|
||||
self._mutation_lock = asyncio.Lock()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Factory for backwards-compat bridge
|
||||
@@ -119,6 +149,7 @@ class SkillsManager:
|
||||
from framework.skills.catalog import SkillCatalog
|
||||
from framework.skills.defaults import DefaultSkillManager
|
||||
from framework.skills.discovery import DiscoveryConfig, SkillDiscovery
|
||||
from framework.skills.overrides import SkillOverrideStore
|
||||
|
||||
skills_config = self._config.skills_config
|
||||
|
||||
@@ -128,12 +159,13 @@ class SkillsManager:
|
||||
DiscoveryConfig(
|
||||
project_root=self._config.project_root,
|
||||
skip_framework_scope=False,
|
||||
extra_scopes=list(self._config.extra_scope_dirs or []),
|
||||
)
|
||||
)
|
||||
discovered = discovery.discover()
|
||||
self._watched_dirs = discovery.scanned_directories
|
||||
|
||||
# Trust-gate project-scope skills (AS-13)
|
||||
# Trust-gate project-scope skills (AS-13). UI scopes bypass.
|
||||
if self._config.project_root is not None and not self._config.skip_community_discovery:
|
||||
from framework.skills.trust import TrustGate
|
||||
|
||||
@@ -141,6 +173,31 @@ class SkillsManager:
|
||||
discovered, project_dir=self._config.project_root
|
||||
)
|
||||
|
||||
# 1b. Load per-scope override stores. Missing files → empty stores.
|
||||
queen_store = None
|
||||
if self._config.queen_overrides_path is not None:
|
||||
queen_store = SkillOverrideStore.load(
|
||||
self._config.queen_overrides_path,
|
||||
scope_label=f"queen:{self._config.queen_id or ''}",
|
||||
)
|
||||
colony_store = None
|
||||
if self._config.colony_overrides_path is not None:
|
||||
colony_store = SkillOverrideStore.load(
|
||||
self._config.colony_overrides_path,
|
||||
scope_label=f"colony:{self._config.colony_name or ''}",
|
||||
)
|
||||
self._queen_overrides = queen_store
|
||||
self._colony_overrides = colony_store
|
||||
self._watched_files = [
|
||||
str(p) for p in (self._config.queen_overrides_path, self._config.colony_overrides_path) if p is not None
|
||||
]
|
||||
|
||||
# 1c. Apply override filtering. Colony entries take precedence over
|
||||
# queen entries on name collision; the store's ``is_disabled`` keeps
|
||||
# the resolution rule in one place.
|
||||
self._all_skills = list(discovered)
|
||||
discovered = self._apply_overrides(discovered, skills_config, queen_store, colony_store)
|
||||
|
||||
catalog = SkillCatalog(discovered)
|
||||
self._catalog = catalog
|
||||
self._allowlisted_dirs = catalog.allowlisted_dirs
|
||||
@@ -174,6 +231,101 @@ class SkillsManager:
|
||||
len(catalog_prompt),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Override application
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _apply_overrides(
|
||||
discovered: list,
|
||||
skills_config: SkillsConfig,
|
||||
queen_store: object,
|
||||
colony_store: object,
|
||||
) -> list:
|
||||
"""Filter ``discovered`` per the queen + colony override stores.
|
||||
|
||||
Resolution rule:
|
||||
1. Tombstoned names (``deleted_ui_skills``) drop out.
|
||||
2. An explicit ``enabled=False`` override drops the skill.
|
||||
3. An explicit ``enabled=True`` override keeps it (wins over
|
||||
``all_defaults_disabled`` for framework defaults AND over the
|
||||
preset-scope default-off rule).
|
||||
4. Otherwise: preset-scope skills are off by default; everything
|
||||
else inherits :meth:`SkillsConfig.is_default_enabled`.
|
||||
"""
|
||||
from framework.skills.overrides import SkillOverrideStore
|
||||
|
||||
stores: list[SkillOverrideStore] = [s for s in (queen_store, colony_store) if s is not None]
|
||||
|
||||
tombstones: set[str] = set()
|
||||
for store in stores:
|
||||
tombstones |= set(store.deleted_ui_skills)
|
||||
|
||||
out = []
|
||||
for skill in discovered:
|
||||
if skill.name in tombstones:
|
||||
continue
|
||||
# Check colony first so colony overrides win over queen's.
|
||||
explicit: bool | None = None
|
||||
master_disabled = False
|
||||
for store in reversed(stores): # colony, then queen
|
||||
entry = store.get(skill.name)
|
||||
if entry is not None and entry.enabled is not None:
|
||||
explicit = entry.enabled
|
||||
break
|
||||
if store.all_defaults_disabled:
|
||||
master_disabled = True
|
||||
if explicit is False:
|
||||
continue
|
||||
if explicit is True:
|
||||
out.append(skill)
|
||||
continue
|
||||
# Preset-scope capability packs are bundled but ship OFF; the
|
||||
# user must explicitly enable them per queen or colony. This
|
||||
# runs even when no store is present so bare agents don't
|
||||
# silently load x-automation etc.
|
||||
if skill.source_scope == "preset":
|
||||
continue
|
||||
# No explicit entry — master switch takes effect against framework defaults.
|
||||
default_enabled = skills_config.is_default_enabled(skill.name)
|
||||
if master_disabled and default_enabled and skill.source_scope == "framework":
|
||||
continue
|
||||
if default_enabled:
|
||||
out.append(skill)
|
||||
return out
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Override accessors
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def queen_overrides(self) -> object:
|
||||
"""The queen-scope :class:`SkillOverrideStore` or ``None``."""
|
||||
return self._queen_overrides
|
||||
|
||||
@property
|
||||
def colony_overrides(self) -> object:
|
||||
"""The colony-scope :class:`SkillOverrideStore` or ``None``."""
|
||||
return self._colony_overrides
|
||||
|
||||
@property
|
||||
def mutation_lock(self) -> asyncio.Lock:
|
||||
"""Serializes in-process override mutations (routes + queen tools)."""
|
||||
return self._mutation_lock
|
||||
|
||||
def reload(self) -> None:
|
||||
"""Re-run discovery and rebuild cached prompts. Public wrapper for ``_reload``."""
|
||||
self._reload()
|
||||
|
||||
def enumerate_skills_with_source(self) -> list:
|
||||
"""Return every discovered skill, including ones disabled by overrides.
|
||||
|
||||
The UI relies on this: a disabled framework skill needs to render
|
||||
in the list so the user can toggle it back on. The post-filter
|
||||
catalog omits those entries.
|
||||
"""
|
||||
return list(self._all_skills)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Hot-reload: watch skill directories for SKILL.md changes.
|
||||
# ------------------------------------------------------------------
|
||||
@@ -181,14 +333,14 @@ class SkillsManager:
|
||||
async def start_watching(self) -> None:
|
||||
"""Start a background task watching skill directories for changes.
|
||||
|
||||
When a ``SKILL.md`` file is added/modified/removed, the cached
|
||||
``skills_catalog_prompt`` is rebuilt. The next node iteration picks
|
||||
up the new prompt automatically via the ``dynamic_prompt_provider``.
|
||||
Triggers a reload when any ``SKILL.md`` changes or an override
|
||||
JSON file is modified. The next node iteration picks up the new
|
||||
prompt via the ``dynamic_prompt_provider`` / per-worker
|
||||
``dynamic_skills_catalog_provider``.
|
||||
|
||||
Silently no-ops when ``watchfiles`` is not installed or when no
|
||||
directories are being watched (e.g. bare mode, no project_root).
|
||||
Silently no-ops when ``watchfiles`` is not installed or there
|
||||
are no paths to watch.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
import watchfiles # noqa: F401 -- optional dep check
|
||||
@@ -196,7 +348,7 @@ class SkillsManager:
|
||||
logger.debug("watchfiles not installed; skill hot-reload disabled")
|
||||
return
|
||||
|
||||
if not self._watched_dirs:
|
||||
if not self._watched_dirs and not self._watched_files:
|
||||
logger.debug("No skill directories to watch; hot-reload skipped")
|
||||
return
|
||||
|
||||
@@ -208,14 +360,13 @@ class SkillsManager:
|
||||
name="skills-hot-reload",
|
||||
)
|
||||
logger.info(
|
||||
"Skill hot-reload enabled (watching %d directories)",
|
||||
"Skill hot-reload enabled (watching %d dirs, %d override files)",
|
||||
len(self._watched_dirs),
|
||||
len(self._watched_files),
|
||||
)
|
||||
|
||||
async def stop_watching(self) -> None:
|
||||
"""Cancel the background watcher task (if running)."""
|
||||
import asyncio
|
||||
|
||||
task = self._watcher_task
|
||||
if task is None:
|
||||
return
|
||||
@@ -228,22 +379,35 @@ class SkillsManager:
|
||||
pass
|
||||
|
||||
async def _watch_loop(self) -> None:
|
||||
"""Background coroutine that watches SKILL.md files and triggers reload."""
|
||||
import asyncio
|
||||
|
||||
"""Watch SKILL.md + override JSON files and trigger reload on change."""
|
||||
import watchfiles
|
||||
|
||||
def _filter(_change: object, path: str) -> bool:
|
||||
return path.endswith("SKILL.md")
|
||||
return path.endswith("SKILL.md") or path.endswith("skills_overrides.json")
|
||||
|
||||
# watchfiles accepts a mix of dirs and files; file watches survive
|
||||
# a tmp+rename (the containing dir sees the event).
|
||||
watch_targets = list(self._watched_dirs)
|
||||
for f in self._watched_files:
|
||||
# watchfiles needs the parent dir for file-level events to fire
|
||||
# reliably through atomic replace; adding the file path directly
|
||||
# works on Linux/macOS inotify/FSEvents but a dir watch is
|
||||
# belt-and-braces.
|
||||
parent = str(Path(f).parent)
|
||||
if parent not in watch_targets:
|
||||
watch_targets.append(parent)
|
||||
|
||||
if not watch_targets:
|
||||
return
|
||||
|
||||
try:
|
||||
async for changes in watchfiles.awatch(
|
||||
*self._watched_dirs,
|
||||
*watch_targets,
|
||||
watch_filter=_filter,
|
||||
debounce=1000,
|
||||
):
|
||||
paths = [p for _, p in changes]
|
||||
logger.info("SKILL.md changes detected: %s", paths)
|
||||
logger.info("Skill state changes detected: %s", paths)
|
||||
try:
|
||||
self._reload()
|
||||
except Exception:
|
||||
|
||||
@@ -0,0 +1,254 @@
|
||||
"""Per-scope skill override store.
|
||||
|
||||
Sits between :mod:`framework.skills.discovery` and
|
||||
:class:`framework.skills.catalog.SkillCatalog`: records the user's
|
||||
per-queen and per-colony decisions about which skills are enabled,
|
||||
who created them (provenance), and any parameter tweaks.
|
||||
|
||||
Two well-known paths back this module:
|
||||
|
||||
* Queen scope: ``~/.hive/agents/queens/{queen_id}/skills_overrides.json``
|
||||
* Colony scope: ``~/.hive/colonies/{colony_name}/skills_overrides.json``
|
||||
|
||||
The schema is intentionally small; see :class:`SkillOverrideStore` for
|
||||
the JSON shape. Atomic writes mirror
|
||||
:class:`framework.skills.trust.TrustedRepoStore` (tmp + rename).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SCHEMA_VERSION = 1
|
||||
|
||||
|
||||
class Provenance(StrEnum):
|
||||
"""Where a skill came from.
|
||||
|
||||
The override store is the authoritative provenance ledger for anything
|
||||
the UI or the queen tools touched. Framework / user-dropped /
|
||||
project-dropped skills don't need an entry unless they've been
|
||||
explicitly configured.
|
||||
"""
|
||||
|
||||
FRAMEWORK = "framework"
|
||||
PRESET = "preset"
|
||||
USER_DROPPED = "user_dropped"
|
||||
USER_UI_CREATED = "user_ui_created"
|
||||
QUEEN_CREATED = "queen_created"
|
||||
LEARNED_RUNTIME = "learned_runtime"
|
||||
PROJECT_DROPPED = "project_dropped"
|
||||
# Catch-all for skills with no recorded authorship: legacy rows from
|
||||
# before the override store existed, PATCHes that precede any CREATE,
|
||||
# etc. Keeps the ledger honest rather than forcing a guess.
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OverrideEntry:
|
||||
"""Per-skill override record inside a scope's store."""
|
||||
|
||||
enabled: bool | None = None
|
||||
provenance: Provenance = Provenance.FRAMEWORK
|
||||
trust: str | None = None
|
||||
param_overrides: dict[str, Any] = field(default_factory=dict)
|
||||
notes: str | None = None
|
||||
created_at: datetime | None = None
|
||||
created_by: str | None = None
|
||||
|
||||
def clone(self) -> OverrideEntry:
|
||||
"""Return a deep-enough copy (dict fields are re-allocated)."""
|
||||
return OverrideEntry(
|
||||
enabled=self.enabled,
|
||||
provenance=self.provenance,
|
||||
trust=self.trust,
|
||||
param_overrides=dict(self.param_overrides),
|
||||
notes=self.notes,
|
||||
created_at=self.created_at,
|
||||
created_by=self.created_by,
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
out: dict[str, Any] = {"provenance": str(self.provenance)}
|
||||
if self.enabled is not None:
|
||||
out["enabled"] = bool(self.enabled)
|
||||
if self.trust is not None:
|
||||
out["trust"] = self.trust
|
||||
if self.param_overrides:
|
||||
out["param_overrides"] = dict(self.param_overrides)
|
||||
if self.notes is not None:
|
||||
out["notes"] = self.notes
|
||||
if self.created_at is not None:
|
||||
out["created_at"] = self.created_at.isoformat()
|
||||
if self.created_by is not None:
|
||||
out["created_by"] = self.created_by
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, raw: dict[str, Any]) -> OverrideEntry:
|
||||
created_at_raw = raw.get("created_at")
|
||||
created_at: datetime | None = None
|
||||
if isinstance(created_at_raw, str):
|
||||
try:
|
||||
created_at = datetime.fromisoformat(created_at_raw)
|
||||
except ValueError:
|
||||
created_at = None
|
||||
provenance_raw = raw.get("provenance") or Provenance.FRAMEWORK
|
||||
try:
|
||||
provenance = Provenance(provenance_raw)
|
||||
except ValueError:
|
||||
logger.warning("override: unknown provenance %r; defaulting to framework", provenance_raw)
|
||||
provenance = Provenance.FRAMEWORK
|
||||
enabled = raw.get("enabled")
|
||||
return cls(
|
||||
enabled=enabled if isinstance(enabled, bool) else None,
|
||||
provenance=provenance,
|
||||
trust=raw.get("trust") if isinstance(raw.get("trust"), str) else None,
|
||||
param_overrides=dict(raw.get("param_overrides") or {}),
|
||||
notes=raw.get("notes") if isinstance(raw.get("notes"), str) else None,
|
||||
created_at=created_at,
|
||||
created_by=raw.get("created_by") if isinstance(raw.get("created_by"), str) else None,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillOverrideStore:
|
||||
"""Persistent per-scope override file.
|
||||
|
||||
The file is created lazily on first save; a missing file behaves like
|
||||
an empty store (all skills inherit defaults, no metadata recorded).
|
||||
"""
|
||||
|
||||
path: Path
|
||||
scope_label: str = ""
|
||||
version: int = _SCHEMA_VERSION
|
||||
all_defaults_disabled: bool = False
|
||||
overrides: dict[str, OverrideEntry] = field(default_factory=dict)
|
||||
deleted_ui_skills: set[str] = field(default_factory=set)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Factory
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Path, scope_label: str = "") -> SkillOverrideStore:
|
||||
"""Load the store from disk; return an empty store if the file is absent.
|
||||
|
||||
Permissive on parse errors: logs and returns an empty store rather
|
||||
than raising, so a corrupted file never takes down skill loading.
|
||||
"""
|
||||
store = cls(path=path, scope_label=scope_label)
|
||||
try:
|
||||
raw = json.loads(path.read_text(encoding="utf-8"))
|
||||
except FileNotFoundError:
|
||||
return store
|
||||
except Exception as exc:
|
||||
logger.warning("override: failed to read %s (%s); starting empty", path, exc)
|
||||
return store
|
||||
if not isinstance(raw, dict):
|
||||
logger.warning("override: %s is not an object; starting empty", path)
|
||||
return store
|
||||
|
||||
store.version = int(raw.get("version", _SCHEMA_VERSION))
|
||||
store.all_defaults_disabled = bool(raw.get("all_defaults_disabled", False))
|
||||
raw_overrides = raw.get("overrides") or {}
|
||||
if isinstance(raw_overrides, dict):
|
||||
for name, entry_raw in raw_overrides.items():
|
||||
if not isinstance(name, str) or not isinstance(entry_raw, dict):
|
||||
continue
|
||||
store.overrides[name] = OverrideEntry.from_dict(entry_raw)
|
||||
deleted = raw.get("deleted_ui_skills") or []
|
||||
if isinstance(deleted, list):
|
||||
store.deleted_ui_skills = {s for s in deleted if isinstance(s, str)}
|
||||
return store
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Mutations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def upsert(self, skill_name: str, entry: OverrideEntry) -> None:
|
||||
"""Insert or replace a skill's override entry."""
|
||||
self.overrides[skill_name] = entry
|
||||
# If we're explicitly managing this skill again, lift any tombstone.
|
||||
self.deleted_ui_skills.discard(skill_name)
|
||||
|
||||
def set_enabled(self, skill_name: str, enabled: bool, *, provenance: Provenance | None = None) -> None:
|
||||
"""Convenience: toggle enabled without rewriting other fields."""
|
||||
existing = self.overrides.get(skill_name)
|
||||
if existing is None:
|
||||
existing = OverrideEntry(
|
||||
enabled=enabled,
|
||||
provenance=provenance or Provenance.FRAMEWORK,
|
||||
)
|
||||
else:
|
||||
existing.enabled = enabled
|
||||
if provenance is not None:
|
||||
existing.provenance = provenance
|
||||
self.overrides[skill_name] = existing
|
||||
|
||||
def remove(self, skill_name: str, *, tombstone: bool = True) -> None:
|
||||
"""Drop a skill's override entry; optionally leave a tombstone.
|
||||
|
||||
Tombstones matter for UI-created skills: if the user deletes a
|
||||
queen-scope skill via the UI, we rm-tree its directory, but the
|
||||
file watcher might lag or a background process might have an
|
||||
open handle. A tombstone ensures the loader treats the skill as
|
||||
gone even if a stale SKILL.md lingers.
|
||||
"""
|
||||
self.overrides.pop(skill_name, None)
|
||||
if tombstone:
|
||||
self.deleted_ui_skills.add(skill_name)
|
||||
|
||||
def is_disabled(self, skill_name: str, *, default_enabled: bool) -> bool:
|
||||
"""Return True when this scope's override force-disables the skill."""
|
||||
if self.all_defaults_disabled and default_enabled:
|
||||
# Caller says "default enabled"; master switch flips it off unless
|
||||
# an explicit enabled=True override re-enables.
|
||||
entry = self.overrides.get(skill_name)
|
||||
if entry is not None and entry.enabled is True:
|
||||
return False
|
||||
return True
|
||||
entry = self.overrides.get(skill_name)
|
||||
if entry is None:
|
||||
return not default_enabled
|
||||
if entry.enabled is None:
|
||||
return not default_enabled
|
||||
return not entry.enabled
|
||||
|
||||
def effective_enabled(self, skill_name: str, *, default_enabled: bool) -> bool:
|
||||
"""The inverse of :meth:`is_disabled`, for readability at call sites."""
|
||||
return not self.is_disabled(skill_name, default_enabled=default_enabled)
|
||||
|
||||
def get(self, skill_name: str) -> OverrideEntry | None:
|
||||
return self.overrides.get(skill_name)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Persistence
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def save(self) -> None:
|
||||
"""Atomic write: tmp + rename. Creates the parent dir if needed."""
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload: dict[str, Any] = {
|
||||
"version": self.version,
|
||||
"all_defaults_disabled": self.all_defaults_disabled,
|
||||
"overrides": {name: entry.to_dict() for name, entry in sorted(self.overrides.items())},
|
||||
}
|
||||
if self.deleted_ui_skills:
|
||||
payload["deleted_ui_skills"] = sorted(self.deleted_ui_skills)
|
||||
tmp = self.path.with_suffix(self.path.suffix + ".tmp")
|
||||
tmp.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
||||
tmp.replace(self.path)
|
||||
|
||||
|
||||
def utc_now() -> datetime:
|
||||
"""Single source of truth for override timestamps."""
|
||||
return datetime.now(tz=UTC)
|
||||
@@ -20,9 +20,17 @@ from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_SKILLS_DIR = Path(__file__).parent / "_default_skills"
|
||||
# Bundled skills live in two sibling dirs: ``_default_skills`` (always-on
|
||||
# infra) and ``_preset_skills`` (capability packs, off by default but
|
||||
# still bundled). Tool-gated pre-activation walks both so ``browser_*``
|
||||
# tools still pull in the browser-automation preset even though it isn't
|
||||
# default-enabled in the catalog.
|
||||
_BUNDLED_DIRS: tuple[Path, ...] = (
|
||||
Path(__file__).parent / "_default_skills",
|
||||
Path(__file__).parent / "_preset_skills",
|
||||
)
|
||||
|
||||
# (tool-name prefix, default skill directory name, display name)
|
||||
# (tool-name prefix, skill directory name, display name)
|
||||
_TOOL_GATED_SKILLS: list[tuple[str, str, str]] = [
|
||||
("browser_", "browser-automation", "hive.browser-automation"),
|
||||
]
|
||||
@@ -31,12 +39,23 @@ _BODY_CACHE: dict[str, str] = {}
|
||||
|
||||
|
||||
def _load_body(dir_name: str) -> str:
|
||||
"""Load the markdown body of a framework default skill, cached."""
|
||||
"""Load the markdown body of a bundled skill, cached. Searches every
|
||||
bundled directory (default + preset) so the mapping table doesn't
|
||||
need to know which dir a skill lives in.
|
||||
"""
|
||||
if dir_name in _BODY_CACHE:
|
||||
return _BODY_CACHE[dir_name]
|
||||
|
||||
path = _DEFAULT_SKILLS_DIR / dir_name / "SKILL.md"
|
||||
path: Path | None = None
|
||||
for parent in _BUNDLED_DIRS:
|
||||
candidate = parent / dir_name / "SKILL.md"
|
||||
if candidate.exists():
|
||||
path = candidate
|
||||
break
|
||||
body = ""
|
||||
if path is None:
|
||||
_BODY_CACHE[dir_name] = body
|
||||
return body
|
||||
try:
|
||||
raw = path.read_text(encoding="utf-8")
|
||||
# Strip YAML frontmatter (between the first two '---' fences)
|
||||
|
||||
@@ -318,13 +318,19 @@ class TrustGate:
|
||||
) -> list[ParsedSkill]:
|
||||
"""Return the subset of skills that are trusted for loading.
|
||||
|
||||
- Framework and user-scope skills: always included.
|
||||
- Framework, user, queen_ui, and colony_ui scopes: always included.
|
||||
(UI-created skills are authenticated by the user creating them
|
||||
through the authenticated UI — they do not go through the
|
||||
trusted_repos.json flow.)
|
||||
- Project-scope skills: classified; consent prompt shown if untrusted.
|
||||
"""
|
||||
import os
|
||||
|
||||
# Separate project skills from always-trusted scopes
|
||||
always_trusted = [s for s in skills if s.source_scope != "project"]
|
||||
# UI-authored scopes bypass the trust gate — they're implicitly
|
||||
# trusted because the user authored them through the UI. ``preset``
|
||||
# ships with the framework distribution, so it's trusted too.
|
||||
_bypass_scopes = {"framework", "preset", "user", "queen_ui", "colony_ui"}
|
||||
always_trusted = [s for s in skills if s.source_scope in _bypass_scopes]
|
||||
project_skills = [s for s in skills if s.source_scope == "project"]
|
||||
|
||||
if not project_skills:
|
||||
|
||||
@@ -116,6 +116,9 @@ class WorkerSessionAdapter:
|
||||
worker_path: Path | None = None
|
||||
|
||||
|
||||
QUEEN_PHASES: frozenset[str] = frozenset({"independent", "incubating", "working", "reviewing"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueenPhaseState:
|
||||
"""Mutable state container for queen operating phase.
|
||||
@@ -131,7 +134,7 @@ class QueenPhaseState:
|
||||
that trigger phase transitions.
|
||||
"""
|
||||
|
||||
phase: str = "independent" # "independent", "incubating", "working", or "reviewing"
|
||||
phase: str = "independent" # one of QUEEN_PHASES
|
||||
independent_tools: list = field(default_factory=list) # list[Tool]
|
||||
incubating_tools: list = field(default_factory=list) # list[Tool]
|
||||
working_tools: list = field(default_factory=list) # list[Tool]
|
||||
@@ -182,10 +185,31 @@ class QueenPhaseState:
|
||||
# Cached recall blocks — populated async by recall_selector after each turn.
|
||||
_cached_global_recall_block: str = ""
|
||||
_cached_queen_recall_block: str = ""
|
||||
# Cached dynamic system-prompt suffix — frozen at user-turn boundaries so
|
||||
# AgentLoop iterations within a single turn send a byte-stable prompt and
|
||||
# Anthropic's prompt cache keeps the static block warm. Rebuilt by
|
||||
# refresh_dynamic_suffix() on CLIENT_INPUT_RECEIVED and on phase change.
|
||||
_cached_dynamic_suffix: str = ""
|
||||
# Memory directories.
|
||||
global_memory_dir: Path | None = None
|
||||
queen_memory_dir: Path | None = None
|
||||
|
||||
# Per-queen MCP tool allowlist for the INDEPENDENT phase. ``None`` means
|
||||
# "allow every MCP tool" (default, backward-compatible). An explicit list
|
||||
# is authoritative: only tools whose name appears here pass through.
|
||||
# Lifecycle / synthetic tools bypass this gate regardless.
|
||||
enabled_mcp_tools: list[str] | None = None
|
||||
# Union of every MCP-origin tool name currently registered — the set the
|
||||
# allowlist can gate. Populated once at queen boot from
|
||||
# ``ToolRegistry._mcp_server_tools``. Names outside this set (lifecycle,
|
||||
# ``ask_user``) always pass through the filter.
|
||||
mcp_tool_names_all: set = field(default_factory=set)
|
||||
# Memoized output of the filter applied to ``independent_tools``.
|
||||
# Recomputed only when ``enabled_mcp_tools`` or ``independent_tools``
|
||||
# changes, so ``get_current_tools()`` in the independent phase returns
|
||||
# a byte-stable list between saves and the LLM prompt cache stays warm.
|
||||
_filtered_independent_tools: list = field(default_factory=list)
|
||||
|
||||
async def switch_to_working(self, source: str = "tool") -> None:
|
||||
"""Switch to working phase — colony workers are running.
|
||||
|
||||
@@ -204,6 +228,44 @@ class QueenPhaseState:
|
||||
"Colony workers are running. Available tools: " + ", ".join(tool_names) + "."
|
||||
)
|
||||
|
||||
def rebuild_independent_filter(self) -> None:
|
||||
"""Recompute the memoized independent-phase tool list.
|
||||
|
||||
Called once at queen boot (after ``independent_tools``,
|
||||
``mcp_tool_names_all`` and ``enabled_mcp_tools`` are all populated)
|
||||
and again from the tools-PATCH handler whenever the allowlist
|
||||
changes. Keeping the result memoized means the independent-phase
|
||||
branch of ``get_current_tools()`` returns the same Python list
|
||||
object across turns, so the LLM prompt cache stays warm until
|
||||
the user explicitly edits their allowlist.
|
||||
"""
|
||||
if self.enabled_mcp_tools is None:
|
||||
self._filtered_independent_tools = list(self.independent_tools)
|
||||
return
|
||||
allowed = set(self.enabled_mcp_tools)
|
||||
# If ``mcp_tool_names_all`` is empty, every tool falls through the
|
||||
# "not in mcp_tool_names_all" branch below and the allowlist is
|
||||
# silently ignored. That's a fail-open bug (the symptom: a
|
||||
# role-restricted queen sees every MCP tool). Log a warning so the
|
||||
# upstream cause is visible next time it happens.
|
||||
if not self.mcp_tool_names_all:
|
||||
logger.warning(
|
||||
"rebuild_independent_filter: mcp_tool_names_all is empty but "
|
||||
"allowlist has %d entries — allowlist cannot be applied. "
|
||||
"Check that queen boot populated phase_state.mcp_tool_names_all.",
|
||||
len(allowed),
|
||||
)
|
||||
self._filtered_independent_tools = [
|
||||
t for t in self.independent_tools if t.name not in self.mcp_tool_names_all or t.name in allowed
|
||||
]
|
||||
logger.info(
|
||||
"rebuild_independent_filter: allowlist=%d, mcp_names=%d, independent=%d -> filtered=%d",
|
||||
len(allowed),
|
||||
len(self.mcp_tool_names_all),
|
||||
len(self.independent_tools),
|
||||
len(self._filtered_independent_tools),
|
||||
)
|
||||
|
||||
def get_current_tools(self) -> list:
|
||||
"""Return tools for the current phase."""
|
||||
if self.phase == "working":
|
||||
@@ -212,11 +274,29 @@ class QueenPhaseState:
|
||||
return list(self.reviewing_tools)
|
||||
if self.phase == "incubating":
|
||||
return list(self.incubating_tools)
|
||||
# Default / "independent" — DM mode with full MCP tools.
|
||||
return list(self.independent_tools)
|
||||
# Default / "independent" — DM mode with full MCP tools, gated by
|
||||
# the per-queen allowlist. Return the memoized list directly so the
|
||||
# JSON sent to the LLM is byte-identical turn-to-turn.
|
||||
if not self._filtered_independent_tools and self.independent_tools:
|
||||
# Safety net: first-call in tests or code paths that skipped
|
||||
# the explicit boot-time rebuild.
|
||||
self.rebuild_independent_filter()
|
||||
return self._filtered_independent_tools
|
||||
|
||||
def get_current_prompt(self) -> str:
|
||||
"""Return the system prompt for the current phase."""
|
||||
def get_static_prompt(self) -> str:
|
||||
"""Return the stable portion of the system prompt for the current phase.
|
||||
|
||||
Includes identity, phase-role prompt, connected-integrations block,
|
||||
skills catalog, and default skill protocols. These change only on
|
||||
phase transition, queen identity selection, or when the user adds/
|
||||
removes an integration — rare events. Designed to be byte-stable
|
||||
across AgentLoop iterations within a single user turn so that
|
||||
Anthropic's prompt cache keeps this block warm.
|
||||
|
||||
The dynamic tail (recall + timestamp) is returned separately by
|
||||
``get_dynamic_suffix()``; the LLM wrapper emits them as two system
|
||||
content blocks with a cache breakpoint between them.
|
||||
"""
|
||||
if self.phase == "working":
|
||||
base = self.prompt_working
|
||||
elif self.phase == "reviewing":
|
||||
@@ -250,11 +330,51 @@ class QueenPhaseState:
|
||||
parts.append(catalog_prompt)
|
||||
if self.protocols_prompt:
|
||||
parts.append(self.protocols_prompt)
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def refresh_dynamic_suffix(self) -> str:
|
||||
"""Rebuild and cache the dynamic system-prompt suffix.
|
||||
|
||||
The suffix contains recall blocks only. Called from the
|
||||
CLIENT_INPUT_RECEIVED subscriber so the suffix is byte-stable across
|
||||
every AgentLoop iteration within a single user turn.
|
||||
|
||||
Timestamps used to live here too; they were moved into the
|
||||
conversation itself as a ``[YYYY-MM-DD HH:MM TZ]`` prefix on each
|
||||
injected event (see ``drain_injection_queue``) so they ride on
|
||||
byte-stable conversation history instead of busting the
|
||||
per-turn system-prompt cache tail.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
if self._cached_global_recall_block:
|
||||
parts.append(self._cached_global_recall_block)
|
||||
if self._cached_queen_recall_block:
|
||||
parts.append(self._cached_queen_recall_block)
|
||||
return "\n\n".join(parts)
|
||||
self._cached_dynamic_suffix = "\n\n".join(parts)
|
||||
return self._cached_dynamic_suffix
|
||||
|
||||
def get_dynamic_suffix(self) -> str:
|
||||
"""Return the cached dynamic system-prompt suffix.
|
||||
|
||||
Lazily populates on first call so callers don't have to know about
|
||||
the refresh lifecycle. Subsequent calls return the cached string
|
||||
until ``refresh_dynamic_suffix()`` is invoked again.
|
||||
"""
|
||||
if not self._cached_dynamic_suffix:
|
||||
self.refresh_dynamic_suffix()
|
||||
return self._cached_dynamic_suffix
|
||||
|
||||
def get_current_prompt(self) -> str:
|
||||
"""Return the concatenated system prompt (static + dynamic).
|
||||
|
||||
Retained for backward compatibility and for callers that want one
|
||||
string (conversation persistence, debug dumps). The AgentLoop sends
|
||||
the two pieces separately to the LLM so the cache can break between
|
||||
them — see ``get_static_prompt()`` / ``get_dynamic_suffix()``.
|
||||
"""
|
||||
static = self.get_static_prompt()
|
||||
dynamic = self.get_dynamic_suffix()
|
||||
return f"{static}\n\n{dynamic}" if dynamic else static
|
||||
|
||||
async def _emit_phase_event(self) -> None:
|
||||
"""Publish a QUEEN_PHASE_CHANGED event so the frontend updates the tag."""
|
||||
@@ -1382,124 +1502,8 @@ def register_queen_lifecycle_tools(
|
||||
# re-runs idempotent.
|
||||
|
||||
import re as _re
|
||||
import shutil as _shutil
|
||||
|
||||
_COLONY_NAME_RE = _re.compile(r"^[a-z0-9_]+$")
|
||||
_SKILL_NAME_RE = _re.compile(r"^[a-z0-9-]+$")
|
||||
|
||||
def _materialize_skill_folder(
|
||||
*,
|
||||
skill_name: str,
|
||||
skill_description: str,
|
||||
skill_body: str,
|
||||
skill_files: list[dict] | None,
|
||||
colony_dir: Path,
|
||||
) -> tuple[Path | None, str | None, bool]:
|
||||
"""Write a skill folder under ``{colony_dir}/.hive/skills/{name}/`` from inline content.
|
||||
|
||||
The skill is scoped to a single colony: ``SkillDiscovery`` scans
|
||||
``{project_root}/.hive/skills/`` as project-scope, and the
|
||||
colony's worker uses ``project_root = colony_dir`` — so only
|
||||
that colony's workers see it, not every colony on the machine.
|
||||
We deliberately avoid ``~/.hive/skills/`` here because that
|
||||
directory is scanned as user scope and leaks into every agent.
|
||||
|
||||
Returns ``(installed_path, error, replaced)``. On success
|
||||
``error`` is ``None`` and ``installed_path`` is the final
|
||||
location; ``replaced`` is ``True`` when an existing skill with
|
||||
the same name was overwritten. On failure ``installed_path`` is
|
||||
``None``, ``error`` is a human-readable reason, and
|
||||
``replaced`` is ``False``.
|
||||
"""
|
||||
name = (skill_name or "").strip() if isinstance(skill_name, str) else ""
|
||||
if not name:
|
||||
return None, "skill_name is required", False
|
||||
if not _SKILL_NAME_RE.match(name):
|
||||
return None, (f"skill_name '{name}' must match [a-z0-9-] pattern"), False
|
||||
if name.startswith("-") or name.endswith("-") or "--" in name:
|
||||
return None, (f"skill_name '{name}' has leading/trailing/consecutive hyphens"), False
|
||||
if len(name) > 64:
|
||||
return None, f"skill_name '{name}' exceeds 64 chars", False
|
||||
|
||||
desc = (skill_description or "").strip() if isinstance(skill_description, str) else ""
|
||||
if not desc:
|
||||
return None, "skill_description is required", False
|
||||
if len(desc) > 1024:
|
||||
return None, "skill_description must be 1–1024 chars", False
|
||||
# Frontmatter descriptions must stay on a single line because
|
||||
# our frontmatter parser is line-oriented and the downstream
|
||||
# skill loader expects ``description:`` to resolve to one value.
|
||||
if "\n" in desc or "\r" in desc:
|
||||
return None, "skill_description must be a single line (no newlines)", False
|
||||
|
||||
body = skill_body if isinstance(skill_body, str) else ""
|
||||
if not body.strip():
|
||||
return (
|
||||
None,
|
||||
(
|
||||
"skill_body is required — the operational procedure the "
|
||||
"colony worker needs to run this job unattended"
|
||||
),
|
||||
False,
|
||||
)
|
||||
|
||||
# Optional supporting files (scripts/, references/, assets/…).
|
||||
# Each entry: {"path": "<relative>", "content": "<text>"}.
|
||||
normalized_files: list[tuple[Path, str]] = []
|
||||
if skill_files:
|
||||
if not isinstance(skill_files, list):
|
||||
return None, "skill_files must be an array", False
|
||||
for entry in skill_files:
|
||||
if not isinstance(entry, dict):
|
||||
return None, "each skill_files entry must be an object with 'path' and 'content'", False
|
||||
rel_raw = entry.get("path")
|
||||
content = entry.get("content")
|
||||
if not isinstance(rel_raw, str) or not rel_raw.strip():
|
||||
return None, "skill_files entry missing non-empty 'path'", False
|
||||
if not isinstance(content, str):
|
||||
return None, f"skill_files entry '{rel_raw}' missing string 'content'", False
|
||||
rel_stripped = rel_raw.strip()
|
||||
# Normalize a leading ``./`` but do NOT strip bare ``/`` —
|
||||
# an absolute path should be rejected, not silently relativized.
|
||||
if rel_stripped.startswith("./"):
|
||||
rel_stripped = rel_stripped[2:]
|
||||
rel_path = Path(rel_stripped)
|
||||
if rel_stripped.startswith("/") or rel_path.is_absolute() or ".." in rel_path.parts:
|
||||
return None, (f"skill_files path '{rel_raw}' must be relative and inside the skill folder"), False
|
||||
if rel_path.as_posix() == "SKILL.md":
|
||||
return None, ("skill_files must not contain SKILL.md — pass skill_body instead"), False
|
||||
normalized_files.append((rel_path, content))
|
||||
|
||||
target_root = colony_dir / ".hive" / "skills"
|
||||
target = target_root / name
|
||||
try:
|
||||
target_root.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as e:
|
||||
return None, f"failed to create skills root: {e}", False
|
||||
|
||||
replaced = False
|
||||
try:
|
||||
if target.exists():
|
||||
# Queen is re-creating a skill under the same name —
|
||||
# her latest content wins. rmtree first so stale files
|
||||
# from a prior version don't linger alongside the new
|
||||
# ones (copytree with dirs_exist_ok would merge them).
|
||||
replaced = True
|
||||
_shutil.rmtree(target)
|
||||
target.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
body_norm = body.rstrip() + "\n"
|
||||
skill_md_text = f"---\nname: {name}\ndescription: {desc}\n---\n\n{body_norm}"
|
||||
(target / "SKILL.md").write_text(skill_md_text, encoding="utf-8")
|
||||
|
||||
for rel_path, file_content in normalized_files:
|
||||
full_path = target / rel_path
|
||||
full_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
full_path.write_text(file_content, encoding="utf-8")
|
||||
except OSError as e:
|
||||
return None, f"failed to write skill folder {target}: {e}", False
|
||||
|
||||
return target, None, replaced
|
||||
|
||||
def _validate_triggers(raw: Any) -> tuple[list[dict] | None, str | None]:
|
||||
"""Validate and normalize the ``triggers`` argument for create_colony.
|
||||
@@ -1645,17 +1649,26 @@ def register_queen_lifecycle_tools(
|
||||
except OSError as e:
|
||||
return json.dumps({"error": f"failed to create colony dir {colony_dir}: {e}"})
|
||||
|
||||
installed_skill, skill_err, skill_replaced = _materialize_skill_folder(
|
||||
# Validate + write via the shared authoring module so the HTTP
|
||||
# routes and this tool stay in lockstep.
|
||||
from framework.skills.authoring import build_draft, write_skill
|
||||
from framework.skills.overrides import (
|
||||
OverrideEntry,
|
||||
Provenance,
|
||||
SkillOverrideStore,
|
||||
utc_now,
|
||||
)
|
||||
|
||||
draft, draft_err = build_draft(
|
||||
skill_name=skill_name,
|
||||
skill_description=skill_description,
|
||||
skill_body=skill_body,
|
||||
skill_files=skill_files,
|
||||
colony_dir=colony_dir,
|
||||
)
|
||||
if skill_err is not None:
|
||||
if draft_err is not None or draft is None:
|
||||
return json.dumps(
|
||||
{
|
||||
"error": skill_err,
|
||||
"error": draft_err or "invalid skill draft",
|
||||
"hint": (
|
||||
"Provide skill_name (lowercase [a-z0-9-], ≤64 chars), "
|
||||
"skill_description (single line, 1–1024 chars), and "
|
||||
@@ -1668,6 +1681,63 @@ def register_queen_lifecycle_tools(
|
||||
}
|
||||
)
|
||||
|
||||
installed_skill, write_err, skill_replaced = write_skill(
|
||||
draft,
|
||||
target_root=colony_dir / ".hive" / "skills",
|
||||
replace_existing=True,
|
||||
)
|
||||
if write_err is not None or installed_skill is None:
|
||||
return json.dumps(
|
||||
{
|
||||
"error": write_err or "failed to write skill folder",
|
||||
}
|
||||
)
|
||||
|
||||
# Seed the colony's override ledger from the queen's current
|
||||
# state so the colony inherits everything she had enabled (preset
|
||||
# capability packs, toggled-off framework defaults, etc.) at fork
|
||||
# time. The colony then owns its own copy — later queen edits
|
||||
# don't retroactively alter this colony's skill surface.
|
||||
# On top of the seed we upsert the newly-written skill with
|
||||
# QUEEN_CREATED provenance so the UI renders + edits it properly.
|
||||
try:
|
||||
from framework.config import QUEENS_DIR
|
||||
|
||||
overrides_path = colony_dir / "skills_overrides.json"
|
||||
queen_id = getattr(session, "queen_name", None) or "unknown"
|
||||
colony_store = SkillOverrideStore.load(overrides_path, scope_label=f"colony:{cn}")
|
||||
|
||||
queen_overrides_path = QUEENS_DIR / queen_id / "skills_overrides.json"
|
||||
if queen_overrides_path.exists():
|
||||
queen_store = SkillOverrideStore.load(queen_overrides_path, scope_label=f"queen:{queen_id}")
|
||||
# Shallow clone: queen's explicit toggles + master switch
|
||||
# become the colony's starting state. Tombstones propagate
|
||||
# so a queen-deleted UI skill doesn't resurrect here.
|
||||
colony_store.all_defaults_disabled = queen_store.all_defaults_disabled
|
||||
for sname, entry in queen_store.overrides.items():
|
||||
# Don't overwrite an entry the colony already set
|
||||
# (rare on fresh fork; matters if this is a re-fork).
|
||||
if sname in colony_store.overrides:
|
||||
continue
|
||||
colony_store.upsert(sname, entry.clone())
|
||||
for sname in queen_store.deleted_ui_skills:
|
||||
colony_store.deleted_ui_skills.add(sname)
|
||||
|
||||
colony_store.upsert(
|
||||
draft.name,
|
||||
OverrideEntry(
|
||||
enabled=True,
|
||||
provenance=Provenance.QUEEN_CREATED,
|
||||
created_at=utc_now(),
|
||||
created_by=f"queen:{queen_id}",
|
||||
),
|
||||
)
|
||||
colony_store.save()
|
||||
except Exception:
|
||||
# Registration is best-effort; discovery still surfaces the
|
||||
# skill as project-scope even if the ledger fails to update.
|
||||
logger.warning("create_colony: override registration failed", exc_info=True)
|
||||
|
||||
logger.info(
|
||||
"create_colony: materialized skill at %s (replaced=%s)",
|
||||
installed_skill,
|
||||
|
||||
@@ -5,6 +5,8 @@ import ColonyChat from "./pages/colony-chat";
|
||||
import QueenDM from "./pages/queen-dm";
|
||||
import OrgChart from "./pages/org-chart";
|
||||
import PromptLibrary from "./pages/prompt-library";
|
||||
import SkillsLibrary from "./pages/skills-library";
|
||||
import ToolLibrary from "./pages/tool-library";
|
||||
import CredentialsPage from "./pages/credentials";
|
||||
import NotFound from "./pages/not-found";
|
||||
|
||||
@@ -16,7 +18,9 @@ function App() {
|
||||
<Route path="/colony/:colonyId" element={<ColonyChat />} />
|
||||
<Route path="/queen/:queenId" element={<QueenDM />} />
|
||||
<Route path="/org-chart" element={<OrgChart />} />
|
||||
<Route path="/skills-library" element={<SkillsLibrary />} />
|
||||
<Route path="/prompt-library" element={<PromptLibrary />} />
|
||||
<Route path="/tool-library" element={<ToolLibrary />} />
|
||||
<Route path="/credentials" element={<CredentialsPage />} />
|
||||
<Route path="*" element={<NotFound />} />
|
||||
</Route>
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
import { api } from "./client";
|
||||
import type { ToolMeta, McpServerTools } from "./queens";
|
||||
|
||||
export interface ColonySummary {
|
||||
name: string;
|
||||
queen_name: string | null;
|
||||
created_at: string | null;
|
||||
has_allowlist: boolean;
|
||||
enabled_count: number | null;
|
||||
}
|
||||
|
||||
export interface ColonyToolsResponse {
|
||||
colony_name: string;
|
||||
enabled_mcp_tools: string[] | null;
|
||||
stale: boolean;
|
||||
lifecycle: ToolMeta[];
|
||||
synthetic: ToolMeta[];
|
||||
mcp_servers: McpServerTools[];
|
||||
}
|
||||
|
||||
export interface ColonyToolsUpdateResult {
|
||||
colony_name: string;
|
||||
enabled_mcp_tools: string[] | null;
|
||||
refreshed_runtimes: number;
|
||||
note?: string;
|
||||
}
|
||||
|
||||
export const coloniesApi = {
|
||||
/** List every colony on disk with a summary of its tool allowlist. */
|
||||
list: () =>
|
||||
api.get<{ colonies: ColonySummary[] }>(`/colonies/tools-index`),
|
||||
|
||||
/** Enumerate a colony's tool surface (lifecycle + synthetic + MCP). */
|
||||
getTools: (colonyName: string) =>
|
||||
api.get<ColonyToolsResponse>(
|
||||
`/colony/${encodeURIComponent(colonyName)}/tools`,
|
||||
),
|
||||
|
||||
/** Persist a colony's MCP tool allowlist.
|
||||
*
|
||||
* ``null`` resets to "allow every MCP tool". A list of names enables
|
||||
* only those MCP tools. Changes take effect on the next worker spawn;
|
||||
* in-flight workers keep their booted tool list.
|
||||
*/
|
||||
updateTools: (colonyName: string, enabled: string[] | null) =>
|
||||
api.patch<ColonyToolsUpdateResult>(
|
||||
`/colony/${encodeURIComponent(colonyName)}/tools`,
|
||||
{ enabled_mcp_tools: enabled },
|
||||
),
|
||||
};
|
||||
@@ -0,0 +1,66 @@
|
||||
import { api } from "./client";
|
||||
|
||||
export type McpTransport = "stdio" | "http" | "sse" | "unix";
|
||||
|
||||
export interface McpServer {
|
||||
name: string;
|
||||
/** "local": added via UI/CLI (user-editable). "registry": installed from
|
||||
* the remote MCP registry. "built-in": baked into the queen package —
|
||||
* visible but not removable from the UI. */
|
||||
source: "local" | "registry" | "built-in";
|
||||
transport: McpTransport | string;
|
||||
description: string;
|
||||
enabled: boolean;
|
||||
last_health_status: "healthy" | "unhealthy" | null;
|
||||
last_error: string | null;
|
||||
last_health_check_at: string | null;
|
||||
tool_count: number | null;
|
||||
/** Servers flagged removable:false cannot be deleted from the UI. */
|
||||
removable?: boolean;
|
||||
}
|
||||
|
||||
export interface AddMcpServerBody {
|
||||
name: string;
|
||||
transport: McpTransport;
|
||||
/** stdio */
|
||||
command?: string;
|
||||
args?: string[];
|
||||
env?: Record<string, string>;
|
||||
cwd?: string;
|
||||
/** http / sse */
|
||||
url?: string;
|
||||
headers?: Record<string, string>;
|
||||
/** unix */
|
||||
socket_path?: string;
|
||||
description?: string;
|
||||
}
|
||||
|
||||
export interface McpHealthResult {
|
||||
name: string;
|
||||
status: "healthy" | "unhealthy" | "unknown";
|
||||
tools: number;
|
||||
error: string | null;
|
||||
}
|
||||
|
||||
/** Backend MCPError shape when an operation fails. */
|
||||
export interface McpErrorBody {
|
||||
error: string;
|
||||
code?: string;
|
||||
what?: string;
|
||||
why?: string;
|
||||
fix?: string;
|
||||
}
|
||||
|
||||
export const mcpApi = {
|
||||
listServers: () => api.get<{ servers: McpServer[] }>("/mcp/servers"),
|
||||
addServer: (body: AddMcpServerBody) =>
|
||||
api.post<{ server: McpServer; hint: string }>("/mcp/servers", body),
|
||||
removeServer: (name: string) =>
|
||||
api.delete<{ removed: string }>(`/mcp/servers/${encodeURIComponent(name)}`),
|
||||
setEnabled: (name: string, enabled: boolean) =>
|
||||
api.post<{ name: string; enabled: boolean }>(
|
||||
`/mcp/servers/${encodeURIComponent(name)}/${enabled ? "enable" : "disable"}`,
|
||||
),
|
||||
checkHealth: (name: string) =>
|
||||
api.post<McpHealthResult>(`/mcp/servers/${encodeURIComponent(name)}/health`),
|
||||
};
|
||||
@@ -16,6 +16,45 @@ export interface QueenSessionResult {
|
||||
status: "live" | "resumed" | "created";
|
||||
}
|
||||
|
||||
export interface ToolMeta {
|
||||
name: string;
|
||||
description: string;
|
||||
input_schema?: Record<string, unknown>;
|
||||
editable?: boolean;
|
||||
}
|
||||
|
||||
export interface McpServerTools {
|
||||
name: string;
|
||||
tools: Array<ToolMeta & { enabled: boolean }>;
|
||||
}
|
||||
|
||||
export interface QueenToolsResponse {
|
||||
queen_id: string;
|
||||
enabled_mcp_tools: string[] | null;
|
||||
/** True when the effective allowlist comes from the role-based default
|
||||
* (no tools.json sidecar saved for this queen). False means the user
|
||||
* has explicitly saved an allowlist. */
|
||||
is_role_default: boolean;
|
||||
stale: boolean;
|
||||
lifecycle: ToolMeta[];
|
||||
synthetic: ToolMeta[];
|
||||
mcp_servers: McpServerTools[];
|
||||
}
|
||||
|
||||
export interface QueenToolsUpdateResult {
|
||||
queen_id: string;
|
||||
enabled_mcp_tools: string[] | null;
|
||||
refreshed_sessions: number;
|
||||
}
|
||||
|
||||
export interface QueenToolsResetResult {
|
||||
queen_id: string;
|
||||
removed: boolean;
|
||||
enabled_mcp_tools: string[] | null;
|
||||
is_role_default: true;
|
||||
refreshed_sessions: number;
|
||||
}
|
||||
|
||||
export const queensApi = {
|
||||
/** List all queen profiles (id, name, title). */
|
||||
list: () =>
|
||||
@@ -57,4 +96,24 @@ export const queensApi = {
|
||||
initial_prompt: initialPrompt,
|
||||
initial_phase: initialPhase || undefined,
|
||||
}),
|
||||
|
||||
/** Enumerate the queen's tool surface (lifecycle + synthetic + MCP). */
|
||||
getTools: (queenId: string) =>
|
||||
api.get<QueenToolsResponse>(`/queen/${queenId}/tools`),
|
||||
|
||||
/** Persist the MCP tool allowlist for a queen.
|
||||
*
|
||||
* Pass ``null`` to explicitly allow every MCP tool, or a list to
|
||||
* restrict the queen's tool surface. Lifecycle and synthetic tools
|
||||
* are always enabled and cannot be listed here.
|
||||
*/
|
||||
updateTools: (queenId: string, enabled: string[] | null) =>
|
||||
api.patch<QueenToolsUpdateResult>(`/queen/${queenId}/tools`, {
|
||||
enabled_mcp_tools: enabled,
|
||||
}),
|
||||
|
||||
/** Drop the queen's tools.json sidecar so she falls back to the
|
||||
* role-based default (or allow-all for queens without a role entry). */
|
||||
resetTools: (queenId: string) =>
|
||||
api.delete<QueenToolsResetResult>(`/queen/${queenId}/tools`),
|
||||
};
|
||||
|
||||
@@ -0,0 +1,154 @@
|
||||
import { api } from "./client";
|
||||
|
||||
export type SkillScopeKind = "queen" | "colony" | "user";
|
||||
|
||||
export type SkillProvenance =
|
||||
| "framework"
|
||||
| "preset"
|
||||
| "user_dropped"
|
||||
| "user_ui_created"
|
||||
| "queen_created"
|
||||
| "learned_runtime"
|
||||
| "project_dropped"
|
||||
| "other";
|
||||
|
||||
export interface SkillOwner {
|
||||
type: "queen" | "colony";
|
||||
id: string;
|
||||
name: string;
|
||||
}
|
||||
|
||||
export interface SkillRow {
|
||||
name: string;
|
||||
description: string;
|
||||
source_scope: string;
|
||||
provenance: SkillProvenance;
|
||||
enabled: boolean;
|
||||
editable: boolean;
|
||||
deletable: boolean;
|
||||
location: string;
|
||||
base_dir?: string;
|
||||
visibility: string[] | null;
|
||||
trust: string | null;
|
||||
created_at: string | null;
|
||||
created_by: string | null;
|
||||
notes: string | null;
|
||||
param_overrides?: Record<string, unknown>;
|
||||
owner?: SkillOwner | null;
|
||||
visible_to?: { queens: string[]; colonies: string[] };
|
||||
enabled_by_default?: boolean;
|
||||
}
|
||||
|
||||
export interface ScopeSkillsResponse {
|
||||
queen_id?: string;
|
||||
colony_name?: string;
|
||||
all_defaults_disabled: boolean;
|
||||
skills: SkillRow[];
|
||||
inherited_from_queen?: string[];
|
||||
}
|
||||
|
||||
export interface AggregatedSkillsResponse {
|
||||
skills: SkillRow[];
|
||||
queens: Array<{ id: string; name: string }>;
|
||||
colonies: Array<{ name: string; queen_id: string | null }>;
|
||||
}
|
||||
|
||||
export interface SkillScopesResponse {
|
||||
queens: Array<{ id: string; name: string }>;
|
||||
colonies: Array<{ name: string; queen_id: string | null }>;
|
||||
}
|
||||
|
||||
export interface SkillDetailResponse {
|
||||
name: string;
|
||||
description: string;
|
||||
source_scope: string;
|
||||
location: string;
|
||||
base_dir: string;
|
||||
body: string;
|
||||
visibility: string[] | null;
|
||||
}
|
||||
|
||||
export interface SkillCreatePayload {
|
||||
name: string;
|
||||
description: string;
|
||||
body: string;
|
||||
files?: Array<{ path: string; content: string }>;
|
||||
enabled?: boolean;
|
||||
notes?: string | null;
|
||||
replace_existing?: boolean;
|
||||
}
|
||||
|
||||
export interface SkillPatchPayload {
|
||||
enabled?: boolean;
|
||||
param_overrides?: Record<string, unknown>;
|
||||
notes?: string | null;
|
||||
all_defaults_disabled?: boolean;
|
||||
}
|
||||
|
||||
const scopePath = (scope: "queen" | "colony", targetId: string) =>
|
||||
scope === "queen"
|
||||
? `/queen/${encodeURIComponent(targetId)}/skills`
|
||||
: `/colonies/${encodeURIComponent(targetId)}/skills`;
|
||||
|
||||
export const skillsApi = {
|
||||
// Aggregated library
|
||||
listAll: () => api.get<AggregatedSkillsResponse>("/skills"),
|
||||
listScopes: () => api.get<SkillScopesResponse>("/skills/scopes"),
|
||||
getDetail: (name: string) =>
|
||||
api.get<SkillDetailResponse>(`/skills/${encodeURIComponent(name)}`),
|
||||
|
||||
// Per-scope
|
||||
listForQueen: (queenId: string) =>
|
||||
api.get<ScopeSkillsResponse>(`/queen/${encodeURIComponent(queenId)}/skills`),
|
||||
listForColony: (colonyName: string) =>
|
||||
api.get<ScopeSkillsResponse>(
|
||||
`/colonies/${encodeURIComponent(colonyName)}/skills`,
|
||||
),
|
||||
|
||||
create: (
|
||||
scope: "queen" | "colony",
|
||||
targetId: string,
|
||||
payload: SkillCreatePayload,
|
||||
) => api.post<SkillRow>(scopePath(scope, targetId), payload),
|
||||
|
||||
patch: (
|
||||
scope: "queen" | "colony",
|
||||
targetId: string,
|
||||
skillName: string,
|
||||
payload: SkillPatchPayload,
|
||||
) =>
|
||||
api.patch<{ name: string; enabled: boolean | null; ok: boolean }>(
|
||||
`${scopePath(scope, targetId)}/${encodeURIComponent(skillName)}`,
|
||||
payload,
|
||||
),
|
||||
|
||||
putBody: (
|
||||
scope: "queen" | "colony",
|
||||
targetId: string,
|
||||
skillName: string,
|
||||
payload: { body: string; description?: string },
|
||||
) =>
|
||||
api.put<{ name: string; installed_path: string }>(
|
||||
`${scopePath(scope, targetId)}/${encodeURIComponent(skillName)}/body`,
|
||||
payload,
|
||||
),
|
||||
|
||||
remove: (scope: "queen" | "colony", targetId: string, skillName: string) =>
|
||||
api.delete<{ name: string; removed: boolean }>(
|
||||
`${scopePath(scope, targetId)}/${encodeURIComponent(skillName)}`,
|
||||
),
|
||||
|
||||
reload: (scope: "queen" | "colony", targetId: string) =>
|
||||
api.post<{ ok: boolean }>(`${scopePath(scope, targetId)}/reload`),
|
||||
|
||||
// Multipart upload. File may be a SKILL.md or a .zip bundle.
|
||||
upload: (formData: FormData) =>
|
||||
api.upload<{
|
||||
name: string;
|
||||
installed_path: string;
|
||||
replaced: boolean;
|
||||
scope: SkillScopeKind;
|
||||
target_id: string | null;
|
||||
enabled: boolean;
|
||||
}>("/skills/upload", formData),
|
||||
};
|
||||
@@ -151,8 +151,11 @@ interface ChatPanelProps {
|
||||
onStartNewSession?: () => void;
|
||||
/** When true, disable the start-new-session button (request in flight). */
|
||||
startingNewSession?: boolean;
|
||||
/** Cumulative LLM token usage for this session */
|
||||
tokenUsage?: { input: number; output: number };
|
||||
/** Cumulative LLM token usage for this session.
|
||||
* `cached` (cache reads) and `cacheCreated` (cache writes) are subsets of
|
||||
* `input` — providers count both inside prompt_tokens. Display them
|
||||
* separately; do not add to a total. */
|
||||
tokenUsage?: { input: number; output: number; cached?: number; cacheCreated?: number; costUsd?: number };
|
||||
/** Optional action element rendered on the right side of the "Conversation" header */
|
||||
headerAction?: React.ReactNode;
|
||||
}
|
||||
@@ -1482,11 +1485,41 @@ export default function ChatPanel({
|
||||
Context: {fmt(queenUsage.estimatedTokens)}/{fmt(queenUsage.maxTokens)}
|
||||
</span>
|
||||
)}
|
||||
{hasTokens && (
|
||||
<span title="LLM tokens used this session (input + output)">
|
||||
Tokens: {fmt(tokenUsage!.input + tokenUsage!.output)}
|
||||
</span>
|
||||
)}
|
||||
{hasTokens && (() => {
|
||||
const cached = tokenUsage!.cached ?? 0;
|
||||
const created = tokenUsage!.cacheCreated ?? 0;
|
||||
const cost = tokenUsage!.costUsd ?? 0;
|
||||
// cached/created are subsets of input — never sum; surface separately.
|
||||
// Cost can be < $0.01; show 4 decimals so small-model sessions aren't "$0.00".
|
||||
const costStr = cost > 0 ? `$${cost.toFixed(4)}` : "—";
|
||||
return (
|
||||
<span className="group relative cursor-help transition-colors hover:text-muted-foreground">
|
||||
Tokens: {fmt(tokenUsage!.input + tokenUsage!.output)}
|
||||
<span
|
||||
role="tooltip"
|
||||
className="pointer-events-none invisible absolute bottom-full right-0 z-50 mb-2 whitespace-nowrap rounded-md border border-border bg-popover px-3 py-2 text-[11px] text-popover-foreground opacity-0 shadow-lg transition-[opacity,transform] duration-150 translate-y-1 group-hover:visible group-hover:opacity-100 group-hover:translate-y-0"
|
||||
>
|
||||
<span className="mb-1.5 block text-muted-foreground">
|
||||
LLM tokens used this session
|
||||
</span>
|
||||
<span className="grid grid-cols-[auto_1fr] gap-x-4 gap-y-0.5 tabular-nums">
|
||||
<span>Input</span>
|
||||
<span className="text-right">{fmt(tokenUsage!.input)}</span>
|
||||
<span className="pl-3 text-muted-foreground">cache read</span>
|
||||
<span className="text-right text-muted-foreground">{fmt(cached)}</span>
|
||||
<span className="pl-3 text-muted-foreground">cache write</span>
|
||||
<span className="text-right text-muted-foreground">{fmt(created)}</span>
|
||||
<span>Output</span>
|
||||
<span className="text-right">{fmt(tokenUsage!.output)}</span>
|
||||
<span className="mt-1 border-t border-border/50 pt-1">Cost</span>
|
||||
<span className="mt-1 border-t border-border/50 pt-1 text-right font-medium">
|
||||
{costStr}
|
||||
</span>
|
||||
</span>
|
||||
</span>
|
||||
</span>
|
||||
);
|
||||
})()}
|
||||
</div>
|
||||
);
|
||||
})()}
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
import { useCallback } from "react";
|
||||
import { coloniesApi } from "@/api/colonies";
|
||||
import ToolsEditor from "./ToolsEditor";
|
||||
|
||||
export default function ColonyToolsSection({
|
||||
colonyName,
|
||||
}: {
|
||||
colonyName: string;
|
||||
}) {
|
||||
const fetchSnapshot = useCallback(
|
||||
() => coloniesApi.getTools(colonyName),
|
||||
[colonyName],
|
||||
);
|
||||
const saveAllowlist = useCallback(
|
||||
(enabled: string[] | null) => coloniesApi.updateTools(colonyName, enabled),
|
||||
[colonyName],
|
||||
);
|
||||
return (
|
||||
<ToolsEditor
|
||||
subjectKey={`colony:${colonyName}`}
|
||||
title="Tools"
|
||||
caveat="Changes apply to the next worker spawn. Running workers keep the tool list they booted with."
|
||||
fetchSnapshot={fetchSnapshot}
|
||||
saveAllowlist={saveAllowlist}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,651 @@
|
||||
import { useEffect, useState } from "react";
|
||||
import {
|
||||
Plus,
|
||||
Trash2,
|
||||
RefreshCw,
|
||||
Loader2,
|
||||
AlertCircle,
|
||||
Check,
|
||||
X,
|
||||
Server,
|
||||
CircleCheck,
|
||||
CircleAlert,
|
||||
CircleDashed,
|
||||
} from "lucide-react";
|
||||
import {
|
||||
mcpApi,
|
||||
type McpServer,
|
||||
type McpTransport,
|
||||
type AddMcpServerBody,
|
||||
} from "@/api/mcp";
|
||||
|
||||
type TransportKey = McpTransport;
|
||||
|
||||
const TRANSPORT_OPTIONS: TransportKey[] = ["stdio", "http", "sse", "unix"];
|
||||
|
||||
function healthBadge(server: McpServer) {
|
||||
if (!server.enabled) {
|
||||
return (
|
||||
<span className="flex items-center gap-1 text-[11px] text-muted-foreground">
|
||||
<CircleDashed className="w-3 h-3" /> Disabled
|
||||
</span>
|
||||
);
|
||||
}
|
||||
if (server.last_health_status === "healthy") {
|
||||
return (
|
||||
<span className="flex items-center gap-1 text-[11px] text-green-500">
|
||||
<CircleCheck className="w-3 h-3" /> Healthy
|
||||
</span>
|
||||
);
|
||||
}
|
||||
if (server.last_health_status === "unhealthy") {
|
||||
return (
|
||||
<span
|
||||
className="flex items-center gap-1 text-[11px] text-red-400"
|
||||
title={server.last_error || "Unhealthy"}
|
||||
>
|
||||
<CircleAlert className="w-3 h-3" /> Unhealthy
|
||||
</span>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<span className="flex items-center gap-1 text-[11px] text-muted-foreground">
|
||||
<CircleDashed className="w-3 h-3" /> Unknown
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
interface AddFormState {
|
||||
name: string;
|
||||
transport: TransportKey;
|
||||
command: string;
|
||||
args: string;
|
||||
env: string;
|
||||
cwd: string;
|
||||
url: string;
|
||||
headers: string;
|
||||
socketPath: string;
|
||||
description: string;
|
||||
}
|
||||
|
||||
const EMPTY_FORM: AddFormState = {
|
||||
name: "",
|
||||
transport: "stdio",
|
||||
command: "",
|
||||
args: "",
|
||||
env: "",
|
||||
cwd: "",
|
||||
url: "",
|
||||
headers: "",
|
||||
socketPath: "",
|
||||
description: "",
|
||||
};
|
||||
|
||||
function parseKeyValueLines(text: string): Record<string, string> {
|
||||
const out: Record<string, string> = {};
|
||||
text
|
||||
.split("\n")
|
||||
.map((l) => l.trim())
|
||||
.filter(Boolean)
|
||||
.forEach((line) => {
|
||||
const eq = line.indexOf("=");
|
||||
if (eq < 0) return;
|
||||
const k = line.slice(0, eq).trim();
|
||||
const v = line.slice(eq + 1).trim();
|
||||
if (k) out[k] = v;
|
||||
});
|
||||
return out;
|
||||
}
|
||||
|
||||
function buildAddBody(form: AddFormState): AddMcpServerBody {
|
||||
const body: AddMcpServerBody = {
|
||||
name: form.name.trim(),
|
||||
transport: form.transport,
|
||||
description: form.description.trim() || undefined,
|
||||
};
|
||||
if (form.transport === "stdio") {
|
||||
body.command = form.command.trim();
|
||||
const args = form.args
|
||||
.split("\n")
|
||||
.map((s) => s.trim())
|
||||
.filter(Boolean);
|
||||
if (args.length) body.args = args;
|
||||
const env = parseKeyValueLines(form.env);
|
||||
if (Object.keys(env).length) body.env = env;
|
||||
if (form.cwd.trim()) body.cwd = form.cwd.trim();
|
||||
} else if (form.transport === "http" || form.transport === "sse") {
|
||||
body.url = form.url.trim();
|
||||
const headers = parseKeyValueLines(form.headers);
|
||||
if (Object.keys(headers).length) body.headers = headers;
|
||||
} else if (form.transport === "unix") {
|
||||
body.socket_path = form.socketPath.trim();
|
||||
}
|
||||
return body;
|
||||
}
|
||||
|
||||
export default function McpServersPanel() {
|
||||
const [servers, setServers] = useState<McpServer[] | null>(null);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
const [adding, setAdding] = useState(false);
|
||||
const [form, setForm] = useState<AddFormState>(EMPTY_FORM);
|
||||
const [submitting, setSubmitting] = useState(false);
|
||||
const [submitError, setSubmitError] = useState<string | null>(null);
|
||||
|
||||
const [busyByName, setBusyByName] = useState<Record<string, boolean>>({});
|
||||
|
||||
const refresh = async () => {
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
try {
|
||||
const { servers } = await mcpApi.listServers();
|
||||
setServers(servers);
|
||||
} catch (e: unknown) {
|
||||
setError((e as Error)?.message || "Failed to load MCP servers");
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
refresh();
|
||||
}, []);
|
||||
|
||||
const setBusy = (name: string, v: boolean) =>
|
||||
setBusyByName((p) => ({ ...p, [name]: v }));
|
||||
|
||||
const handleToggle = async (server: McpServer) => {
|
||||
setBusy(server.name, true);
|
||||
try {
|
||||
await mcpApi.setEnabled(server.name, !server.enabled);
|
||||
await refresh();
|
||||
} catch (e: unknown) {
|
||||
setError((e as Error)?.message || "Toggle failed");
|
||||
} finally {
|
||||
setBusy(server.name, false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleRemove = async (server: McpServer) => {
|
||||
if (!confirm(`Remove MCP server "${server.name}"?`)) return;
|
||||
setBusy(server.name, true);
|
||||
try {
|
||||
await mcpApi.removeServer(server.name);
|
||||
await refresh();
|
||||
} catch (e: unknown) {
|
||||
const body = (e as { body?: { error?: string } }).body;
|
||||
setError(body?.error || (e as Error)?.message || "Remove failed");
|
||||
} finally {
|
||||
setBusy(server.name, false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleHealth = async (server: McpServer) => {
|
||||
setBusy(server.name, true);
|
||||
try {
|
||||
await mcpApi.checkHealth(server.name);
|
||||
await refresh();
|
||||
} catch (e: unknown) {
|
||||
setError((e as Error)?.message || "Health check failed");
|
||||
} finally {
|
||||
setBusy(server.name, false);
|
||||
}
|
||||
};
|
||||
|
||||
const canSubmit = (() => {
|
||||
if (!form.name.trim()) return false;
|
||||
if (form.transport === "stdio") return !!form.command.trim();
|
||||
if (form.transport === "http" || form.transport === "sse")
|
||||
return !!form.url.trim();
|
||||
if (form.transport === "unix") return !!form.socketPath.trim();
|
||||
return false;
|
||||
})();
|
||||
|
||||
const handleSubmit = async () => {
|
||||
if (!canSubmit) return;
|
||||
setSubmitting(true);
|
||||
setSubmitError(null);
|
||||
try {
|
||||
const body = buildAddBody(form);
|
||||
const { server } = await mcpApi.addServer(body);
|
||||
// Best-effort: auto-run health check so the UI shows tool count.
|
||||
try {
|
||||
await mcpApi.checkHealth(server.name);
|
||||
} catch {
|
||||
/* health check is informational; don't block the add flow */
|
||||
}
|
||||
setAdding(false);
|
||||
setForm(EMPTY_FORM);
|
||||
await refresh();
|
||||
} catch (e: unknown) {
|
||||
const body = (e as { body?: { error?: string; fix?: string } }).body;
|
||||
setSubmitError(
|
||||
[body?.error, body?.fix].filter(Boolean).join(" — ") ||
|
||||
(e as Error)?.message ||
|
||||
"Add failed",
|
||||
);
|
||||
} finally {
|
||||
setSubmitting(false);
|
||||
}
|
||||
};
|
||||
|
||||
// Group by origin. "local" = user-registered via the UI or CLI. Everything
|
||||
// else (built-in package entries, registry-installed entries) sits under
|
||||
// "Built-in" since the user can't remove them from the UI.
|
||||
const builtIns = (servers || []).filter((s) => s.source !== "local");
|
||||
const custom = (servers || []).filter((s) => s.source === "local");
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-5">
|
||||
<div className="flex items-start justify-between gap-3">
|
||||
<div>
|
||||
<h3 className="text-lg font-semibold text-foreground">MCP Servers</h3>
|
||||
<p className="text-sm text-muted-foreground mt-1">
|
||||
Register your own MCP servers so queens can use their tools. New
|
||||
servers take effect in the next queen session you start.
|
||||
</p>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<button
|
||||
onClick={refresh}
|
||||
disabled={loading}
|
||||
className="flex items-center gap-1.5 px-3 py-1.5 rounded-md border border-border/60 text-xs text-muted-foreground hover:text-foreground hover:bg-muted/30 disabled:opacity-50"
|
||||
title="Refresh"
|
||||
>
|
||||
<RefreshCw className={`w-3 h-3 ${loading ? "animate-spin" : ""}`} />
|
||||
</button>
|
||||
<button
|
||||
onClick={() => {
|
||||
setAdding(true);
|
||||
setForm(EMPTY_FORM);
|
||||
setSubmitError(null);
|
||||
}}
|
||||
className="flex items-center gap-1.5 px-3 py-1.5 rounded-md bg-primary text-primary-foreground text-xs font-semibold hover:bg-primary/90"
|
||||
>
|
||||
<Plus className="w-3 h-3" />
|
||||
Add MCP Server
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="flex items-start gap-2 text-xs text-destructive p-2.5 rounded-md bg-destructive/10 border border-destructive/30">
|
||||
<AlertCircle className="w-3.5 h-3.5 mt-0.5 flex-shrink-0" />
|
||||
<span className="flex-1">{error}</span>
|
||||
<button
|
||||
onClick={() => setError(null)}
|
||||
className="text-destructive/70 hover:text-destructive"
|
||||
>
|
||||
<X className="w-3 h-3" />
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{loading && !servers && (
|
||||
<div className="flex items-center gap-2 text-xs text-muted-foreground">
|
||||
<Loader2 className="w-3 h-3 animate-spin" /> Loading MCP servers…
|
||||
</div>
|
||||
)}
|
||||
|
||||
{servers && (
|
||||
<>
|
||||
{custom.length > 0 && (
|
||||
<Section title="My Custom">
|
||||
{custom.map((s) => (
|
||||
<ServerRow
|
||||
key={s.name}
|
||||
server={s}
|
||||
busy={!!busyByName[s.name]}
|
||||
onToggle={() => handleToggle(s)}
|
||||
onRemove={() => handleRemove(s)}
|
||||
onHealth={() => handleHealth(s)}
|
||||
isLocal
|
||||
/>
|
||||
))}
|
||||
</Section>
|
||||
)}
|
||||
<Section title="Built-in">
|
||||
{builtIns.length === 0 ? (
|
||||
<p className="text-xs text-muted-foreground px-2 py-2">
|
||||
No built-in servers registered.
|
||||
</p>
|
||||
) : (
|
||||
builtIns.map((s) => (
|
||||
<ServerRow
|
||||
key={s.name}
|
||||
server={s}
|
||||
busy={!!busyByName[s.name]}
|
||||
onToggle={() => handleToggle(s)}
|
||||
onRemove={() => handleRemove(s)}
|
||||
onHealth={() => handleHealth(s)}
|
||||
/>
|
||||
))
|
||||
)}
|
||||
</Section>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Add MCP modal */}
|
||||
{adding && (
|
||||
<div className="fixed inset-0 z-[60] flex items-center justify-center">
|
||||
<div
|
||||
className="absolute inset-0 bg-black/50"
|
||||
onClick={() => !submitting && setAdding(false)}
|
||||
/>
|
||||
<div className="relative bg-card border border-border/60 rounded-xl shadow-2xl w-full max-w-lg p-5 space-y-4 max-h-[85vh] overflow-y-auto">
|
||||
<div className="flex items-center justify-between">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Add MCP Server
|
||||
</h3>
|
||||
<button
|
||||
onClick={() => !submitting && setAdding(false)}
|
||||
className="p-1 rounded text-muted-foreground hover:text-foreground"
|
||||
>
|
||||
<X className="w-4 h-4" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<FieldRow label="Name *" hint="Unique identifier, e.g. my-search-tool">
|
||||
<input
|
||||
autoFocus
|
||||
value={form.name}
|
||||
onChange={(e) =>
|
||||
setForm((f) => ({
|
||||
...f,
|
||||
name: e.target.value.toLowerCase().replace(/[^a-z0-9_-]/g, ""),
|
||||
}))
|
||||
}
|
||||
placeholder="my-search-tool"
|
||||
className={inputCls}
|
||||
/>
|
||||
</FieldRow>
|
||||
|
||||
<FieldRow label="Transport *">
|
||||
<div className="flex gap-1">
|
||||
{TRANSPORT_OPTIONS.map((t) => (
|
||||
<button
|
||||
key={t}
|
||||
onClick={() => setForm((f) => ({ ...f, transport: t }))}
|
||||
className={`flex-1 px-3 py-1.5 rounded-md text-xs font-medium border ${
|
||||
form.transport === t
|
||||
? "bg-primary/15 text-primary border-primary/40"
|
||||
: "text-muted-foreground hover:text-foreground border-border/60 hover:bg-muted/30"
|
||||
}`}
|
||||
>
|
||||
{t}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
</FieldRow>
|
||||
|
||||
{form.transport === "stdio" && (
|
||||
<>
|
||||
<FieldRow
|
||||
label="Command *"
|
||||
hint="Executable that speaks MCP over stdin/stdout"
|
||||
>
|
||||
<input
|
||||
value={form.command}
|
||||
onChange={(e) =>
|
||||
setForm((f) => ({ ...f, command: e.target.value }))
|
||||
}
|
||||
placeholder="uv"
|
||||
className={inputCls}
|
||||
/>
|
||||
</FieldRow>
|
||||
<FieldRow label="Args (one per line)">
|
||||
<textarea
|
||||
value={form.args}
|
||||
onChange={(e) =>
|
||||
setForm((f) => ({ ...f, args: e.target.value }))
|
||||
}
|
||||
rows={3}
|
||||
placeholder={"run\npython\nmy_server.py\n--stdio"}
|
||||
className={textareaCls}
|
||||
/>
|
||||
</FieldRow>
|
||||
<FieldRow label="Env (KEY=VALUE, one per line)">
|
||||
<textarea
|
||||
value={form.env}
|
||||
onChange={(e) =>
|
||||
setForm((f) => ({ ...f, env: e.target.value }))
|
||||
}
|
||||
rows={2}
|
||||
placeholder="API_KEY=abc123"
|
||||
className={textareaCls}
|
||||
/>
|
||||
</FieldRow>
|
||||
<FieldRow label="Working directory">
|
||||
<input
|
||||
value={form.cwd}
|
||||
onChange={(e) =>
|
||||
setForm((f) => ({ ...f, cwd: e.target.value }))
|
||||
}
|
||||
placeholder="/path/to/repo"
|
||||
className={inputCls}
|
||||
/>
|
||||
</FieldRow>
|
||||
</>
|
||||
)}
|
||||
|
||||
{(form.transport === "http" || form.transport === "sse") && (
|
||||
<>
|
||||
<FieldRow label="URL *">
|
||||
<input
|
||||
value={form.url}
|
||||
onChange={(e) =>
|
||||
setForm((f) => ({ ...f, url: e.target.value }))
|
||||
}
|
||||
placeholder="https://example.com/mcp"
|
||||
className={inputCls}
|
||||
/>
|
||||
</FieldRow>
|
||||
<FieldRow label="Headers (KEY=VALUE, one per line)">
|
||||
<textarea
|
||||
value={form.headers}
|
||||
onChange={(e) =>
|
||||
setForm((f) => ({ ...f, headers: e.target.value }))
|
||||
}
|
||||
rows={2}
|
||||
placeholder="Authorization=Bearer ..."
|
||||
className={textareaCls}
|
||||
/>
|
||||
</FieldRow>
|
||||
</>
|
||||
)}
|
||||
|
||||
{form.transport === "unix" && (
|
||||
<FieldRow label="Socket path *">
|
||||
<input
|
||||
value={form.socketPath}
|
||||
onChange={(e) =>
|
||||
setForm((f) => ({ ...f, socketPath: e.target.value }))
|
||||
}
|
||||
placeholder="/tmp/mcp.sock"
|
||||
className={inputCls}
|
||||
/>
|
||||
</FieldRow>
|
||||
)}
|
||||
|
||||
<FieldRow label="Description">
|
||||
<input
|
||||
value={form.description}
|
||||
onChange={(e) =>
|
||||
setForm((f) => ({ ...f, description: e.target.value }))
|
||||
}
|
||||
placeholder="What this server does"
|
||||
className={inputCls}
|
||||
/>
|
||||
</FieldRow>
|
||||
|
||||
{submitError && (
|
||||
<div className="flex items-start gap-2 text-xs text-destructive p-2 rounded-md bg-destructive/10 border border-destructive/30">
|
||||
<AlertCircle className="w-3.5 h-3.5 mt-0.5 flex-shrink-0" />
|
||||
<span>{submitError}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex justify-end gap-2 pt-1">
|
||||
<button
|
||||
onClick={() => setAdding(false)}
|
||||
disabled={submitting}
|
||||
className="px-3 py-1.5 rounded-md text-xs font-medium text-muted-foreground hover:text-foreground hover:bg-muted/30"
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
<button
|
||||
onClick={handleSubmit}
|
||||
disabled={!canSubmit || submitting}
|
||||
className="flex items-center gap-1.5 px-3 py-1.5 rounded-md bg-primary text-primary-foreground text-xs font-semibold hover:bg-primary/90 disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
>
|
||||
{submitting ? (
|
||||
<Loader2 className="w-3 h-3 animate-spin" />
|
||||
) : (
|
||||
<Check className="w-3 h-3" />
|
||||
)}
|
||||
{submitting ? "Adding…" : "Add"}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const inputCls =
|
||||
"w-full bg-muted/30 border border-border/50 rounded-lg px-3 py-2 text-sm text-foreground focus:outline-none focus:ring-1 focus:ring-primary/40";
|
||||
const textareaCls = `${inputCls} resize-none font-mono text-xs`;
|
||||
|
||||
function FieldRow({
|
||||
label,
|
||||
hint,
|
||||
children,
|
||||
}: {
|
||||
label: string;
|
||||
hint?: string;
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
return (
|
||||
<div>
|
||||
<label className="text-[11px] font-semibold text-muted-foreground uppercase tracking-wider mb-1.5 block">
|
||||
{label}
|
||||
</label>
|
||||
{children}
|
||||
{hint && (
|
||||
<p className="text-[11px] text-muted-foreground/70 mt-1">{hint}</p>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function Section({
|
||||
title,
|
||||
children,
|
||||
}: {
|
||||
title: string;
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
return (
|
||||
<div>
|
||||
<p className="text-[11px] font-semibold text-muted-foreground/60 uppercase tracking-wider mb-2">
|
||||
{title}
|
||||
</p>
|
||||
<div className="flex flex-col gap-1">{children}</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ServerRow({
|
||||
server,
|
||||
busy,
|
||||
onToggle,
|
||||
onRemove,
|
||||
onHealth,
|
||||
isLocal,
|
||||
}: {
|
||||
server: McpServer;
|
||||
busy: boolean;
|
||||
onToggle: () => void;
|
||||
onRemove: () => void;
|
||||
onHealth: () => void;
|
||||
isLocal?: boolean;
|
||||
}) {
|
||||
// Package-baked servers live in the repo and aren't managed by
|
||||
// MCPRegistry, so toggling / removing / health-checking them would
|
||||
// fail against the backend. Show them as read-only.
|
||||
const isBuiltIn = server.source === "built-in";
|
||||
return (
|
||||
<div className="flex items-center gap-3 py-2.5 px-2 rounded-lg hover:bg-muted/20">
|
||||
<div className="w-9 h-9 rounded-full bg-primary/10 flex items-center justify-center flex-shrink-0">
|
||||
<Server className="w-4 h-4 text-primary" />
|
||||
</div>
|
||||
<div className="flex-1 min-w-0">
|
||||
<div className="flex items-center gap-2">
|
||||
<p className="text-sm font-medium text-foreground truncate">
|
||||
{server.name}
|
||||
</p>
|
||||
<span className="text-[10px] uppercase tracking-wider text-muted-foreground/60">
|
||||
{server.transport}
|
||||
</span>
|
||||
{isBuiltIn && (
|
||||
<span className="text-[10px] uppercase tracking-wider text-muted-foreground/80 bg-muted/40 px-1.5 py-0.5 rounded">
|
||||
Built-in
|
||||
</span>
|
||||
)}
|
||||
{server.tool_count !== null && server.tool_count !== undefined && (
|
||||
<span className="text-[11px] text-muted-foreground">
|
||||
{server.tool_count} tools
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
{!isBuiltIn && healthBadge(server)}
|
||||
{server.description && (
|
||||
<span className="text-xs text-muted-foreground truncate">
|
||||
{isBuiltIn ? server.description : `· ${server.description}`}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
{!isBuiltIn && (
|
||||
<>
|
||||
<button
|
||||
onClick={onHealth}
|
||||
disabled={busy}
|
||||
className="p-1.5 rounded-md text-muted-foreground hover:text-foreground hover:bg-muted/40 disabled:opacity-50"
|
||||
title="Health check"
|
||||
>
|
||||
{busy ? (
|
||||
<Loader2 className="w-3.5 h-3.5 animate-spin" />
|
||||
) : (
|
||||
<RefreshCw className="w-3.5 h-3.5" />
|
||||
)}
|
||||
</button>
|
||||
<button
|
||||
onClick={onToggle}
|
||||
disabled={busy}
|
||||
className={`px-3 py-1 rounded-md text-[11px] font-semibold border disabled:opacity-50 ${
|
||||
server.enabled
|
||||
? "text-muted-foreground border-border/60 hover:bg-muted/30"
|
||||
: "bg-primary/15 text-primary border-primary/40 hover:bg-primary/25"
|
||||
}`}
|
||||
>
|
||||
{server.enabled ? "Disable" : "Enable"}
|
||||
</button>
|
||||
</>
|
||||
)}
|
||||
{isLocal && !isBuiltIn && (
|
||||
<button
|
||||
onClick={onRemove}
|
||||
disabled={busy}
|
||||
className="p-1.5 rounded-md text-muted-foreground hover:text-red-400 hover:bg-red-500/10 disabled:opacity-50"
|
||||
title="Remove"
|
||||
>
|
||||
<Trash2 className="w-3.5 h-3.5" />
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import { executionApi } from "@/api/execution";
|
||||
import { compressImage } from "@/lib/image-utils";
|
||||
import type { Colony } from "@/types/colony";
|
||||
import { slugToColonyId } from "@/lib/colony-registry";
|
||||
import QueenToolsSection from "./QueenToolsSection";
|
||||
|
||||
interface QueenProfilePanelProps {
|
||||
queenId: string;
|
||||
@@ -354,6 +355,10 @@ export default function QueenProfilePanel({ queenId, colonies, onClose }: QueenP
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="mb-6">
|
||||
<QueenToolsSection queenId={queenId} />
|
||||
</div>
|
||||
|
||||
{colonies.length > 0 && (
|
||||
<div>
|
||||
<h4 className="text-[11px] font-semibold text-muted-foreground uppercase tracking-wider mb-2">Assigned Colonies</h4>
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
import { useCallback } from "react";
|
||||
import { queensApi } from "@/api/queens";
|
||||
import ToolsEditor from "./ToolsEditor";
|
||||
|
||||
export default function QueenToolsSection({ queenId }: { queenId: string }) {
|
||||
const fetchSnapshot = useCallback(
|
||||
() => queensApi.getTools(queenId),
|
||||
[queenId],
|
||||
);
|
||||
const saveAllowlist = useCallback(
|
||||
(enabled: string[] | null) => queensApi.updateTools(queenId, enabled),
|
||||
[queenId],
|
||||
);
|
||||
const resetToRoleDefault = useCallback(
|
||||
() => queensApi.resetTools(queenId),
|
||||
[queenId],
|
||||
);
|
||||
return (
|
||||
<ToolsEditor
|
||||
subjectKey={`queen:${queenId}`}
|
||||
title="Tools"
|
||||
fetchSnapshot={fetchSnapshot}
|
||||
saveAllowlist={saveAllowlist}
|
||||
resetToRoleDefault={resetToRoleDefault}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -6,11 +6,12 @@ import { useModel, LLM_PROVIDERS } from "@/context/ModelContext";
|
||||
import { credentialsApi } from "@/api/credentials";
|
||||
import { configApi, type ModelOption } from "@/api/config";
|
||||
import { compressImage } from "@/lib/image-utils";
|
||||
import McpServersPanel from "./McpServersPanel";
|
||||
|
||||
interface SettingsModalProps {
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
initialSection?: "profile" | "byok";
|
||||
initialSection?: "profile" | "byok" | "mcp";
|
||||
}
|
||||
|
||||
function ValidationBadge({ state }: { state: "validating" | { valid: boolean | null; message: string } | undefined }) {
|
||||
@@ -37,7 +38,7 @@ export default function SettingsModal({ open, onClose, initialSection }: Setting
|
||||
|
||||
const [displayName, setDisplayName] = useState(userProfile.displayName);
|
||||
const [about, setAbout] = useState(userProfile.about);
|
||||
const [activeSection, setActiveSection] = useState<"profile" | "byok">(initialSection || "profile");
|
||||
const [activeSection, setActiveSection] = useState<"profile" | "byok" | "mcp">(initialSection || "profile");
|
||||
const [editingProvider, setEditingProvider] = useState<string | null>(null);
|
||||
const [keyInput, setKeyInput] = useState("");
|
||||
const [showKey, setShowKey] = useState(false);
|
||||
@@ -187,6 +188,12 @@ export default function SettingsModal({ open, onClose, initialSection }: Setting
|
||||
>
|
||||
BYOK
|
||||
</button>
|
||||
<button
|
||||
onClick={() => setActiveSection("mcp")}
|
||||
className={`text-left text-sm px-3 py-1.5 rounded-md ${activeSection === "mcp" ? "bg-primary/15 text-primary font-medium" : "text-muted-foreground hover:text-foreground hover:bg-muted/30"}`}
|
||||
>
|
||||
MCP Servers
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -267,6 +274,8 @@ export default function SettingsModal({ open, onClose, initialSection }: Setting
|
||||
</>
|
||||
)}
|
||||
|
||||
{activeSection === "mcp" && <McpServersPanel />}
|
||||
|
||||
{activeSection === "byok" && (
|
||||
<>
|
||||
<div>
|
||||
|
||||
@@ -5,13 +5,13 @@ import {
|
||||
ChevronRight,
|
||||
MessageSquarePlus,
|
||||
Network,
|
||||
Sparkles,
|
||||
KeyRound,
|
||||
ChevronDown,
|
||||
Plus,
|
||||
X,
|
||||
Crown,
|
||||
Loader2,
|
||||
Library,
|
||||
} from "lucide-react";
|
||||
import SidebarColonyItem from "./SidebarColonyItem";
|
||||
import SidebarQueenItem from "./SidebarQueenItem";
|
||||
@@ -28,6 +28,7 @@ export default function Sidebar() {
|
||||
);
|
||||
const [coloniesExpanded, setColoniesExpanded] = useState(true);
|
||||
const [queensExpanded, setQueensExpanded] = useState(true);
|
||||
const [libraryExpanded, setLibraryExpanded] = useState(false);
|
||||
|
||||
// Colony creation
|
||||
const [createColonyOpen, setCreateColonyOpen] = useState(false);
|
||||
@@ -165,13 +166,6 @@ export default function Sidebar() {
|
||||
<Network className="w-4 h-4" />
|
||||
<span>Org Chart</span>
|
||||
</button>
|
||||
<button
|
||||
onClick={() => navigate("/prompt-library")}
|
||||
className="flex items-center gap-2.5 px-3 py-1.5 rounded-md text-sm text-foreground/70 hover:bg-sidebar-item-hover hover:text-foreground transition-colors"
|
||||
>
|
||||
<Sparkles className="w-4 h-4" />
|
||||
<span>Prompt Library</span>
|
||||
</button>
|
||||
<button
|
||||
onClick={() => navigate("/credentials")}
|
||||
className="flex items-center gap-2.5 px-3 py-1.5 rounded-md text-sm text-foreground/70 hover:bg-sidebar-item-hover hover:text-foreground transition-colors"
|
||||
@@ -179,6 +173,40 @@ export default function Sidebar() {
|
||||
<KeyRound className="w-4 h-4" />
|
||||
<span>Credentials</span>
|
||||
</button>
|
||||
<button
|
||||
onClick={() => setLibraryExpanded((v) => !v)}
|
||||
className="flex items-center gap-2.5 px-3 py-1.5 rounded-md text-sm text-foreground/70 hover:bg-sidebar-item-hover hover:text-foreground transition-colors"
|
||||
>
|
||||
<Library className="w-4 h-4" />
|
||||
<span className="flex-1 text-left">Configuration</span>
|
||||
<ChevronDown
|
||||
className={`w-3.5 h-3.5 transition-transform ${
|
||||
libraryExpanded ? "" : "-rotate-90"
|
||||
}`}
|
||||
/>
|
||||
</button>
|
||||
{libraryExpanded && (
|
||||
<>
|
||||
<button
|
||||
onClick={() => navigate("/prompt-library")}
|
||||
className="flex items-center gap-2.5 pl-9 pr-3 py-1.5 rounded-md text-sm text-foreground/70 hover:bg-sidebar-item-hover hover:text-foreground transition-colors"
|
||||
>
|
||||
<span>Prompts</span>
|
||||
</button>
|
||||
<button
|
||||
onClick={() => navigate("/skills-library")}
|
||||
className="flex items-center gap-2.5 pl-9 pr-3 py-1.5 rounded-md text-sm text-foreground/70 hover:bg-sidebar-item-hover hover:text-foreground transition-colors"
|
||||
>
|
||||
<span>Skills</span>
|
||||
</button>
|
||||
<button
|
||||
onClick={() => navigate("/tool-library")}
|
||||
className="flex items-center gap-2.5 pl-9 pr-3 py-1.5 rounded-md text-sm text-foreground/70 hover:bg-sidebar-item-hover hover:text-foreground transition-colors"
|
||||
>
|
||||
<span>Tools</span>
|
||||
</button>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* COLONIES section */}
|
||||
|
||||
@@ -0,0 +1,508 @@
|
||||
import { useEffect, useMemo, useRef, useState } from "react";
|
||||
import {
|
||||
ChevronDown,
|
||||
ChevronRight,
|
||||
Check,
|
||||
Loader2,
|
||||
Lock,
|
||||
Wrench,
|
||||
AlertCircle,
|
||||
} from "lucide-react";
|
||||
import type { ToolMeta, McpServerTools } from "@/api/queens";
|
||||
|
||||
/** Shape every Tools section (Queen / Colony) shares. */
|
||||
export interface ToolsSnapshot {
|
||||
enabled_mcp_tools: string[] | null;
|
||||
stale: boolean;
|
||||
lifecycle: ToolMeta[];
|
||||
synthetic: ToolMeta[];
|
||||
mcp_servers: McpServerTools[];
|
||||
/** Optional: when true, the allowlist came from the role-based
|
||||
* default (no explicit save). Only queens surface this today. */
|
||||
is_role_default?: boolean;
|
||||
}
|
||||
|
||||
export interface ToolsEditorProps {
|
||||
/** Stable identifier — refetches when it changes. */
|
||||
subjectKey: string;
|
||||
/** Title shown above the controls. */
|
||||
title?: string;
|
||||
/** One-line caveat rendered under the header (e.g. "Changes apply …"). */
|
||||
caveat?: string;
|
||||
/** Load the current snapshot. */
|
||||
fetchSnapshot: () => Promise<ToolsSnapshot>;
|
||||
/** Persist an allowlist. ``null`` is an explicit "allow all" save. */
|
||||
saveAllowlist: (
|
||||
enabled: string[] | null,
|
||||
) => Promise<{ enabled_mcp_tools: string[] | null }>;
|
||||
/** Optional: drop any saved allowlist so the subject falls back to
|
||||
* its role-based default. Shows a "Reset to role default" button
|
||||
* when provided. */
|
||||
resetToRoleDefault?: () => Promise<{ enabled_mcp_tools: string[] | null }>;
|
||||
}
|
||||
|
||||
type TriState = "checked" | "unchecked" | "indeterminate";
|
||||
|
||||
function triStateForServer(
|
||||
toolNames: string[],
|
||||
allowed: Set<string> | null,
|
||||
): TriState {
|
||||
if (allowed === null) return "checked";
|
||||
if (toolNames.length === 0) return "unchecked";
|
||||
const enabledCount = toolNames.reduce(
|
||||
(n, name) => n + (allowed.has(name) ? 1 : 0),
|
||||
0,
|
||||
);
|
||||
if (enabledCount === 0) return "unchecked";
|
||||
if (enabledCount === toolNames.length) return "checked";
|
||||
return "indeterminate";
|
||||
}
|
||||
|
||||
function TriStateCheckbox({
|
||||
state,
|
||||
onChange,
|
||||
disabled,
|
||||
}: {
|
||||
state: TriState;
|
||||
onChange: (next: boolean) => void;
|
||||
disabled?: boolean;
|
||||
}) {
|
||||
const ref = useRef<HTMLInputElement>(null);
|
||||
useEffect(() => {
|
||||
if (ref.current) ref.current.indeterminate = state === "indeterminate";
|
||||
}, [state]);
|
||||
return (
|
||||
<input
|
||||
ref={ref}
|
||||
type="checkbox"
|
||||
checked={state === "checked"}
|
||||
disabled={disabled}
|
||||
onChange={(e) => onChange(e.target.checked)}
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
className="h-3.5 w-3.5 rounded border-border/70 text-primary focus:ring-primary/40"
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
function ToolRow({
|
||||
name,
|
||||
description,
|
||||
enabled,
|
||||
editable,
|
||||
onToggle,
|
||||
}: {
|
||||
name: string;
|
||||
description: string;
|
||||
enabled: boolean;
|
||||
editable: boolean;
|
||||
onToggle?: (next: boolean) => void;
|
||||
}) {
|
||||
return (
|
||||
<div className="flex items-start gap-2 py-1.5 px-2 rounded hover:bg-muted/30">
|
||||
{editable ? (
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={enabled}
|
||||
onChange={(e) => onToggle?.(e.target.checked)}
|
||||
className="mt-0.5 h-3.5 w-3.5 rounded border-border/70 text-primary focus:ring-primary/40"
|
||||
/>
|
||||
) : (
|
||||
<Lock className="mt-0.5 h-3 w-3 text-muted-foreground/60 flex-shrink-0" />
|
||||
)}
|
||||
<div className="min-w-0 flex-1">
|
||||
<div className="text-xs font-medium text-foreground font-mono">
|
||||
{name}
|
||||
</div>
|
||||
{description && (
|
||||
<div className="text-[11px] text-muted-foreground leading-relaxed line-clamp-2">
|
||||
{description}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function CollapsibleGroup({
|
||||
title,
|
||||
count,
|
||||
badge,
|
||||
expanded,
|
||||
onToggle,
|
||||
leading,
|
||||
children,
|
||||
}: {
|
||||
title: string;
|
||||
count: number;
|
||||
badge?: string;
|
||||
expanded: boolean;
|
||||
onToggle: () => void;
|
||||
leading?: React.ReactNode;
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
return (
|
||||
<div className="mb-2 rounded-lg border border-border/40 bg-muted/10 overflow-hidden">
|
||||
<button
|
||||
onClick={onToggle}
|
||||
className="w-full flex items-center gap-2 px-2.5 py-1.5 text-left hover:bg-muted/30"
|
||||
>
|
||||
{expanded ? (
|
||||
<ChevronDown className="w-3.5 h-3.5 text-muted-foreground" />
|
||||
) : (
|
||||
<ChevronRight className="w-3.5 h-3.5 text-muted-foreground" />
|
||||
)}
|
||||
{leading}
|
||||
<span className="text-xs font-medium text-foreground flex-1 truncate">
|
||||
{title}
|
||||
</span>
|
||||
<span className="text-[11px] text-muted-foreground">
|
||||
{badge ?? count}
|
||||
</span>
|
||||
</button>
|
||||
{expanded && (
|
||||
<div className="border-t border-border/30 px-1 py-1">{children}</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default function ToolsEditor({
|
||||
subjectKey,
|
||||
title = "Tools",
|
||||
caveat,
|
||||
fetchSnapshot,
|
||||
saveAllowlist,
|
||||
resetToRoleDefault,
|
||||
}: ToolsEditorProps) {
|
||||
const [data, setData] = useState<ToolsSnapshot | null>(null);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
const [draftAllowed, setDraftAllowed] = useState<Set<string> | null>(null);
|
||||
const baselineRef = useRef<Set<string> | null>(null);
|
||||
const [saving, setSaving] = useState(false);
|
||||
const [saveError, setSaveError] = useState<string | null>(null);
|
||||
const [savedRecently, setSavedRecently] = useState(false);
|
||||
|
||||
const [expanded, setExpanded] = useState<Record<string, boolean>>({});
|
||||
|
||||
useEffect(() => {
|
||||
let cancelled = false;
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
fetchSnapshot()
|
||||
.then((d) => {
|
||||
if (cancelled) return;
|
||||
setData(d);
|
||||
const baseline =
|
||||
d.enabled_mcp_tools === null
|
||||
? null
|
||||
: new Set<string>(d.enabled_mcp_tools);
|
||||
baselineRef.current = baseline === null ? null : new Set(baseline);
|
||||
setDraftAllowed(baseline);
|
||||
})
|
||||
.catch((e) => {
|
||||
if (cancelled) return;
|
||||
setError((e as Error)?.message || "Failed to load tools");
|
||||
})
|
||||
.finally(() => {
|
||||
if (!cancelled) setLoading(false);
|
||||
});
|
||||
return () => {
|
||||
cancelled = true;
|
||||
};
|
||||
}, [subjectKey, fetchSnapshot]);
|
||||
|
||||
const allMcpNames = useMemo(() => {
|
||||
const s = new Set<string>();
|
||||
data?.mcp_servers.forEach((srv) => srv.tools.forEach((t) => s.add(t.name)));
|
||||
return s;
|
||||
}, [data]);
|
||||
|
||||
const dirty = useMemo(() => {
|
||||
const a = draftAllowed;
|
||||
const b = baselineRef.current;
|
||||
if (a === null && b === null) return false;
|
||||
if (a === null || b === null) return true;
|
||||
if (a.size !== b.size) return true;
|
||||
for (const n of a) if (!b.has(n)) return true;
|
||||
return false;
|
||||
}, [draftAllowed]);
|
||||
|
||||
const applyResult = (updated: string[] | null, isRoleDefault: boolean) => {
|
||||
baselineRef.current = updated === null ? null : new Set(updated);
|
||||
setDraftAllowed(updated === null ? null : new Set(updated));
|
||||
if (data) {
|
||||
const u = updated === null ? null : new Set(updated);
|
||||
setData({
|
||||
...data,
|
||||
enabled_mcp_tools: updated,
|
||||
is_role_default: isRoleDefault,
|
||||
mcp_servers: data.mcp_servers.map((srv) => ({
|
||||
...srv,
|
||||
tools: srv.tools.map((t) => ({
|
||||
...t,
|
||||
enabled: u === null ? true : u.has(t.name),
|
||||
})),
|
||||
})),
|
||||
});
|
||||
}
|
||||
setSavedRecently(true);
|
||||
setTimeout(() => setSavedRecently(false), 2500);
|
||||
};
|
||||
|
||||
const toggleOne = (name: string, next: boolean) => {
|
||||
setDraftAllowed((prev) => {
|
||||
const base =
|
||||
prev === null ? new Set<string>(allMcpNames) : new Set<string>(prev);
|
||||
if (next) base.add(name);
|
||||
else base.delete(name);
|
||||
return base;
|
||||
});
|
||||
};
|
||||
|
||||
const toggleServer = (serverNames: string[], next: boolean) => {
|
||||
setDraftAllowed((prev) => {
|
||||
const base =
|
||||
prev === null ? new Set<string>(allMcpNames) : new Set<string>(prev);
|
||||
if (next) serverNames.forEach((n) => base.add(n));
|
||||
else serverNames.forEach((n) => base.delete(n));
|
||||
return base;
|
||||
});
|
||||
};
|
||||
|
||||
const handleAllowAll = () => setDraftAllowed(null);
|
||||
|
||||
const handleCancel = () => {
|
||||
const baseline = baselineRef.current;
|
||||
setDraftAllowed(baseline === null ? null : new Set(baseline));
|
||||
setSaveError(null);
|
||||
};
|
||||
|
||||
const handleSave = async () => {
|
||||
setSaving(true);
|
||||
setSaveError(null);
|
||||
try {
|
||||
// Only send tool names the server knows about (MCP tools).
|
||||
// The draft may contain lifecycle/synthetic names from the
|
||||
// baseline — strip those to avoid "Unknown MCP tool name" errors.
|
||||
const payload =
|
||||
draftAllowed === null
|
||||
? null
|
||||
: Array.from(draftAllowed)
|
||||
.filter((name) => allMcpNames.has(name))
|
||||
.sort();
|
||||
const result = await saveAllowlist(payload);
|
||||
applyResult(result.enabled_mcp_tools, false);
|
||||
} catch (e: unknown) {
|
||||
const err = e as { body?: { error?: string; unknown?: string[] } };
|
||||
const extra = err.body?.unknown
|
||||
? ` (${err.body.unknown.join(", ")})`
|
||||
: "";
|
||||
setSaveError((err.body?.error || "Save failed") + extra);
|
||||
} finally {
|
||||
setSaving(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleResetToRoleDefault = async () => {
|
||||
if (!resetToRoleDefault) return;
|
||||
setSaving(true);
|
||||
setSaveError(null);
|
||||
try {
|
||||
const result = await resetToRoleDefault();
|
||||
applyResult(result.enabled_mcp_tools, true);
|
||||
} catch (e: unknown) {
|
||||
const err = e as { body?: { error?: string } };
|
||||
setSaveError(err.body?.error || "Reset failed");
|
||||
} finally {
|
||||
setSaving(false);
|
||||
}
|
||||
};
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<div className="flex items-center gap-2 text-xs text-muted-foreground py-3">
|
||||
<Loader2 className="w-3 h-3 animate-spin" />
|
||||
Loading tools…
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (error || !data) {
|
||||
return (
|
||||
<div className="flex items-start gap-2 text-xs text-destructive py-3">
|
||||
<AlertCircle className="w-3.5 h-3.5 mt-0.5 flex-shrink-0" />
|
||||
<span>{error || "Could not load tools"}</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const draftEnabledCount =
|
||||
draftAllowed === null ? allMcpNames.size : draftAllowed.size;
|
||||
const totalMcpCount = allMcpNames.size;
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="flex items-center justify-between mb-1.5">
|
||||
<h4 className="text-[11px] font-semibold text-muted-foreground uppercase tracking-wider flex items-center gap-1.5">
|
||||
<Wrench className="w-3 h-3" /> {title}
|
||||
</h4>
|
||||
<span className="text-[11px] text-muted-foreground">
|
||||
{draftEnabledCount}/{totalMcpCount} MCP enabled
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{caveat && (
|
||||
<div className="flex items-start gap-1.5 text-[11px] text-muted-foreground mb-2 px-2 py-1.5 rounded bg-muted/20 border border-border/40">
|
||||
<AlertCircle className="w-3 h-3 mt-0.5 flex-shrink-0" />
|
||||
<span>{caveat}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{data.stale && (
|
||||
<div className="flex items-start gap-1.5 text-[11px] text-muted-foreground mb-3 px-2 py-1.5 rounded bg-muted/30">
|
||||
<AlertCircle className="w-3 h-3 mt-0.5 flex-shrink-0" />
|
||||
<span>
|
||||
Catalog is unavailable. Start a session once to populate the tool list.
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{(data.lifecycle.length > 0 || data.synthetic.length > 0) && (
|
||||
<CollapsibleGroup
|
||||
title="System tools (always enabled)"
|
||||
count={data.lifecycle.length + data.synthetic.length}
|
||||
expanded={!!expanded["__system"]}
|
||||
onToggle={() =>
|
||||
setExpanded((p) => ({ ...p, __system: !p["__system"] }))
|
||||
}
|
||||
>
|
||||
<div className="flex flex-col">
|
||||
{data.synthetic.map((t) => (
|
||||
<ToolRow
|
||||
key={`syn-${t.name}`}
|
||||
name={t.name}
|
||||
description={t.description}
|
||||
enabled={true}
|
||||
editable={false}
|
||||
/>
|
||||
))}
|
||||
{data.lifecycle.map((t) => (
|
||||
<ToolRow
|
||||
key={`lc-${t.name}`}
|
||||
name={t.name}
|
||||
description={t.description}
|
||||
enabled={true}
|
||||
editable={false}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</CollapsibleGroup>
|
||||
)}
|
||||
|
||||
{data.mcp_servers.map((srv) => {
|
||||
const toolNames = srv.tools.map((t) => t.name);
|
||||
const state = triStateForServer(toolNames, draftAllowed);
|
||||
const enabledInServer =
|
||||
draftAllowed === null
|
||||
? toolNames.length
|
||||
: toolNames.reduce(
|
||||
(n, name) => n + (draftAllowed.has(name) ? 1 : 0),
|
||||
0,
|
||||
);
|
||||
return (
|
||||
<CollapsibleGroup
|
||||
key={srv.name}
|
||||
title={srv.name === "(unknown)" ? "MCP Tools" : srv.name}
|
||||
count={srv.tools.length}
|
||||
badge={`${enabledInServer}/${srv.tools.length}`}
|
||||
expanded={!!expanded[srv.name]}
|
||||
onToggle={() =>
|
||||
setExpanded((p) => ({ ...p, [srv.name]: !p[srv.name] }))
|
||||
}
|
||||
leading={
|
||||
<TriStateCheckbox
|
||||
state={state}
|
||||
onChange={(next) => toggleServer(toolNames, next)}
|
||||
/>
|
||||
}
|
||||
>
|
||||
<div className="flex flex-col">
|
||||
{srv.tools.map((t) => {
|
||||
const enabled =
|
||||
draftAllowed === null ? true : draftAllowed.has(t.name);
|
||||
return (
|
||||
<ToolRow
|
||||
key={`${srv.name}-${t.name}`}
|
||||
name={t.name}
|
||||
description={t.description}
|
||||
enabled={enabled}
|
||||
editable={true}
|
||||
onToggle={(next) => toggleOne(t.name, next)}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</CollapsibleGroup>
|
||||
);
|
||||
})}
|
||||
|
||||
<div className="flex items-center gap-2 pt-3 flex-wrap">
|
||||
{/* Primary actions */}
|
||||
<button
|
||||
onClick={handleSave}
|
||||
disabled={!dirty || saving}
|
||||
className="flex items-center gap-1.5 px-3 py-1.5 rounded-md bg-primary text-primary-foreground text-xs font-medium hover:bg-primary/90 disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
>
|
||||
{saving ? <Loader2 className="w-3 h-3 animate-spin" /> : <Check className="w-3 h-3" />}
|
||||
{saving ? "Saving…" : "Save"}
|
||||
</button>
|
||||
<button
|
||||
onClick={handleCancel}
|
||||
disabled={!dirty || saving}
|
||||
className="px-3 py-1.5 rounded-md border border-border/60 text-xs font-medium text-muted-foreground hover:text-foreground hover:bg-muted/30 disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
|
||||
{/* Status */}
|
||||
{savedRecently && !dirty && (
|
||||
<span className="text-[11px] text-green-500 flex items-center gap-1">
|
||||
<Check className="w-3 h-3" /> Saved
|
||||
</span>
|
||||
)}
|
||||
{dirty && !saving && (
|
||||
<span className="text-[11px] text-amber-500">Unsaved changes</span>
|
||||
)}
|
||||
|
||||
{/* Quick actions */}
|
||||
<div className="ml-auto flex items-center gap-2">
|
||||
<button
|
||||
onClick={handleAllowAll}
|
||||
disabled={saving || draftAllowed === null}
|
||||
className="px-3 py-1.5 rounded-md border border-border/60 text-xs font-medium text-muted-foreground hover:text-foreground hover:bg-muted/30 disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
>
|
||||
Allow all
|
||||
</button>
|
||||
{resetToRoleDefault && (
|
||||
<button
|
||||
onClick={handleResetToRoleDefault}
|
||||
disabled={saving || !!data.is_role_default}
|
||||
className="px-3 py-1.5 rounded-md border border-border/60 text-xs font-medium text-muted-foreground hover:text-foreground hover:bg-muted/30 disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
>
|
||||
Reset to defaults
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{saveError && (
|
||||
<div className="flex items-start gap-1.5 mt-2 text-[11px] text-destructive">
|
||||
<AlertCircle className="w-3 h-3 mt-0.5 flex-shrink-0" />
|
||||
<span>{saveError}</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -140,13 +140,26 @@ export default function PromptLibrary() {
|
||||
promptsApi.list().then((r) => setCustomPrompts(r.prompts)).catch(() => {});
|
||||
}, []);
|
||||
|
||||
// Merge built-in + custom prompts
|
||||
const allPrompts = useMemo(() => [...customPrompts, ...prompts], [customPrompts]);
|
||||
// Filtered custom (my) prompts
|
||||
const filteredCustom = useMemo(() => {
|
||||
let result: (Prompt | CustomPrompt)[] = customPrompts;
|
||||
if (selectedCategory && selectedCategory !== "custom") {
|
||||
result = result.filter((p) => p.category === selectedCategory);
|
||||
}
|
||||
if (searchQuery.trim()) {
|
||||
const query = searchQuery.toLowerCase();
|
||||
result = result.filter(
|
||||
(p) => p.title.toLowerCase().includes(query) || p.content.toLowerCase().includes(query),
|
||||
);
|
||||
}
|
||||
return result;
|
||||
}, [customPrompts, searchQuery, selectedCategory]);
|
||||
|
||||
const filteredPrompts = useMemo(() => {
|
||||
let result = allPrompts;
|
||||
// Filtered built-in (community) prompts
|
||||
const filteredBuiltIn = useMemo(() => {
|
||||
let result: Prompt[] = prompts;
|
||||
if (selectedCategory === "custom") {
|
||||
result = result.filter((p) => "custom" in p && p.custom);
|
||||
result = [];
|
||||
} else if (selectedCategory) {
|
||||
result = result.filter((p) => p.category === selectedCategory);
|
||||
}
|
||||
@@ -157,13 +170,13 @@ export default function PromptLibrary() {
|
||||
);
|
||||
}
|
||||
return result;
|
||||
}, [allPrompts, searchQuery, selectedCategory]);
|
||||
}, [searchQuery, selectedCategory]);
|
||||
|
||||
// Reset page when filters change
|
||||
useEffect(() => setPage(0), [searchQuery, selectedCategory]);
|
||||
|
||||
const totalPages = Math.max(1, Math.ceil(filteredPrompts.length / PAGE_SIZE));
|
||||
const pagedPrompts = filteredPrompts.slice(page * PAGE_SIZE, (page + 1) * PAGE_SIZE);
|
||||
const totalPages = Math.max(1, Math.ceil(filteredBuiltIn.length / PAGE_SIZE));
|
||||
const pagedBuiltIn = filteredBuiltIn.slice(page * PAGE_SIZE, (page + 1) * PAGE_SIZE);
|
||||
|
||||
const handleUsePrompt = (content: string, category: string) => {
|
||||
const queenId = categoryToQueen[category];
|
||||
@@ -196,7 +209,7 @@ export default function PromptLibrary() {
|
||||
Prompt Library
|
||||
</h2>
|
||||
<span className="text-xs text-muted-foreground">
|
||||
{allPrompts.length} prompts across {promptCategories.length + (customCount > 0 ? 1 : 0)} categories
|
||||
{customCount > 0 ? `${customCount} custom · ` : ""}{prompts.length} community prompts
|
||||
</span>
|
||||
</div>
|
||||
<button onClick={() => setAddModalOpen(true)}
|
||||
@@ -240,23 +253,47 @@ export default function PromptLibrary() {
|
||||
|
||||
{/* Prompts grid */}
|
||||
<div className="flex-1 overflow-y-auto p-6">
|
||||
{pagedPrompts.length > 0 ? (
|
||||
<div className="grid grid-cols-1 gap-4 md:grid-cols-2 lg:grid-cols-3">
|
||||
{pagedPrompts.map((prompt) => (
|
||||
<PromptCard
|
||||
key={typeof prompt.id === "string" ? prompt.id : `builtin-${prompt.id}`}
|
||||
prompt={prompt}
|
||||
onUse={handleUsePrompt}
|
||||
onDelete={"custom" in prompt && prompt.custom ? () => handleDeletePrompt(prompt.id as string) : undefined}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
) : (
|
||||
{filteredCustom.length === 0 && pagedBuiltIn.length === 0 ? (
|
||||
<div className="flex flex-col items-center justify-center h-full text-center">
|
||||
<Sparkles className="w-10 h-10 text-muted-foreground/30 mb-3" />
|
||||
<p className="text-sm text-muted-foreground">No prompts found</p>
|
||||
<p className="text-xs text-muted-foreground/60 mt-1">Try adjusting your search or category filter</p>
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
{/* My Prompts section */}
|
||||
{filteredCustom.length > 0 && (
|
||||
<div className="mb-8">
|
||||
<h3 className="text-xs font-semibold text-muted-foreground uppercase tracking-wider mb-3">My Prompts</h3>
|
||||
<div className="grid grid-cols-1 gap-4 md:grid-cols-2 lg:grid-cols-3">
|
||||
{filteredCustom.map((prompt) => (
|
||||
<PromptCard
|
||||
key={prompt.id as string}
|
||||
prompt={prompt}
|
||||
onUse={handleUsePrompt}
|
||||
onDelete={"custom" in prompt && prompt.custom ? () => handleDeletePrompt(prompt.id as string) : undefined}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Community Prompts section */}
|
||||
{pagedBuiltIn.length > 0 && selectedCategory !== "custom" && (
|
||||
<div>
|
||||
<h3 className="text-xs font-semibold text-muted-foreground uppercase tracking-wider mb-3">Community Prompts</h3>
|
||||
<div className="grid grid-cols-1 gap-4 md:grid-cols-2 lg:grid-cols-3">
|
||||
{pagedBuiltIn.map((prompt) => (
|
||||
<PromptCard
|
||||
key={`builtin-${prompt.id}`}
|
||||
prompt={prompt}
|
||||
onUse={handleUsePrompt}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -264,7 +301,7 @@ export default function PromptLibrary() {
|
||||
{totalPages > 1 && (
|
||||
<div className="px-6 py-3 border-t border-border/60 flex items-center justify-between">
|
||||
<span className="text-xs text-muted-foreground">
|
||||
{page * PAGE_SIZE + 1}–{Math.min((page + 1) * PAGE_SIZE, filteredPrompts.length)} of {filteredPrompts.length}
|
||||
{page * PAGE_SIZE + 1}–{Math.min((page + 1) * PAGE_SIZE, filteredBuiltIn.length)} of {filteredBuiltIn.length}
|
||||
</span>
|
||||
<div className="flex items-center gap-1">
|
||||
<button onClick={() => setPage((p) => Math.max(0, p - 1))} disabled={page === 0}
|
||||
|
||||
@@ -23,6 +23,32 @@ import { getQueenForAgent, slugToColonyId } from "@/lib/colony-registry";
|
||||
|
||||
const makeId = () => Math.random().toString(36).slice(2, 9);
|
||||
|
||||
// Remembers the last session the user had open in each queen DM so that
|
||||
// navigating away (e.g. to another queen) and back lands on the session
|
||||
// they were just in, instead of whichever session the server picks.
|
||||
const lastSessionKey = (queenId: string) => `hive:queen:${queenId}:lastSession`;
|
||||
const readLastSession = (queenId: string): string | null => {
|
||||
try {
|
||||
return localStorage.getItem(lastSessionKey(queenId));
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
};
|
||||
const writeLastSession = (queenId: string, sessionId: string) => {
|
||||
try {
|
||||
localStorage.setItem(lastSessionKey(queenId), sessionId);
|
||||
} catch {
|
||||
/* storage disabled/full — best-effort */
|
||||
}
|
||||
};
|
||||
const clearLastSession = (queenId: string) => {
|
||||
try {
|
||||
localStorage.removeItem(lastSessionKey(queenId));
|
||||
} catch {
|
||||
/* ignore */
|
||||
}
|
||||
};
|
||||
|
||||
export default function QueenDM() {
|
||||
const { queenId } = useParams<{ queenId: string }>();
|
||||
const [searchParams, setSearchParams] = useSearchParams();
|
||||
@@ -45,7 +71,17 @@ export default function QueenDM() {
|
||||
{ id: string; prompt: string; options?: string[] }[] | null
|
||||
>(null);
|
||||
const [awaitingInput, setAwaitingInput] = useState(false);
|
||||
const [tokenUsage, setTokenUsage] = useState({ input: 0, output: 0 });
|
||||
// `cached` and `cacheCreated` are subsets of `input` (providers count both
|
||||
// inside prompt_tokens already) — display them, never add them to a total.
|
||||
// `costUsd` is the session-total USD cost when the provider supplies one
|
||||
// (Anthropic, OpenAI, OpenRouter); 0 means unreported, not free.
|
||||
const [tokenUsage, setTokenUsage] = useState({
|
||||
input: 0,
|
||||
output: 0,
|
||||
cached: 0,
|
||||
cacheCreated: 0,
|
||||
costUsd: 0,
|
||||
});
|
||||
const [historySessions, setHistorySessions] = useState<HistorySession[]>([]);
|
||||
const [historyLoading, setHistoryLoading] = useState(false);
|
||||
const [switchingSessionId, setSwitchingSessionId] = useState<string | null>(
|
||||
@@ -92,7 +128,7 @@ export default function QueenDM() {
|
||||
setPendingQuestions(null);
|
||||
setAwaitingInput(false);
|
||||
setQueenPhase("independent");
|
||||
setTokenUsage({ input: 0, output: 0 });
|
||||
setTokenUsage({ input: 0, output: 0, cached: 0, cacheCreated: 0, costUsd: 0 });
|
||||
setInitialDraft(null);
|
||||
setColonySpawned(false);
|
||||
setSpawnedColonyName(null);
|
||||
@@ -160,6 +196,31 @@ export default function QueenDM() {
|
||||
);
|
||||
replayStateRef.current = replayState;
|
||||
|
||||
// Sum historical llm_turn_complete events so Tokens/Cost carry over
|
||||
// across resume. SSE does not replay llm_turn_complete (see
|
||||
// routes_events.py _REPLAY_TYPES), so no double-count risk — live
|
||||
// SSE deltas that may have already landed are kept via functional
|
||||
// merge below.
|
||||
const seed = { input: 0, output: 0, cached: 0, cacheCreated: 0, costUsd: 0 };
|
||||
for (const evt of events) {
|
||||
if (evt.type !== "llm_turn_complete" || !evt.data) continue;
|
||||
const d = evt.data as Record<string, unknown>;
|
||||
seed.input += (d.input_tokens as number) || 0;
|
||||
seed.output += (d.output_tokens as number) || 0;
|
||||
seed.cached += (d.cached_tokens as number) || 0;
|
||||
seed.cacheCreated += (d.cache_creation_tokens as number) || 0;
|
||||
seed.costUsd += (d.cost_usd as number) || 0;
|
||||
}
|
||||
if (!cancelled()) {
|
||||
setTokenUsage((prev) => ({
|
||||
input: prev.input + seed.input,
|
||||
output: prev.output + seed.output,
|
||||
cached: prev.cached + seed.cached,
|
||||
cacheCreated: prev.cacheCreated + seed.cacheCreated,
|
||||
costUsd: prev.costUsd + seed.costUsd,
|
||||
}));
|
||||
}
|
||||
|
||||
// Show a banner if the server truncated older events.
|
||||
const droppedCount = Math.max(0, total - returned);
|
||||
if (truncated && droppedCount > 0) {
|
||||
@@ -199,6 +260,19 @@ export default function QueenDM() {
|
||||
useEffect(() => {
|
||||
if (!queenId) return;
|
||||
|
||||
// If we arrived without an explicit session in the URL and aren't
|
||||
// bootstrapping a new one, redirect to the last session the user had
|
||||
// open for this queen. Session IDs are always of the form
|
||||
// "session_<timestamp>_<hex>", so we gate on that prefix to avoid
|
||||
// redirecting to anything unexpected that landed in storage.
|
||||
if (!selectedSessionParam && newSessionFlag !== "1") {
|
||||
const stored = readLastSession(queenId);
|
||||
if (stored && stored.startsWith("session_")) {
|
||||
setSearchParams({ session: stored }, { replace: true });
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
resetViewState();
|
||||
setLoading(true);
|
||||
|
||||
@@ -314,7 +388,17 @@ export default function QueenDM() {
|
||||
await restoreMessages(sid, () => cancelled);
|
||||
refresh();
|
||||
} catch {
|
||||
// Session creation failed
|
||||
// Session creation/selection failed. If the URL param came from
|
||||
// our own localStorage restore, the stored session is stale (e.g.
|
||||
// deleted on disk) — clear it so the next navigation falls
|
||||
// through to getOrCreate instead of looping on the bad id.
|
||||
if (
|
||||
queenId &&
|
||||
selectedSessionParam &&
|
||||
selectedSessionParam === readLastSession(queenId)
|
||||
) {
|
||||
clearLastSession(queenId);
|
||||
}
|
||||
} finally {
|
||||
if (!cancelled) {
|
||||
setLoading(false);
|
||||
@@ -337,6 +421,13 @@ export default function QueenDM() {
|
||||
setSearchParams,
|
||||
]);
|
||||
|
||||
// Remember the session the user is currently viewing so switching queens
|
||||
// and coming back lands on it instead of whatever the server picks.
|
||||
useEffect(() => {
|
||||
if (!queenId || !sessionId) return;
|
||||
writeLastSession(queenId, sessionId);
|
||||
}, [queenId, sessionId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!queenId) return;
|
||||
let cancelled = false;
|
||||
@@ -520,7 +611,18 @@ export default function QueenDM() {
|
||||
if (event.data) {
|
||||
const inp = (event.data.input_tokens as number) || 0;
|
||||
const out = (event.data.output_tokens as number) || 0;
|
||||
setTokenUsage((prev) => ({ input: prev.input + inp, output: prev.output + out }));
|
||||
// cached / cache_creation are subsets of input — accumulate
|
||||
// separately for display, do NOT roll into input/total.
|
||||
const cached = (event.data.cached_tokens as number) || 0;
|
||||
const cacheCreated = (event.data.cache_creation_tokens as number) || 0;
|
||||
const costUsd = (event.data.cost_usd as number) || 0;
|
||||
setTokenUsage((prev) => ({
|
||||
input: prev.input + inp,
|
||||
output: prev.output + out,
|
||||
cached: prev.cached + cached,
|
||||
cacheCreated: prev.cacheCreated + cacheCreated,
|
||||
costUsd: prev.costUsd + costUsd,
|
||||
}));
|
||||
}
|
||||
// Flush one queued message per LLM turn boundary. This is the
|
||||
// real "turn ended" signal in a queen DM — execution_completed
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,283 @@
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
import { Wrench, Crown, Network, Server, Loader2, AlertCircle } from "lucide-react";
|
||||
import { queensApi } from "@/api/queens";
|
||||
import { coloniesApi, type ColonySummary } from "@/api/colonies";
|
||||
import { slugToDisplayName, sortQueenProfiles } from "@/lib/colony-registry";
|
||||
import QueenToolsSection from "@/components/QueenToolsSection";
|
||||
import ColonyToolsSection from "@/components/ColonyToolsSection";
|
||||
import McpServersPanel from "@/components/McpServersPanel";
|
||||
|
||||
type Tab = "queens" | "colonies" | "mcp";
|
||||
|
||||
export default function ToolLibrary() {
|
||||
const [tab, setTab] = useState<Tab>("queens");
|
||||
|
||||
return (
|
||||
<div className="flex-1 flex flex-col min-w-0 overflow-hidden">
|
||||
{/* Header */}
|
||||
<div className="px-6 py-4 border-b border-border/60">
|
||||
<div className="flex items-baseline gap-3 mb-3">
|
||||
<h2 className="text-lg font-semibold text-foreground flex items-center gap-2">
|
||||
<Wrench className="w-5 h-5 text-primary" />
|
||||
Tool Configuration
|
||||
</h2>
|
||||
<span className="text-xs text-muted-foreground">
|
||||
Curate which tools each queen and colony can call, and register your own MCP servers.
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
<TabButton active={tab === "queens"} onClick={() => setTab("queens")} icon={<Crown className="w-3.5 h-3.5" />}>
|
||||
Queens
|
||||
</TabButton>
|
||||
<TabButton active={tab === "colonies"} onClick={() => setTab("colonies")} icon={<Network className="w-3.5 h-3.5" />}>
|
||||
Colonies
|
||||
</TabButton>
|
||||
<TabButton active={tab === "mcp"} onClick={() => setTab("mcp")} icon={<Server className="w-3.5 h-3.5" />}>
|
||||
MCP Servers
|
||||
</TabButton>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex-1 overflow-y-auto">
|
||||
{tab === "queens" && <QueensTab />}
|
||||
{tab === "colonies" && <ColoniesTab />}
|
||||
{tab === "mcp" && (
|
||||
<div className="px-6 py-6 max-w-4xl">
|
||||
<McpServersPanel />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function TabButton({
|
||||
active,
|
||||
onClick,
|
||||
icon,
|
||||
children,
|
||||
}: {
|
||||
active: boolean;
|
||||
onClick: () => void;
|
||||
icon: React.ReactNode;
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
return (
|
||||
<button
|
||||
onClick={onClick}
|
||||
className={`flex items-center gap-1.5 px-3 py-1.5 rounded-md text-sm font-medium ${
|
||||
active
|
||||
? "bg-primary/15 text-primary"
|
||||
: "text-muted-foreground hover:text-foreground hover:bg-muted/30"
|
||||
}`}
|
||||
>
|
||||
{icon}
|
||||
{children}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
|
||||
// ----- Queens tab ---------------------------------------------------------
|
||||
|
||||
function QueensTab() {
|
||||
const [queens, setQueens] = useState<Array<{ id: string; name: string; title: string }> | null>(null);
|
||||
const [selected, setSelected] = useState<string | null>(null);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
queensApi
|
||||
.list()
|
||||
.then((r) => {
|
||||
const sorted = sortQueenProfiles(r.queens);
|
||||
setQueens(sorted);
|
||||
if (sorted.length > 0) setSelected((prev) => prev ?? sorted[0].id);
|
||||
})
|
||||
.catch((e: Error) => setError(e.message || "Failed to load queens"));
|
||||
}, []);
|
||||
|
||||
if (error) return <ErrorBlock message={error} />;
|
||||
if (queens === null) return <LoadingBlock label="Loading queens…" />;
|
||||
if (queens.length === 0)
|
||||
return <EmptyBlock label="No queens yet. Create one to curate its tools." />;
|
||||
|
||||
return (
|
||||
<div className="flex h-full">
|
||||
<SidePicker>
|
||||
{queens.map((q) => (
|
||||
<PickerItem
|
||||
key={q.id}
|
||||
active={selected === q.id}
|
||||
onClick={() => setSelected(q.id)}
|
||||
primary={q.name}
|
||||
secondary={q.title}
|
||||
/>
|
||||
))}
|
||||
</SidePicker>
|
||||
<div className="flex-1 overflow-y-auto px-6 py-5 min-w-0">
|
||||
{selected ? (
|
||||
<>
|
||||
{(() => {
|
||||
const queen = queens.find((q) => q.id === selected);
|
||||
return queen ? (
|
||||
<div className="mb-4 pb-3 border-b border-border/40">
|
||||
<h3 className="text-base font-semibold text-foreground">
|
||||
{queen.name}
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground mt-0.5">
|
||||
{queen.title}
|
||||
</p>
|
||||
</div>
|
||||
) : null;
|
||||
})()}
|
||||
<QueenToolsSection queenId={selected} />
|
||||
</>
|
||||
) : (
|
||||
<EmptyBlock label="Pick a queen to edit her tool allowlist." />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ----- Colonies tab -------------------------------------------------------
|
||||
|
||||
function ColoniesTab() {
|
||||
const [colonies, setColonies] = useState<ColonySummary[] | null>(null);
|
||||
const [selected, setSelected] = useState<string | null>(null);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
coloniesApi
|
||||
.list()
|
||||
.then((r) => {
|
||||
setColonies(r.colonies);
|
||||
if (r.colonies.length > 0)
|
||||
setSelected((prev) => prev ?? r.colonies[0].name);
|
||||
})
|
||||
.catch((e: Error) => setError(e.message || "Failed to load colonies"));
|
||||
}, []);
|
||||
|
||||
const sorted = useMemo(() => {
|
||||
if (!colonies) return null;
|
||||
return [...colonies].sort((a, b) => a.name.localeCompare(b.name));
|
||||
}, [colonies]);
|
||||
|
||||
if (error) return <ErrorBlock message={error} />;
|
||||
if (sorted === null) return <LoadingBlock label="Loading colonies…" />;
|
||||
if (sorted.length === 0)
|
||||
return (
|
||||
<EmptyBlock label="No colonies yet. Ask a queen to incubate one and its tools will show up here." />
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="flex h-full">
|
||||
<SidePicker>
|
||||
{sorted.map((c) => (
|
||||
<PickerItem
|
||||
key={c.name}
|
||||
active={selected === c.name}
|
||||
onClick={() => setSelected(c.name)}
|
||||
primary={slugToDisplayName(c.name)}
|
||||
secondary={
|
||||
c.has_allowlist
|
||||
? `${c.enabled_count ?? 0} tools allowed · ${c.queen_name ?? ""}`
|
||||
: `all tools · ${c.queen_name ?? ""}`
|
||||
}
|
||||
tertiary={c.name}
|
||||
/>
|
||||
))}
|
||||
</SidePicker>
|
||||
<div className="flex-1 overflow-y-auto px-6 py-5 min-w-0">
|
||||
{selected ? (
|
||||
<>
|
||||
<div className="mb-4 pb-3 border-b border-border/40">
|
||||
<h3 className="text-base font-semibold text-foreground">
|
||||
{slugToDisplayName(selected)}
|
||||
</h3>
|
||||
<p className="text-[11px] text-muted-foreground font-mono mt-0.5">
|
||||
{selected}
|
||||
</p>
|
||||
</div>
|
||||
<ColonyToolsSection colonyName={selected} />
|
||||
</>
|
||||
) : (
|
||||
<EmptyBlock label="Pick a colony to edit its tool allowlist." />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ----- Shared primitives --------------------------------------------------
|
||||
|
||||
function SidePicker({ children }: { children: React.ReactNode }) {
|
||||
return (
|
||||
<div className="w-[260px] flex-shrink-0 border-r border-border/60 overflow-y-auto py-3 px-2 flex flex-col gap-1">
|
||||
{children}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function PickerItem({
|
||||
active,
|
||||
onClick,
|
||||
primary,
|
||||
secondary,
|
||||
tertiary,
|
||||
}: {
|
||||
active: boolean;
|
||||
onClick: () => void;
|
||||
primary: string;
|
||||
secondary?: string;
|
||||
tertiary?: string;
|
||||
}) {
|
||||
return (
|
||||
<button
|
||||
onClick={onClick}
|
||||
className={`text-left px-3 py-2 rounded-md text-sm ${
|
||||
active
|
||||
? "bg-primary/15 text-primary"
|
||||
: "text-foreground hover:bg-muted/30"
|
||||
}`}
|
||||
>
|
||||
<div className="font-medium truncate">{primary}</div>
|
||||
{secondary && (
|
||||
<div className="text-[11px] text-muted-foreground truncate">
|
||||
{secondary}
|
||||
</div>
|
||||
)}
|
||||
{tertiary && (
|
||||
<div className="text-[10px] text-muted-foreground/60 font-mono truncate">
|
||||
{tertiary}
|
||||
</div>
|
||||
)}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
|
||||
function LoadingBlock({ label }: { label: string }) {
|
||||
return (
|
||||
<div className="flex items-center gap-2 text-xs text-muted-foreground px-6 py-6">
|
||||
<Loader2 className="w-3 h-3 animate-spin" />
|
||||
{label}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function EmptyBlock({ label }: { label: string }) {
|
||||
return (
|
||||
<div className="flex items-start gap-2 text-xs text-muted-foreground px-6 py-6">
|
||||
<AlertCircle className="w-3.5 h-3.5 mt-0.5" />
|
||||
<span>{label}</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ErrorBlock({ message }: { message: string }) {
|
||||
return (
|
||||
<div className="flex items-start gap-2 text-xs text-destructive px-6 py-6">
|
||||
<AlertCircle className="w-3.5 h-3.5 mt-0.5 flex-shrink-0" />
|
||||
<span>{message}</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -42,6 +42,7 @@ _HIVE_PATH_CONSUMERS = (
|
||||
"framework.server.session_manager",
|
||||
"framework.server.queen_orchestrator",
|
||||
"framework.server.routes_queens",
|
||||
"framework.server.routes_skills",
|
||||
"framework.server.app",
|
||||
"framework.agents.discovery",
|
||||
"framework.agents.queen.queen_profiles",
|
||||
|
||||
@@ -9,26 +9,25 @@ from framework.llm.provider import Tool
|
||||
|
||||
|
||||
class TestSupportsImageToolResults:
|
||||
"""Verify the deny-list correctly identifies models that can't handle images."""
|
||||
"""Verify catalog-driven vision capability checks."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4-turbo",
|
||||
"openai/gpt-4o",
|
||||
"anthropic/claude-sonnet-4-20250514",
|
||||
# Catalog entries with supports_vision=true
|
||||
"claude-haiku-4-5-20251001",
|
||||
"gemini/gemini-1.5-pro",
|
||||
"google/gemini-1.5-flash",
|
||||
"mistral/mistral-large",
|
||||
"groq/llama3-70b",
|
||||
"together/meta-llama/Llama-3-70b",
|
||||
"fireworks_ai/llama-v3-70b",
|
||||
"azure/gpt-4o",
|
||||
"kimi/claude-sonnet-4-20250514",
|
||||
"hive/claude-sonnet-4-20250514",
|
||||
"claude-sonnet-4-5-20250929",
|
||||
"claude-opus-4-6",
|
||||
"gpt-5.4",
|
||||
"gpt-5.4-mini",
|
||||
"gemini-3-flash-preview",
|
||||
"kimi-k2.5",
|
||||
# Provider-prefixed catalog entries
|
||||
"openrouter/openai/gpt-5.4",
|
||||
"openrouter/anthropic/claude-sonnet-4.6",
|
||||
# Unknown models default to True (hosted frontier assumption)
|
||||
"some-future-model",
|
||||
"azure/gpt-5",
|
||||
],
|
||||
)
|
||||
def test_supported_models(self, model: str):
|
||||
@@ -37,27 +36,24 @@ class TestSupportsImageToolResults:
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"deepseek/deepseek-chat",
|
||||
"deepseek/deepseek-coder",
|
||||
"deepseek-chat",
|
||||
# Catalog entries with supports_vision=false
|
||||
"deepseek-reasoner",
|
||||
"ollama/llama3",
|
||||
"ollama/mistral",
|
||||
"ollama_chat/llama3",
|
||||
"lm_studio/my-model",
|
||||
"vllm/meta-llama/Llama-3-70b",
|
||||
"llamacpp/model",
|
||||
"cerebras/llama3-70b",
|
||||
"deepseek-v4-pro",
|
||||
"deepseek-v4-flash",
|
||||
"glm-5.1",
|
||||
"queen",
|
||||
"MiniMax-M2.7",
|
||||
"codestral-2508",
|
||||
"llama-3.3-70b-versatile",
|
||||
# Provider-prefixed forms resolve to the same catalog entry
|
||||
"deepseek/deepseek-reasoner",
|
||||
"hive/glm-5.1",
|
||||
"groq/llama-3.3-70b-versatile",
|
||||
],
|
||||
)
|
||||
def test_unsupported_models(self, model: str):
|
||||
assert supports_image_tool_results(model) is False
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert supports_image_tool_results("DeepSeek/deepseek-chat") is False
|
||||
assert supports_image_tool_results("OLLAMA/llama3") is False
|
||||
assert supports_image_tool_results("GPT-4o") is True
|
||||
|
||||
|
||||
class TestFilterToolsForModel:
|
||||
"""Verify ``filter_tools_for_model`` — the real helper used by AgentLoop."""
|
||||
@@ -68,7 +64,7 @@ class TestFilterToolsForModel:
|
||||
Tool(name="browser_screenshot", description="take a screenshot", produces_image=True),
|
||||
Tool(name="browser_snapshot", description="get page content"),
|
||||
]
|
||||
filtered, hidden = filter_tools_for_model(tools, "glm-5")
|
||||
filtered, hidden = filter_tools_for_model(tools, "glm-5.1")
|
||||
names = [t.name for t in filtered]
|
||||
assert "browser_screenshot" not in names
|
||||
assert "read_file" in names
|
||||
@@ -80,7 +76,7 @@ class TestFilterToolsForModel:
|
||||
Tool(name="read_file", description="read a file"),
|
||||
Tool(name="browser_screenshot", description="take a screenshot", produces_image=True),
|
||||
]
|
||||
filtered, hidden = filter_tools_for_model(tools, "claude-sonnet-4-20250514")
|
||||
filtered, hidden = filter_tools_for_model(tools, "claude-sonnet-4-5-20250929")
|
||||
assert {t.name for t in filtered} == {"read_file", "browser_screenshot"}
|
||||
assert hidden == []
|
||||
|
||||
@@ -90,8 +86,8 @@ class TestFilterToolsForModel:
|
||||
Tool(name="read_file", description="read a file"),
|
||||
Tool(name="web_search", description="search the web"),
|
||||
]
|
||||
text_only, text_hidden = filter_tools_for_model(tools, "glm-5")
|
||||
vision, vision_hidden = filter_tools_for_model(tools, "gpt-4o")
|
||||
text_only, text_hidden = filter_tools_for_model(tools, "glm-5.1")
|
||||
vision, vision_hidden = filter_tools_for_model(tools, "claude-sonnet-4-5-20250929")
|
||||
assert len(text_only) == 2 and text_hidden == []
|
||||
assert len(vision) == 2 and vision_hidden == []
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ needs and run everything against a temp directory.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
import json
|
||||
from pathlib import Path
|
||||
@@ -278,6 +279,17 @@ async def test_colony_spawn_creates_correct_artifacts(tmp_path, monkeypatch):
|
||||
assert resp.status == 200, await resp.text()
|
||||
body = await resp.json()
|
||||
|
||||
# fork_session_into_colony schedules the compaction + worker-storage
|
||||
# copy onto _BACKGROUND_FORK_TASKS and returns. In prod the colony-
|
||||
# open path blocks on compaction_status.await_completion; the test
|
||||
# skips that step, so drain the bg tasks here before asserting on
|
||||
# the artifacts they produce (otherwise the worker-storage check is
|
||||
# a race that flakes under CI load).
|
||||
from framework.server.routes_execution import _BACKGROUND_FORK_TASKS
|
||||
|
||||
if _BACKGROUND_FORK_TASKS:
|
||||
await asyncio.gather(*list(_BACKGROUND_FORK_TASKS), return_exceptions=True)
|
||||
|
||||
colony_session_id = body["queen_session_id"]
|
||||
assert body["colony_name"] == "honeycomb"
|
||||
assert body["is_new"] is True
|
||||
|
||||
@@ -63,6 +63,7 @@ class MockStreamingLLM(LLMProvider):
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
**kwargs,
|
||||
) -> AsyncIterator:
|
||||
self.stream_calls.append({"messages": messages, "system": system, "tools": tools})
|
||||
|
||||
@@ -311,7 +312,7 @@ class TestReportToParent:
|
||||
model: str = "mock"
|
||||
stream_calls: list[dict] = []
|
||||
|
||||
async def stream(self, messages, system="", tools=None, max_tokens=4096):
|
||||
async def stream(self, messages, system="", tools=None, max_tokens=4096, **kwargs):
|
||||
self.stream_calls.append({"messages": messages})
|
||||
raise RuntimeError("boom — simulated LLM crash")
|
||||
yield # pragma: no cover — make this an async generator
|
||||
@@ -485,8 +486,12 @@ class TestReportToParentGatingByStream:
|
||||
try:
|
||||
# Spawn a parallel worker — its tool list should include report_to_parent
|
||||
await colony.spawn(task="test", count=1)
|
||||
# After the worker's first LLM call, check the recorded tools
|
||||
await asyncio.sleep(0.2) # let the background task run
|
||||
# Poll until the worker fires its first LLM call. Bare sleeps were
|
||||
# flaky on slow Windows CI; loop with a generous deadline instead.
|
||||
for _ in range(100):
|
||||
if llm.stream_calls:
|
||||
break
|
||||
await asyncio.sleep(0.05)
|
||||
assert llm.stream_calls, "Worker never called the LLM"
|
||||
worker_tools = llm.stream_calls[0]["tools"]
|
||||
tool_names = [t.name for t in (worker_tools or [])]
|
||||
|
||||
@@ -38,7 +38,7 @@ from framework.tools.queen_lifecycle_tools import register_queen_lifecycle_tools
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, sid: str = "session_test_create_colony"):
|
||||
def __init__(self, sid: str = "session_test_create_colony", queen_name: str = "sophia"):
|
||||
self.id = sid
|
||||
self.colony = None
|
||||
self.colony_runtime = None
|
||||
@@ -46,6 +46,7 @@ class _FakeSession:
|
||||
self.worker_path = None
|
||||
self.available_triggers: dict = {}
|
||||
self.active_trigger_ids: set = set()
|
||||
self.queen_name = queen_name
|
||||
|
||||
|
||||
def _make_executor():
|
||||
@@ -161,6 +162,60 @@ async def test_happy_path_emits_colony_created_event(patched_home: Path, patched
|
||||
assert ev.data.get("is_new") is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio
|
||||
async def test_colony_inherits_queen_override_state(patched_home: Path, patched_fork: list[dict]) -> None:
|
||||
"""Seed the colony's skills_overrides.json from the queen's at fork
|
||||
time. A queen who enabled a preset (e.g. hive.x-automation) before
|
||||
calling create_colony must produce a colony that also has it
|
||||
enabled — without needing a second UI toggle on the colony page.
|
||||
"""
|
||||
from framework.config import QUEENS_DIR
|
||||
from framework.skills.overrides import (
|
||||
OverrideEntry,
|
||||
Provenance,
|
||||
SkillOverrideStore,
|
||||
)
|
||||
|
||||
# Pre-seed the queen's override file.
|
||||
queen_home = QUEENS_DIR / "sophia"
|
||||
queen_home.mkdir(parents=True, exist_ok=True)
|
||||
qstore = SkillOverrideStore.load(queen_home / "skills_overrides.json")
|
||||
qstore.upsert(
|
||||
"hive.x-automation",
|
||||
OverrideEntry(enabled=True, provenance=Provenance.PRESET),
|
||||
)
|
||||
qstore.upsert(
|
||||
"hive.note-taking",
|
||||
OverrideEntry(enabled=False, provenance=Provenance.FRAMEWORK),
|
||||
)
|
||||
qstore.save()
|
||||
|
||||
executor, _ = _make_executor()
|
||||
payload = await _call(
|
||||
executor,
|
||||
colony_name="inheritance_check",
|
||||
task="t",
|
||||
skill_name="bespoke-skill",
|
||||
skill_description="Written during this create_colony call.",
|
||||
skill_body=_DEFAULT_BODY,
|
||||
)
|
||||
assert payload.get("status") == "created", f"Tool error: {payload}"
|
||||
|
||||
colony_overrides = patched_home / ".hive" / "colonies" / "inheritance_check" / "skills_overrides.json"
|
||||
cstore = SkillOverrideStore.load(colony_overrides)
|
||||
|
||||
# Inherited entries from the queen:
|
||||
assert cstore.get("hive.x-automation").enabled is True
|
||||
assert cstore.get("hive.note-taking").enabled is False
|
||||
|
||||
# Newly-written skill is also registered with queen_created provenance:
|
||||
bespoke = cstore.get("bespoke-skill")
|
||||
assert bespoke is not None
|
||||
assert bespoke.provenance == Provenance.QUEEN_CREATED
|
||||
assert bespoke.enabled is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_happy_path_materializes_skill_under_colony_dir(patched_home: Path, patched_fork: list[dict]) -> None:
|
||||
"""Inline skill content is written to ~/.hive/colonies/{colony}/.hive/skills/{name}/."""
|
||||
@@ -204,6 +259,20 @@ async def test_happy_path_materializes_skill_under_colony_dir(patched_home: Path
|
||||
assert f"description: {description}" in text
|
||||
assert "HoneyComb API Operational Protocol" in text
|
||||
|
||||
# create_colony should also register the skill in the colony's
|
||||
# override store with ``queen_created`` provenance so the UI can
|
||||
# display it as queen-authored + editable.
|
||||
from framework.skills.overrides import Provenance, SkillOverrideStore
|
||||
|
||||
overrides_path = patched_home / ".hive" / "colonies" / "honeycomb_research" / "skills_overrides.json"
|
||||
assert overrides_path.exists(), "create_colony should write a skills_overrides.json ledger"
|
||||
store = SkillOverrideStore.load(overrides_path)
|
||||
entry = store.get("honeycomb-api-protocol")
|
||||
assert entry is not None
|
||||
assert entry.provenance == Provenance.QUEEN_CREATED
|
||||
assert entry.enabled is True
|
||||
assert (entry.created_by or "").startswith("queen:")
|
||||
|
||||
# Critically: the skill must NOT land in the shared user-scope dir —
|
||||
# that was the leak we are fixing.
|
||||
assert not (patched_home / ".hive" / "skills" / "honeycomb-api-protocol").exists()
|
||||
|
||||
@@ -812,6 +812,9 @@ class TestConveniencePublishers:
|
||||
model="claude-sonnet-4-20250514",
|
||||
input_tokens=100,
|
||||
output_tokens=50,
|
||||
cached_tokens=30,
|
||||
cache_creation_tokens=10,
|
||||
cost_usd=0.0042,
|
||||
execution_id="exec_1",
|
||||
iteration=3,
|
||||
)
|
||||
@@ -822,6 +825,11 @@ class TestConveniencePublishers:
|
||||
assert received[0].data["model"] == "claude-sonnet-4-20250514"
|
||||
assert received[0].data["input_tokens"] == 100
|
||||
assert received[0].data["output_tokens"] == 50
|
||||
# cached / cache_creation are subsets of input — propagated for
|
||||
# display, NOT additive to input_tokens.
|
||||
assert received[0].data["cached_tokens"] == 30
|
||||
assert received[0].data["cache_creation_tokens"] == 10
|
||||
assert received[0].data["cost_usd"] == 0.0042
|
||||
assert received[0].data["iteration"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -53,6 +53,7 @@ class MockStreamingLLM(LLMProvider):
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
**kwargs,
|
||||
) -> AsyncIterator:
|
||||
self.stream_calls.append({"messages": messages, "system": system, "tools": tools})
|
||||
if not self.scenarios:
|
||||
@@ -1079,7 +1080,7 @@ class ErrorThenSuccessLLM(LLMProvider):
|
||||
self.success_scenario = success_scenario
|
||||
self._call_index = 0
|
||||
|
||||
async def stream(self, messages, system="", tools=None, max_tokens=4096):
|
||||
async def stream(self, messages, system="", tools=None, max_tokens=4096, **kwargs):
|
||||
call_num = self._call_index
|
||||
self._call_index += 1
|
||||
if call_num < self.fail_count:
|
||||
@@ -1201,7 +1202,7 @@ class TestTransientErrorRetry:
|
||||
class StreamErrorThenSuccessLLM(LLMProvider):
|
||||
model: str = "mock"
|
||||
|
||||
async def stream(self, messages, system="", tools=None, max_tokens=4096):
|
||||
async def stream(self, messages, system="", tools=None, max_tokens=4096, **kwargs):
|
||||
nonlocal call_index
|
||||
idx = call_index
|
||||
call_index += 1
|
||||
@@ -1390,7 +1391,7 @@ class ToolRepeatLLM(LLMProvider):
|
||||
self.final_text = final_text
|
||||
self._call_index = 0
|
||||
|
||||
async def stream(self, messages, system="", tools=None, max_tokens=4096):
|
||||
async def stream(self, messages, system="", tools=None, max_tokens=4096, **kwargs):
|
||||
idx = self._call_index
|
||||
self._call_index += 1
|
||||
# Which outer iteration we're in (2 calls per iteration)
|
||||
@@ -1999,7 +2000,7 @@ class TestToolConcurrencyPartition:
|
||||
def __init__(self):
|
||||
self._calls = 0
|
||||
|
||||
async def stream(self, messages, system="", tools=None, max_tokens=4096):
|
||||
async def stream(self, messages, system="", tools=None, max_tokens=4096, **kwargs):
|
||||
self._calls += 1
|
||||
if self._calls == 1:
|
||||
# Emit the tool call, stall, then finish.
|
||||
|
||||
@@ -23,9 +23,14 @@ from framework.llm.anthropic import AnthropicProvider
|
||||
from framework.llm.litellm import (
|
||||
OPENROUTER_TOOL_COMPAT_MODEL_CACHE,
|
||||
LiteLLMProvider,
|
||||
_build_system_message,
|
||||
_compute_retry_delay,
|
||||
_cost_from_tokens,
|
||||
_ensure_ollama_chat_prefix,
|
||||
_extract_cache_tokens,
|
||||
_extract_cost,
|
||||
_is_ollama_model,
|
||||
_model_supports_cache_control,
|
||||
_summarize_request_for_log,
|
||||
)
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
@@ -1192,3 +1197,446 @@ class TestGetLlmExtraKwargsOllama:
|
||||
with patch("framework.config.get_hive_config", return_value={}):
|
||||
result = get_llm_extra_kwargs()
|
||||
assert result == {}
|
||||
|
||||
|
||||
class TestModelSupportsCacheControl:
|
||||
"""`cache_control` allowlist covers native providers AND OpenRouter sub-providers
|
||||
whose upstream API honors the marker (Anthropic, Gemini, GLM, MiniMax).
|
||||
Auto-cache sub-providers (OpenAI, DeepSeek, Grok, Moonshot, Groq) are
|
||||
intentionally excluded: sending cache_control is a no-op and a false win."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"anthropic/claude-opus-4-5",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"minimax/minimax-text-01",
|
||||
"MiniMax-Text-01",
|
||||
"zai-glm-4.6",
|
||||
"glm-4.6",
|
||||
"openrouter/anthropic/claude-opus-4.5",
|
||||
"openrouter/anthropic/claude-sonnet-4.5",
|
||||
"openrouter/google/gemini-2.5-pro",
|
||||
"openrouter/google/gemini-2.5-flash",
|
||||
"openrouter/z-ai/glm-5.1",
|
||||
"openrouter/z-ai/glm-4.6",
|
||||
"openrouter/minimax/minimax-text-01",
|
||||
],
|
||||
)
|
||||
def test_supported(self, model):
|
||||
assert _model_supports_cache_control(model) is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"gpt-4o-mini",
|
||||
"gemini/gemini-1.5-flash",
|
||||
"ollama_chat/llama3",
|
||||
"openrouter/openai/gpt-4o",
|
||||
"openrouter/deepseek/deepseek-chat",
|
||||
"openrouter/x-ai/grok-2",
|
||||
"openrouter/moonshotai/kimi-k2",
|
||||
"openrouter/liquid/lfm-2.5-1.2b-thinking:free",
|
||||
],
|
||||
)
|
||||
def test_unsupported(self, model):
|
||||
assert _model_supports_cache_control(model) is False
|
||||
|
||||
|
||||
class TestBuildSystemMessageOpenRouter:
|
||||
"""`_build_system_message` should split static/dynamic blocks whenever
|
||||
the model — native OR OpenRouter-routed — supports cache_control."""
|
||||
|
||||
def test_openrouter_anthropic_splits_into_two_blocks(self):
|
||||
msg = _build_system_message(
|
||||
system="static prefix",
|
||||
system_dynamic_suffix="dynamic tail",
|
||||
model="openrouter/anthropic/claude-opus-4.5",
|
||||
)
|
||||
assert msg == {
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "static prefix",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
},
|
||||
{"type": "text", "text": "dynamic tail"},
|
||||
],
|
||||
}
|
||||
|
||||
def test_openrouter_gemini_splits_into_two_blocks(self):
|
||||
msg = _build_system_message(
|
||||
system="static prefix",
|
||||
system_dynamic_suffix="dynamic tail",
|
||||
model="openrouter/google/gemini-2.5-pro",
|
||||
)
|
||||
assert isinstance(msg["content"], list)
|
||||
assert msg["content"][0]["cache_control"] == {"type": "ephemeral"}
|
||||
assert msg["content"][1] == {"type": "text", "text": "dynamic tail"}
|
||||
|
||||
def test_openrouter_glm_splits_into_two_blocks(self):
|
||||
msg = _build_system_message(
|
||||
system="static prefix",
|
||||
system_dynamic_suffix="dynamic tail",
|
||||
model="openrouter/z-ai/glm-5.1",
|
||||
)
|
||||
assert isinstance(msg["content"], list)
|
||||
assert msg["content"][0]["cache_control"] == {"type": "ephemeral"}
|
||||
|
||||
def test_openrouter_openai_stays_concatenated(self):
|
||||
"""OpenAI via OpenRouter auto-caches; sending cache_control is a no-op."""
|
||||
msg = _build_system_message(
|
||||
system="static prefix",
|
||||
system_dynamic_suffix="dynamic tail",
|
||||
model="openrouter/openai/gpt-4o",
|
||||
)
|
||||
assert msg == {
|
||||
"role": "system",
|
||||
"content": "static prefix\n\ndynamic tail",
|
||||
}
|
||||
|
||||
def test_no_suffix_anthropic_gets_top_level_cache_control(self):
|
||||
msg = _build_system_message(
|
||||
system="static prefix",
|
||||
system_dynamic_suffix=None,
|
||||
model="openrouter/anthropic/claude-opus-4.5",
|
||||
)
|
||||
assert msg == {
|
||||
"role": "system",
|
||||
"content": "static prefix",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}
|
||||
|
||||
|
||||
class TestOpenRouterToolCompatCacheControl:
|
||||
"""Tool-compat path must pass cache_control through when the routed
|
||||
sub-provider honors it. Before this, the queen persona+tool-list prefix
|
||||
was recomputed every turn on Anthropic/GLM via OpenRouter."""
|
||||
|
||||
def test_tool_compat_messages_split_for_cache_capable_model(self):
|
||||
provider = LiteLLMProvider(
|
||||
model="openrouter/anthropic/claude-opus-4.5",
|
||||
api_key="test-key",
|
||||
)
|
||||
tools = [
|
||||
Tool(
|
||||
name="web_search",
|
||||
description="Search the web",
|
||||
parameters={"properties": {"query": {"type": "string"}}, "required": ["query"]},
|
||||
)
|
||||
]
|
||||
|
||||
full_messages = provider._build_openrouter_tool_compat_messages(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
system="You are a queen.",
|
||||
tools=tools,
|
||||
system_dynamic_suffix="Current time: 2026-04-23T00:00:00Z",
|
||||
)
|
||||
|
||||
system_msg = full_messages[0]
|
||||
assert system_msg["role"] == "system"
|
||||
assert isinstance(system_msg["content"], list)
|
||||
assert len(system_msg["content"]) == 2
|
||||
|
||||
static_block = system_msg["content"][0]
|
||||
assert static_block["cache_control"] == {"type": "ephemeral"}
|
||||
assert "You are a queen." in static_block["text"]
|
||||
assert "Tool compatibility mode is active" in static_block["text"]
|
||||
assert "web_search" in static_block["text"]
|
||||
assert "2026-04-23" not in static_block["text"]
|
||||
|
||||
dynamic_block = system_msg["content"][1]
|
||||
assert "cache_control" not in dynamic_block
|
||||
assert dynamic_block["text"] == "Current time: 2026-04-23T00:00:00Z"
|
||||
|
||||
def test_tool_compat_messages_stay_concatenated_for_liquid(self):
|
||||
"""Liquid (and other non-cache-control OR sub-providers) keep legacy behavior."""
|
||||
provider = LiteLLMProvider(
|
||||
model="openrouter/liquid/lfm-2.5-1.2b-thinking:free",
|
||||
api_key="test-key",
|
||||
)
|
||||
tools = [
|
||||
Tool(
|
||||
name="web_search",
|
||||
description="Search the web",
|
||||
parameters={"properties": {"query": {"type": "string"}}, "required": ["query"]},
|
||||
)
|
||||
]
|
||||
|
||||
full_messages = provider._build_openrouter_tool_compat_messages(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
system="You are a queen.",
|
||||
tools=tools,
|
||||
system_dynamic_suffix="Current time: 2026-04-23T00:00:00Z",
|
||||
)
|
||||
|
||||
system_msg = full_messages[0]
|
||||
assert isinstance(system_msg["content"], str)
|
||||
assert "2026-04-23" in system_msg["content"]
|
||||
assert "cache_control" not in system_msg
|
||||
|
||||
|
||||
class TestExtractCacheTokens:
|
||||
"""`_extract_cache_tokens` reads cache_read + cache_creation from the
|
||||
LiteLLM-normalized usage object. Both fields are subsets of
|
||||
``prompt_tokens`` — the helper surfaces them for display, the call sites
|
||||
are responsible for never adding them to a total."""
|
||||
|
||||
def test_none_usage_returns_zero(self):
|
||||
assert _extract_cache_tokens(None) == (0, 0)
|
||||
|
||||
def test_openai_shape(self):
|
||||
"""Pure OpenAI responses expose cached reads via
|
||||
``prompt_tokens_details.cached_tokens`` and have no cache write
|
||||
field at all (OpenAI's automatic caching is read-only from the
|
||||
client's perspective)."""
|
||||
usage = MagicMock(spec=["prompt_tokens_details", "cache_creation_input_tokens"])
|
||||
usage.prompt_tokens_details = MagicMock(
|
||||
spec=["cached_tokens"],
|
||||
cached_tokens=120,
|
||||
)
|
||||
usage.cache_creation_input_tokens = 0
|
||||
cache_read, cache_creation = _extract_cache_tokens(usage)
|
||||
assert cache_read == 120
|
||||
assert cache_creation == 0
|
||||
|
||||
def test_openrouter_cache_write_tokens_shape(self):
|
||||
"""OpenRouter normalizes cache writes into
|
||||
``prompt_tokens_details.cache_write_tokens`` (verified empirically
|
||||
against openrouter/anthropic and openrouter/z-ai responses). The
|
||||
legacy ``usage.cache_creation_input_tokens`` field is NOT set on
|
||||
OpenRouter responses, so this is the path that matters in practice."""
|
||||
usage = MagicMock()
|
||||
usage.prompt_tokens_details = MagicMock(
|
||||
cached_tokens=80,
|
||||
cache_write_tokens=50,
|
||||
)
|
||||
# Explicitly set the Anthropic-native field to 0 to prove we don't
|
||||
# depend on it for OpenRouter responses.
|
||||
usage.cache_creation_input_tokens = 0
|
||||
cache_read, cache_creation = _extract_cache_tokens(usage)
|
||||
assert cache_read == 80
|
||||
assert cache_creation == 50
|
||||
|
||||
def test_anthropic_native_cache_creation_field_still_works(self):
|
||||
"""Direct Anthropic API responses (not via OpenRouter) put cache
|
||||
writes on the top-level ``cache_creation_input_tokens`` field. Keep
|
||||
the fallback so non-OpenRouter Anthropic continues to work."""
|
||||
usage = MagicMock(spec=["prompt_tokens_details", "cache_creation_input_tokens"])
|
||||
usage.prompt_tokens_details = MagicMock(
|
||||
spec=["cached_tokens"],
|
||||
cached_tokens=80,
|
||||
)
|
||||
usage.cache_creation_input_tokens = 50
|
||||
cache_read, cache_creation = _extract_cache_tokens(usage)
|
||||
assert cache_read == 80
|
||||
assert cache_creation == 50
|
||||
|
||||
def test_raw_anthropic_shape_falls_back(self):
|
||||
"""Raw Anthropic usage (no prompt_tokens_details) — fall back to
|
||||
cache_read_input_tokens."""
|
||||
usage = MagicMock(spec=["cache_read_input_tokens", "cache_creation_input_tokens"])
|
||||
usage.cache_read_input_tokens = 200
|
||||
usage.cache_creation_input_tokens = 75
|
||||
# Force prompt_tokens_details to be missing on the spec'd mock.
|
||||
cache_read, cache_creation = _extract_cache_tokens(usage)
|
||||
assert cache_read == 200
|
||||
assert cache_creation == 75
|
||||
|
||||
def test_no_cache_fields_returns_zero(self):
|
||||
"""A provider that doesn't report cache tokens at all (e.g. Gemini)
|
||||
returns (0, 0) — never raises."""
|
||||
usage = MagicMock(spec=["prompt_tokens", "completion_tokens"])
|
||||
cache_read, cache_creation = _extract_cache_tokens(usage)
|
||||
assert cache_read == 0
|
||||
assert cache_creation == 0
|
||||
|
||||
|
||||
class TestStreamingChunksFallbackPreservesCacheFields:
|
||||
"""Regression: when LiteLLM strips usage from yielded streaming chunks,
|
||||
we fall back to ``response.chunks`` to recover token totals. LiteLLM's
|
||||
own ``calculate_total_usage()`` aggregates ``prompt_tokens`` /
|
||||
``completion_tokens`` correctly but DROPS ``prompt_tokens_details`` —
|
||||
which is where OpenRouter places ``cached_tokens`` and
|
||||
``cache_write_tokens``. The fallback path must walk the raw chunks to
|
||||
recover those fields, otherwise streaming OpenRouter calls always
|
||||
report zero cache tokens. (Verified empirically against
|
||||
openrouter/anthropic/* and openrouter/z-ai/*.)"""
|
||||
|
||||
def test_chunks_with_cache_fields_recovered(self):
|
||||
"""Simulate the chunks-fallback hot path: build raw chunks where the
|
||||
last one carries cache_write_tokens, run the same recovery loop the
|
||||
streaming code uses, and assert we surface the cache fields."""
|
||||
# Three chunks: text deltas, then a final chunk with usage.
|
||||
empty_usage_chunk = MagicMock()
|
||||
empty_usage_chunk.usage = None
|
||||
last_chunk = MagicMock()
|
||||
last_chunk.usage = MagicMock()
|
||||
last_chunk.usage.prompt_tokens_details = MagicMock(
|
||||
cached_tokens=0,
|
||||
cache_write_tokens=5601,
|
||||
)
|
||||
last_chunk.usage.cache_creation_input_tokens = 0
|
||||
chunks = [empty_usage_chunk, empty_usage_chunk, last_chunk]
|
||||
|
||||
# Mirror the production loop in litellm.py's chunks-fallback.
|
||||
cached, creation = 0, 0
|
||||
for raw in reversed(chunks):
|
||||
usage = getattr(raw, "usage", None)
|
||||
if usage is None:
|
||||
continue
|
||||
cr, cc = _extract_cache_tokens(usage)
|
||||
if cr or cc:
|
||||
cached, creation = cr, cc
|
||||
break
|
||||
|
||||
assert cached == 0
|
||||
assert creation == 5601, (
|
||||
"chunks-fallback must recover cache_write_tokens from the raw "
|
||||
"chunk, not from calculate_total_usage which strips details"
|
||||
)
|
||||
|
||||
def test_chunks_with_cache_read_recovered(self):
|
||||
"""Same path, but for a cache HIT (cached_tokens populated)."""
|
||||
last_chunk = MagicMock()
|
||||
last_chunk.usage = MagicMock()
|
||||
last_chunk.usage.prompt_tokens_details = MagicMock(
|
||||
cached_tokens=5601,
|
||||
cache_write_tokens=0,
|
||||
)
|
||||
last_chunk.usage.cache_creation_input_tokens = 0
|
||||
|
||||
cached, creation = 0, 0
|
||||
for raw in reversed([last_chunk]):
|
||||
usage = getattr(raw, "usage", None)
|
||||
if usage is None:
|
||||
continue
|
||||
cr, cc = _extract_cache_tokens(usage)
|
||||
if cr or cc:
|
||||
cached, creation = cr, cc
|
||||
break
|
||||
|
||||
assert cached == 5601
|
||||
assert creation == 0
|
||||
|
||||
|
||||
class TestExtractCost:
|
||||
"""`_extract_cost` pulls USD cost from three sources in order:
|
||||
usage.cost (OpenRouter native / include_cost_in_streaming_usage) →
|
||||
response._hidden_params['response_cost'] (LiteLLM logging) →
|
||||
litellm.completion_cost() (pricing-table fallback)."""
|
||||
|
||||
def test_none_response_returns_zero(self):
|
||||
assert _extract_cost(None, "gpt-4o-mini") == 0.0
|
||||
|
||||
def test_openrouter_usage_cost_is_preferred(self):
|
||||
"""OpenRouter returns authoritative per-call cost on usage.cost when
|
||||
the caller opts in (usage.include=true). That beats LiteLLM's
|
||||
pricing-table estimate because it reflects promo pricing and BYOK markup."""
|
||||
response = MagicMock()
|
||||
response.usage = MagicMock(cost=0.00123)
|
||||
response._hidden_params = {"response_cost": 99.99} # should be ignored
|
||||
assert _extract_cost(response, "openrouter/anthropic/claude-opus-4.5") == 0.00123
|
||||
|
||||
def test_hidden_params_response_cost_used_when_no_usage_cost(self):
|
||||
"""LiteLLM's logging layer attaches response_cost after most
|
||||
completions — this is how OpenAI/Anthropic responses get costed
|
||||
without going back to the pricing table."""
|
||||
response = MagicMock()
|
||||
response.usage = MagicMock(spec=[]) # no .cost attribute
|
||||
response._hidden_params = {"response_cost": 0.0042}
|
||||
assert _extract_cost(response, "gpt-4o-mini") == 0.0042
|
||||
|
||||
def test_falls_back_to_completion_cost_when_nothing_pre_populated(self):
|
||||
"""For providers where LiteLLM didn't pre-populate cost, call
|
||||
litellm.completion_cost() against the pricing table. Mocked here
|
||||
because we don't want tests depending on the exact price of
|
||||
claude-sonnet-4.5 in LiteLLM's model map."""
|
||||
response = MagicMock()
|
||||
response.usage = MagicMock(spec=[])
|
||||
response._hidden_params = {}
|
||||
with patch("litellm.completion_cost", return_value=0.00789):
|
||||
assert _extract_cost(response, "anthropic/claude-sonnet-4.5") == 0.00789
|
||||
|
||||
def test_completion_cost_exception_returns_zero(self):
|
||||
"""Unpriced models (e.g. new OpenRouter routes not yet in LiteLLM's
|
||||
catalog) must not crash the hot path."""
|
||||
response = MagicMock()
|
||||
response.usage = MagicMock(spec=[])
|
||||
response._hidden_params = {}
|
||||
with patch("litellm.completion_cost", side_effect=Exception("no pricing")):
|
||||
assert _extract_cost(response, "openrouter/mystery/model") == 0.0
|
||||
|
||||
def test_zero_cost_falls_through_to_next_source(self):
|
||||
"""usage.cost == 0 should NOT short-circuit; fall through to
|
||||
_hidden_params / completion_cost so we don't cement a false zero."""
|
||||
response = MagicMock()
|
||||
response.usage = MagicMock(cost=0.0)
|
||||
response._hidden_params = {"response_cost": 0.0055}
|
||||
assert _extract_cost(response, "gpt-4o-mini") == 0.0055
|
||||
|
||||
|
||||
class TestCostFromTokens:
|
||||
"""`_cost_from_tokens` is the streaming-path cost helper: stream wrappers
|
||||
don't expose the full ModelResponse shape that completion_cost() expects,
|
||||
so we go through cost_per_token() with the already-extracted totals."""
|
||||
|
||||
def test_zero_tokens_returns_zero_without_calling_litellm(self):
|
||||
with patch("litellm.cost_per_token") as mock:
|
||||
assert _cost_from_tokens("claude-opus-4.5", 0, 0) == 0.0
|
||||
mock.assert_not_called()
|
||||
|
||||
def test_empty_model_returns_zero(self):
|
||||
assert _cost_from_tokens("", 1000, 500) == 0.0
|
||||
|
||||
def test_computes_from_tokens(self):
|
||||
with patch("litellm.cost_per_token", return_value=(0.001, 0.002)) as mock:
|
||||
cost = _cost_from_tokens(
|
||||
"anthropic/claude-opus-4.5",
|
||||
input_tokens=1000,
|
||||
output_tokens=500,
|
||||
cached_tokens=200,
|
||||
cache_creation_tokens=100,
|
||||
)
|
||||
assert cost == pytest.approx(0.003)
|
||||
# Verify the cache-aware kwargs are threaded through — Anthropic
|
||||
# needs these to apply the 1.25x write / 0.1x read multipliers.
|
||||
call_kwargs = mock.call_args.kwargs
|
||||
assert call_kwargs["prompt_tokens"] == 1000
|
||||
assert call_kwargs["completion_tokens"] == 500
|
||||
assert call_kwargs["cache_read_input_tokens"] == 200
|
||||
assert call_kwargs["cache_creation_input_tokens"] == 100
|
||||
|
||||
def test_exception_returns_zero(self):
|
||||
with patch("litellm.cost_per_token", side_effect=Exception("unpriced")):
|
||||
assert _cost_from_tokens("mystery/model", 1000, 500) == 0.0
|
||||
|
||||
def test_negative_or_none_components_coerce_to_zero(self):
|
||||
"""LiteLLM returns (None, None) for unknown models in some versions;
|
||||
treat as 0 rather than crashing on None+None."""
|
||||
with patch("litellm.cost_per_token", return_value=(None, None)):
|
||||
assert _cost_from_tokens("some/model", 1, 1) == 0.0
|
||||
|
||||
|
||||
class TestLLMResponseAndFinishEventHaveCostUsd:
|
||||
"""Regression: both LLMResponse and FinishEvent must carry cost_usd so
|
||||
the agent loop → event bus → frontend pipeline doesn't lose cost."""
|
||||
|
||||
def test_llm_response_defaults_cost_to_zero(self):
|
||||
from framework.llm.provider import LLMResponse
|
||||
|
||||
r = LLMResponse(content="", model="m")
|
||||
assert r.cost_usd == 0.0
|
||||
|
||||
def test_finish_event_defaults_cost_to_zero(self):
|
||||
from framework.llm.stream_events import FinishEvent
|
||||
|
||||
e = FinishEvent()
|
||||
assert e.cost_usd == 0.0
|
||||
|
||||
def test_finish_event_accepts_cost(self):
|
||||
from framework.llm.stream_events import FinishEvent
|
||||
|
||||
e = FinishEvent(cost_usd=0.0123)
|
||||
assert e.cost_usd == 0.0123
|
||||
|
||||
@@ -24,12 +24,12 @@ def test_default_models_exist_in_each_provider_catalogue():
|
||||
|
||||
|
||||
def test_find_model_returns_curated_token_limits():
|
||||
model = model_catalog.find_model("openai", "gpt-5.4")
|
||||
model = model_catalog.find_model("openai", "gpt-5.5")
|
||||
|
||||
assert model is not None
|
||||
assert model["label"] == "GPT-5.4 - Best intelligence"
|
||||
assert model["label"] == "GPT-5.5 - Frontier coding + reasoning"
|
||||
assert model["max_tokens"] == 128000
|
||||
assert model["max_context_tokens"] == 960000
|
||||
assert model["max_context_tokens"] == 1050000
|
||||
|
||||
|
||||
def test_anthropic_curated_limits_track_documented_caps_with_safe_input_budget():
|
||||
@@ -125,15 +125,22 @@ def test_deepseek_catalog_tracks_current_api_models():
|
||||
deepseek_default = model_catalog.get_default_models()["deepseek"]
|
||||
deepseek_models = model_catalog.get_models_catalogue()["deepseek"]
|
||||
|
||||
assert deepseek_default == "deepseek-chat"
|
||||
assert deepseek_default == "deepseek-v4-pro"
|
||||
assert [model["id"] for model in deepseek_models] == [
|
||||
"deepseek-chat",
|
||||
"deepseek-v4-pro",
|
||||
"deepseek-v4-flash",
|
||||
"deepseek-reasoner",
|
||||
]
|
||||
assert deepseek_models[0]["max_tokens"] == 8192
|
||||
assert deepseek_models[0]["max_context_tokens"] == 128000
|
||||
assert deepseek_models[1]["max_tokens"] == 64000
|
||||
assert deepseek_models[1]["max_context_tokens"] == 128000
|
||||
# V4 family — 1M context, 384k max output, mirrors api-docs.deepseek.com pricing.
|
||||
assert deepseek_models[0]["max_tokens"] == 384000
|
||||
assert deepseek_models[0]["max_context_tokens"] == 1000000
|
||||
assert deepseek_models[0]["pricing_usd_per_mtok"]["input"] == 1.74
|
||||
assert deepseek_models[0]["pricing_usd_per_mtok"]["output"] == 3.48
|
||||
assert deepseek_models[1]["pricing_usd_per_mtok"]["input"] == 0.14
|
||||
assert deepseek_models[1]["pricing_usd_per_mtok"]["output"] == 0.28
|
||||
# Legacy reasoner kept for back-compat while users migrate.
|
||||
assert deepseek_models[2]["max_tokens"] == 64000
|
||||
assert deepseek_models[2]["max_context_tokens"] == 128000
|
||||
|
||||
|
||||
def test_openrouter_catalog_tracks_current_frontier_set():
|
||||
|
||||
@@ -298,6 +298,35 @@ class TestNodeConversation:
|
||||
assert conv.compaction_warning() is True
|
||||
assert conv.needs_compaction() is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_needs_compaction_uses_hybrid_buffer(self):
|
||||
"""Hybrid: effective buffer is fixed_tokens + ratio * max_context.
|
||||
|
||||
With max=1000, fixed=200, ratio=0.1 → effective_buffer=300, so
|
||||
the trigger threshold is 700.
|
||||
"""
|
||||
conv = NodeConversation(
|
||||
max_context_tokens=1000,
|
||||
compaction_buffer_tokens=200,
|
||||
compaction_buffer_ratio=0.1,
|
||||
)
|
||||
conv.update_token_count(650)
|
||||
assert conv.needs_compaction() is False
|
||||
conv.update_token_count(700)
|
||||
assert conv.needs_compaction() is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_needs_compaction_ratio_only(self):
|
||||
"""Ratio component alone (without a fixed floor) still works."""
|
||||
conv = NodeConversation(
|
||||
max_context_tokens=1000,
|
||||
compaction_buffer_ratio=0.25,
|
||||
)
|
||||
conv.update_token_count(740)
|
||||
assert conv.needs_compaction() is False
|
||||
conv.update_token_count(760)
|
||||
assert conv.needs_compaction() is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_legacy_threshold_rule_still_works_without_buffer(self):
|
||||
"""Without compaction_buffer_tokens, the old multiplicative rule
|
||||
|
||||
@@ -0,0 +1,269 @@
|
||||
"""HTTP integration tests for the skills routes.
|
||||
|
||||
Covers the per-queen, per-colony, and aggregated-library surfaces plus
|
||||
the multipart upload handler. Uses aiohttp's TestClient directly (no
|
||||
pytest-aiohttp plugin), which is why each test sets up its own client.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import zipfile
|
||||
from collections.abc import AsyncIterator
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from aiohttp import web
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
from framework.server.routes_skills import register_routes
|
||||
from framework.skills.overrides import (
|
||||
OverrideEntry,
|
||||
Provenance,
|
||||
SkillOverrideStore,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
class _StubSessionManager:
|
||||
"""Tiny stand-in that satisfies the iter_* contracts used by routes.
|
||||
|
||||
The routes_skills handlers call ``manager.iter_queen_sessions`` and
|
||||
``manager.iter_colony_runtimes`` to find live managers to reload.
|
||||
In-process tests don't spin up runtimes, so these iterators yield
|
||||
nothing — the routes fall back to the admin manager built from disk.
|
||||
"""
|
||||
|
||||
def iter_queen_sessions(self, queen_id: str):
|
||||
return iter([])
|
||||
|
||||
def iter_colony_runtimes(self, *, queen_id=None, colony_name=None):
|
||||
return iter([])
|
||||
|
||||
|
||||
def _build_app() -> web.Application:
|
||||
application = web.Application()
|
||||
application["manager"] = _StubSessionManager()
|
||||
register_routes(application)
|
||||
return application
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client() -> AsyncIterator[TestClient]:
|
||||
app = _build_app()
|
||||
server = TestServer(app)
|
||||
async with TestClient(server) as tc:
|
||||
yield tc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _seed_queen(tmp_path: Path):
|
||||
"""Write a queen profile so _queen_scope recognises the id."""
|
||||
queen_home = Path.home() / ".hive" / "agents" / "queens" / "ops"
|
||||
queen_home.mkdir(parents=True, exist_ok=True)
|
||||
(queen_home / "profile.yaml").write_text("name: Ops\ntitle: Ops queen\n", encoding="utf-8")
|
||||
return queen_home
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _seed_colony(tmp_path: Path):
|
||||
colony_home = Path.home() / ".hive" / "colonies" / "research_one"
|
||||
colony_home.mkdir(parents=True, exist_ok=True)
|
||||
return colony_home
|
||||
|
||||
|
||||
async def test_get_queen_skills_returns_empty_for_fresh_queen(client: TestClient, _seed_queen) -> None:
|
||||
resp = await client.get("/api/queen/ops/skills")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["queen_id"] == "ops"
|
||||
assert data["all_defaults_disabled"] is False
|
||||
# Fresh install → framework default skills show up via discovery.
|
||||
assert isinstance(data["skills"], list)
|
||||
|
||||
|
||||
async def test_create_queen_skill_writes_file_and_override(client: TestClient, _seed_queen) -> None:
|
||||
payload = {
|
||||
"name": "ops-runbook",
|
||||
"description": "Runbook for ops",
|
||||
"body": "## Steps\n1. Check\n",
|
||||
"enabled": True,
|
||||
}
|
||||
resp = await client.post("/api/queen/ops/skills", json=payload)
|
||||
assert resp.status == 201
|
||||
data = await resp.json()
|
||||
assert data["name"] == "ops-runbook"
|
||||
# Verify files were written to the queen skill dir.
|
||||
skill_md = _seed_queen / "skills" / "ops-runbook" / "SKILL.md"
|
||||
assert skill_md.exists()
|
||||
# Verify override was registered with USER_UI_CREATED provenance.
|
||||
store = SkillOverrideStore.load(_seed_queen / "skills_overrides.json")
|
||||
entry = store.get("ops-runbook")
|
||||
assert entry is not None
|
||||
assert entry.provenance == Provenance.USER_UI_CREATED
|
||||
assert entry.enabled is True
|
||||
|
||||
|
||||
async def test_patch_queen_skill_toggles_enabled(client: TestClient, _seed_queen) -> None:
|
||||
await client.post(
|
||||
"/api/queen/ops/skills",
|
||||
json={"name": "ops-a", "description": "a", "body": "body"},
|
||||
)
|
||||
resp = await client.patch(
|
||||
"/api/queen/ops/skills/ops-a",
|
||||
json={"enabled": False},
|
||||
)
|
||||
assert resp.status == 200
|
||||
store = SkillOverrideStore.load(_seed_queen / "skills_overrides.json")
|
||||
assert store.get("ops-a").enabled is False
|
||||
|
||||
|
||||
async def test_delete_queen_skill_removes_files(client: TestClient, _seed_queen) -> None:
|
||||
await client.post(
|
||||
"/api/queen/ops/skills",
|
||||
json={"name": "tmp-skill", "description": "d", "body": "body"},
|
||||
)
|
||||
skill_dir = _seed_queen / "skills" / "tmp-skill"
|
||||
assert skill_dir.exists()
|
||||
|
||||
resp = await client.delete("/api/queen/ops/skills/tmp-skill")
|
||||
assert resp.status == 200
|
||||
assert not skill_dir.exists()
|
||||
store = SkillOverrideStore.load(_seed_queen / "skills_overrides.json")
|
||||
assert "tmp-skill" in store.deleted_ui_skills
|
||||
|
||||
|
||||
async def test_delete_framework_skill_is_refused(client: TestClient, _seed_queen) -> None:
|
||||
# Pre-seed an override entry with framework provenance — simulates the
|
||||
# user toggling a framework default so the override exists on disk.
|
||||
store = SkillOverrideStore.load(_seed_queen / "skills_overrides.json")
|
||||
store.upsert(
|
||||
"hive.note-taking",
|
||||
OverrideEntry(enabled=False, provenance=Provenance.FRAMEWORK),
|
||||
)
|
||||
store.save()
|
||||
|
||||
resp = await client.delete("/api/queen/ops/skills/hive.note-taking")
|
||||
assert resp.status == 403
|
||||
|
||||
|
||||
async def test_upload_markdown_places_in_user_library(client: TestClient) -> None:
|
||||
skill_md = "---\nname: from-upload\ndescription: Uploaded skill\n---\n\n## Body\nHi.\n"
|
||||
form = {
|
||||
"file": skill_md.encode("utf-8"),
|
||||
"scope": "user",
|
||||
"enabled": "true",
|
||||
}
|
||||
# Use multipart writer pattern: aiohttp test client auto-serializes dicts.
|
||||
data = _as_form(form, filename="SKILL.md")
|
||||
resp = await client.post("/api/skills/upload", data=data)
|
||||
assert resp.status == 201
|
||||
body = await resp.json()
|
||||
assert body["name"] == "from-upload"
|
||||
assert (Path.home() / ".hive" / "skills" / "from-upload" / "SKILL.md").exists()
|
||||
|
||||
|
||||
async def test_upload_zip_bundle_places_in_queen_scope(client: TestClient, _seed_queen) -> None:
|
||||
# Build a zip in memory with SKILL.md + a supporting file.
|
||||
buf = io.BytesIO()
|
||||
with zipfile.ZipFile(buf, "w") as z:
|
||||
z.writestr(
|
||||
"SKILL.md",
|
||||
"---\nname: zipped-skill\ndescription: From zip\n---\n\nbody\n",
|
||||
)
|
||||
z.writestr("scripts/helper.py", "print('hi')\n")
|
||||
payload = buf.getvalue()
|
||||
form = {
|
||||
"file": payload,
|
||||
"scope": "queen",
|
||||
"target_id": "ops",
|
||||
"enabled": "true",
|
||||
}
|
||||
data = _as_form(form, filename="bundle.zip")
|
||||
resp = await client.post("/api/skills/upload", data=data)
|
||||
assert resp.status == 201
|
||||
skill_dir = _seed_queen / "skills" / "zipped-skill"
|
||||
assert (skill_dir / "SKILL.md").exists()
|
||||
assert (skill_dir / "scripts" / "helper.py").exists()
|
||||
|
||||
|
||||
async def test_patch_does_not_mislabel_legacy_colony_skill_as_framework(client: TestClient, _seed_colony) -> None:
|
||||
"""Regression: toggling a legacy colony skill (no ledger entry yet)
|
||||
must not stamp provenance=FRAMEWORK on the new entry. Before the fix,
|
||||
the first PATCH wrote FRAMEWORK and the next GET displayed 'Framework'
|
||||
instead of the queen-authored label.
|
||||
"""
|
||||
skill_dir = _seed_colony / ".hive" / "skills" / "legacy-queen-skill"
|
||||
skill_dir.mkdir(parents=True)
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\nname: legacy-queen-skill\ndescription: From create_colony\n---\n\nbody\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
resp = await client.patch(
|
||||
"/api/colonies/research_one/skills/legacy-queen-skill",
|
||||
json={"enabled": False},
|
||||
)
|
||||
assert resp.status == 200
|
||||
|
||||
list_resp = await client.get("/api/colonies/research_one/skills")
|
||||
rows = {r["name"]: r for r in (await list_resp.json())["skills"]}
|
||||
assert rows["legacy-queen-skill"]["provenance"] == "queen_created"
|
||||
assert rows["legacy-queen-skill"]["enabled"] is False
|
||||
|
||||
|
||||
async def test_colony_skill_is_editable_even_without_override_entry(client: TestClient, _seed_colony) -> None:
|
||||
"""Regression: a SKILL.md dropped into a colony's .hive/skills dir
|
||||
(e.g. from a pre-override-store colony) must still be marked editable
|
||||
when listed via /api/colonies/{name}/skills. The admin manager used
|
||||
to set project_root=colony_home, which retagged the skill as
|
||||
source_scope='project' and fell back to PROJECT_DROPPED provenance —
|
||||
flipping editable to False.
|
||||
"""
|
||||
# Write a bare SKILL.md directly; no override ledger entry.
|
||||
skill_dir = _seed_colony / ".hive" / "skills" / "legacy-skill"
|
||||
skill_dir.mkdir(parents=True)
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\nname: legacy-skill\ndescription: A legacy\n---\n\nbody\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
resp = await client.get("/api/colonies/research_one/skills")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
rows = {r["name"]: r for r in data["skills"]}
|
||||
assert "legacy-skill" in rows
|
||||
assert rows["legacy-skill"]["editable"] is True
|
||||
assert rows["legacy-skill"]["source_scope"] == "colony_ui"
|
||||
# Legacy colony skills (no override ledger entry) were authored by
|
||||
# create_colony() before the ledger existed — the fallback provenance
|
||||
# must reflect that, not be misreported as user-UI-created.
|
||||
assert rows["legacy-skill"]["provenance"] == "queen_created"
|
||||
|
||||
|
||||
async def test_list_scopes_enumerates_queens_and_colonies(client: TestClient, _seed_queen, _seed_colony) -> None:
|
||||
resp = await client.get("/api/skills/scopes")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert any(q["id"] == "ops" for q in data["queens"])
|
||||
assert any(c["name"] == "research_one" for c in data["colonies"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _as_form(fields: dict, *, filename: str):
|
||||
"""Build aiohttp FormData; bytes entries are attached as file parts."""
|
||||
from aiohttp import FormData
|
||||
|
||||
fd = FormData()
|
||||
for key, value in fields.items():
|
||||
if isinstance(value, bytes):
|
||||
fd.add_field(key, value, filename=filename)
|
||||
else:
|
||||
fd.add_field(key, value)
|
||||
return fd
|
||||
@@ -53,6 +53,7 @@ class _ByTaskMockLLM(LLMProvider):
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
**kwargs,
|
||||
) -> AsyncIterator:
|
||||
first_user = ""
|
||||
for m in messages:
|
||||
@@ -329,6 +330,7 @@ class _SlowLLM(LLMProvider):
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
**kwargs,
|
||||
) -> AsyncIterator:
|
||||
self._turn_count += 1
|
||||
# On the second call (after the watcher's inject), check whether the
|
||||
|
||||
@@ -0,0 +1,274 @@
|
||||
"""Tests for the per-scope skill override store and its interaction with SkillsManager."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.skills.authoring import build_draft, write_skill
|
||||
from framework.skills.config import SkillsConfig
|
||||
from framework.skills.discovery import ExtraScope
|
||||
from framework.skills.manager import SkillsManager, SkillsManagerConfig
|
||||
from framework.skills.overrides import (
|
||||
OverrideEntry,
|
||||
Provenance,
|
||||
SkillOverrideStore,
|
||||
)
|
||||
|
||||
|
||||
def _write_skill_file(base: Path, name: str, description: str = "desc") -> Path:
|
||||
skill_dir = base / name
|
||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
f"---\nname: {name}\ndescription: {description}\n---\n\nbody\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
return skill_dir
|
||||
|
||||
|
||||
class TestSkillOverrideStore:
|
||||
def test_load_missing_returns_empty(self, tmp_path: Path) -> None:
|
||||
store = SkillOverrideStore.load(tmp_path / "skills_overrides.json", scope_label="queen:x")
|
||||
assert store.overrides == {}
|
||||
assert store.all_defaults_disabled is False
|
||||
|
||||
def test_upsert_and_save_roundtrip(self, tmp_path: Path) -> None:
|
||||
path = tmp_path / "skills_overrides.json"
|
||||
store = SkillOverrideStore.load(path, scope_label="queen:x")
|
||||
store.upsert(
|
||||
"foo",
|
||||
OverrideEntry(
|
||||
enabled=False,
|
||||
provenance=Provenance.FRAMEWORK,
|
||||
created_at=datetime(2026, 4, 21, tzinfo=UTC),
|
||||
created_by="user",
|
||||
),
|
||||
)
|
||||
store.save()
|
||||
|
||||
raw = json.loads(path.read_text(encoding="utf-8"))
|
||||
assert raw["version"] == 1
|
||||
assert raw["overrides"]["foo"]["enabled"] is False
|
||||
assert raw["overrides"]["foo"]["provenance"] == "framework"
|
||||
|
||||
# Re-load preserves values
|
||||
again = SkillOverrideStore.load(path, scope_label="queen:x")
|
||||
assert again.get("foo") is not None
|
||||
assert again.get("foo").enabled is False
|
||||
|
||||
def test_tombstone_survives_reload(self, tmp_path: Path) -> None:
|
||||
path = tmp_path / "skills_overrides.json"
|
||||
store = SkillOverrideStore.load(path, scope_label="queen:x")
|
||||
store.upsert("foo", OverrideEntry(enabled=True, provenance=Provenance.USER_UI_CREATED))
|
||||
store.remove("foo", tombstone=True)
|
||||
store.save()
|
||||
again = SkillOverrideStore.load(path, scope_label="queen:x")
|
||||
assert "foo" in again.deleted_ui_skills
|
||||
assert again.get("foo") is None
|
||||
|
||||
def test_corrupt_file_loads_empty(self, tmp_path: Path) -> None:
|
||||
path = tmp_path / "skills_overrides.json"
|
||||
path.write_text("{not valid json", encoding="utf-8")
|
||||
store = SkillOverrideStore.load(path, scope_label="queen:x")
|
||||
assert store.overrides == {}
|
||||
|
||||
|
||||
class TestAuthoring:
|
||||
def test_write_and_remove(self, tmp_path: Path) -> None:
|
||||
draft, err = build_draft(
|
||||
skill_name="demo",
|
||||
skill_description="A demo skill",
|
||||
skill_body="## Steps\n1. Do it.\n",
|
||||
skill_files=[{"path": "notes.md", "content": "notes"}],
|
||||
)
|
||||
assert err is None
|
||||
assert draft is not None
|
||||
installed, werr, replaced = write_skill(draft, target_root=tmp_path, replace_existing=True)
|
||||
assert werr is None
|
||||
assert installed is not None
|
||||
assert (installed / "SKILL.md").exists()
|
||||
assert (installed / "notes.md").read_text() == "notes"
|
||||
assert replaced is False
|
||||
|
||||
def test_reject_absolute_path(self, tmp_path: Path) -> None:
|
||||
_, err = build_draft(
|
||||
skill_name="demo",
|
||||
skill_description="desc",
|
||||
skill_body="body",
|
||||
skill_files=[{"path": "/etc/passwd", "content": "oops"}],
|
||||
)
|
||||
assert err is not None
|
||||
assert "relative" in err
|
||||
|
||||
def test_reject_traversal(self, tmp_path: Path) -> None:
|
||||
_, err = build_draft(
|
||||
skill_name="demo",
|
||||
skill_description="desc",
|
||||
skill_body="body",
|
||||
skill_files=[{"path": "../escape.sh", "content": "oops"}],
|
||||
)
|
||||
assert err is not None
|
||||
|
||||
def test_reject_invalid_name(self, tmp_path: Path) -> None:
|
||||
_, err = build_draft(
|
||||
skill_name="Demo_Skill",
|
||||
skill_description="desc",
|
||||
skill_body="body",
|
||||
)
|
||||
assert err is not None
|
||||
|
||||
|
||||
class TestSkillsManagerOverrides:
|
||||
def test_override_disables_framework_default(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Quarantine user-scope and skip framework-scope discovery by pointing HOME
|
||||
# at an empty tmp dir; supply only one "framework" skill manually via an
|
||||
# extra scope tagged as framework so the manager sees it.
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path / "home")
|
||||
fake_fw = tmp_path / "fake_framework"
|
||||
_write_skill_file(fake_fw, "hive.note-taking", "Fake default")
|
||||
|
||||
overrides_path = tmp_path / "queen_overrides.json"
|
||||
store = SkillOverrideStore.load(overrides_path, scope_label="queen:q")
|
||||
store.upsert(
|
||||
"hive.note-taking",
|
||||
OverrideEntry(enabled=False, provenance=Provenance.FRAMEWORK),
|
||||
)
|
||||
store.save()
|
||||
|
||||
mgr = SkillsManager(
|
||||
SkillsManagerConfig(
|
||||
queen_id="q",
|
||||
queen_overrides_path=overrides_path,
|
||||
extra_scope_dirs=[ExtraScope(directory=fake_fw, label="framework", priority=0)],
|
||||
project_root=None,
|
||||
skip_community_discovery=True,
|
||||
interactive=False,
|
||||
)
|
||||
)
|
||||
mgr.load()
|
||||
|
||||
names_enabled = {s.name for s in mgr._catalog._skills.values()} # type: ignore[attr-defined]
|
||||
assert "hive.note-taking" not in names_enabled
|
||||
# Enumeration (for UI rendering) still returns the hidden entry.
|
||||
assert any(s.name == "hive.note-taking" for s in mgr.enumerate_skills_with_source())
|
||||
|
||||
def test_colony_disable_overrides_queen_enable(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path / "home")
|
||||
|
||||
# One skill in a "queen_ui" extra scope.
|
||||
queen_skills = tmp_path / "queen_home" / "skills"
|
||||
_write_skill_file(queen_skills, "shared-skill")
|
||||
|
||||
queen_overrides = tmp_path / "queen_overrides.json"
|
||||
qstore = SkillOverrideStore.load(queen_overrides, scope_label="queen:q")
|
||||
qstore.upsert(
|
||||
"shared-skill",
|
||||
OverrideEntry(enabled=True, provenance=Provenance.USER_UI_CREATED),
|
||||
)
|
||||
qstore.save()
|
||||
|
||||
colony_overrides = tmp_path / "colony_overrides.json"
|
||||
cstore = SkillOverrideStore.load(colony_overrides, scope_label="colony:c")
|
||||
cstore.upsert(
|
||||
"shared-skill",
|
||||
OverrideEntry(enabled=False, provenance=Provenance.USER_UI_CREATED),
|
||||
)
|
||||
cstore.save()
|
||||
|
||||
mgr = SkillsManager(
|
||||
SkillsManagerConfig(
|
||||
queen_id="q",
|
||||
queen_overrides_path=queen_overrides,
|
||||
colony_name="c",
|
||||
colony_overrides_path=colony_overrides,
|
||||
extra_scope_dirs=[ExtraScope(directory=queen_skills, label="queen_ui", priority=2)],
|
||||
project_root=None,
|
||||
skip_community_discovery=True,
|
||||
skills_config=SkillsConfig(),
|
||||
interactive=False,
|
||||
)
|
||||
)
|
||||
mgr.load()
|
||||
enabled = {s.name for s in mgr._catalog._skills.values()} # type: ignore[attr-defined]
|
||||
assert "shared-skill" not in enabled
|
||||
|
||||
def test_preset_scope_is_off_by_default(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Preset-scope skills (bundled capability packs) must stay out
|
||||
of the catalog until the user explicitly opts in."""
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path / "home")
|
||||
fake_presets = tmp_path / "fake_presets"
|
||||
_write_skill_file(fake_presets, "hive.x-automation", "X capability pack")
|
||||
_write_skill_file(fake_presets, "hive.browser-automation", "Browser pack")
|
||||
|
||||
mgr = SkillsManager(
|
||||
SkillsManagerConfig(
|
||||
extra_scope_dirs=[ExtraScope(directory=fake_presets, label="preset", priority=1)],
|
||||
project_root=None,
|
||||
skip_community_discovery=True,
|
||||
interactive=False,
|
||||
)
|
||||
)
|
||||
mgr.load()
|
||||
enabled = {s.name for s in mgr._catalog._skills.values()} # type: ignore[attr-defined]
|
||||
assert "hive.x-automation" not in enabled
|
||||
assert "hive.browser-automation" not in enabled
|
||||
# Enumeration still surfaces them so the UI can offer a toggle.
|
||||
enumerated = {s.name for s in mgr.enumerate_skills_with_source()}
|
||||
assert "hive.x-automation" in enumerated
|
||||
assert "hive.browser-automation" in enumerated
|
||||
|
||||
def test_preset_skill_enabled_via_explicit_override(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path / "home")
|
||||
fake_presets = tmp_path / "fake_presets"
|
||||
_write_skill_file(fake_presets, "hive.x-automation")
|
||||
|
||||
overrides_path = tmp_path / "queen_overrides.json"
|
||||
store = SkillOverrideStore.load(overrides_path, scope_label="queen:q")
|
||||
store.upsert(
|
||||
"hive.x-automation",
|
||||
OverrideEntry(enabled=True, provenance=Provenance.PRESET),
|
||||
)
|
||||
store.save()
|
||||
|
||||
mgr = SkillsManager(
|
||||
SkillsManagerConfig(
|
||||
queen_id="q",
|
||||
queen_overrides_path=overrides_path,
|
||||
extra_scope_dirs=[ExtraScope(directory=fake_presets, label="preset", priority=1)],
|
||||
project_root=None,
|
||||
skip_community_discovery=True,
|
||||
interactive=False,
|
||||
)
|
||||
)
|
||||
mgr.load()
|
||||
enabled = {s.name for s in mgr._catalog._skills.values()} # type: ignore[attr-defined]
|
||||
assert "hive.x-automation" in enabled
|
||||
|
||||
def test_reload_picks_up_store_change(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path / "home")
|
||||
fw = tmp_path / "fw"
|
||||
_write_skill_file(fw, "alpha")
|
||||
path = tmp_path / "queen.json"
|
||||
|
||||
mgr = SkillsManager(
|
||||
SkillsManagerConfig(
|
||||
queen_id="q",
|
||||
queen_overrides_path=path,
|
||||
extra_scope_dirs=[ExtraScope(directory=fw, label="framework", priority=0)],
|
||||
project_root=None,
|
||||
skip_community_discovery=True,
|
||||
interactive=False,
|
||||
)
|
||||
)
|
||||
mgr.load()
|
||||
assert "alpha" in {s.name for s in mgr._catalog._skills.values()} # type: ignore[attr-defined]
|
||||
|
||||
# Disable via override file + reload
|
||||
store = SkillOverrideStore.load(path, scope_label="queen:q")
|
||||
store.upsert("alpha", OverrideEntry(enabled=False, provenance=Provenance.FRAMEWORK))
|
||||
store.save()
|
||||
mgr.reload()
|
||||
assert "alpha" not in {s.name for s in mgr._catalog._skills.values()} # type: ignore[attr-defined]
|
||||
@@ -241,6 +241,8 @@ class TestEventSerialization:
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 20,
|
||||
"cached_tokens": 0,
|
||||
"cache_creation_tokens": 0,
|
||||
"cost_usd": 0.0,
|
||||
"model": "gpt-4",
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
# 🐝 Hive Agent v0.10.3
|
||||
|
||||
> Colonies grow up, and Queen DMs learn to listen.
|
||||
|
||||
v0.10.0 introduced colonies. v0.10.3 is the release where they stop feeling like a new concept bolted on and start feeling like the place you actually work. Alongside that, Queen DMs got the single biggest fix to single-agent chat since we shipped it: you can keep typing while the queen is thinking, and she'll hear you.
|
||||
|
||||
---
|
||||
|
||||
## The Colony, grown up
|
||||
|
||||
When you spawn a colony now, a few things happen that didn't before.
|
||||
|
||||
The queen who spawned it hands off cleanly — her session is compacted first, so the new colony doesn't inherit a bloated context and spend its first ten turns figuring out what it already knows. There's a short **incubating phase** between "spawn requested" and "colony live" where skills, storage, and scheduler tools get set up quietly in the background. By the time the colony is ready, it has its own scoped skill bundle and SQLite — no more cross-colony skill leakage, no more workers belonging to the wrong group.
|
||||
|
||||
The UI finally matches the model. The sidebar groups everything by colony with a DataGrid view, shows the active queen on a dedicated bar inside the colony, and lets you click a worker to open it as its own tab. Tables and workers are scoped to the colony you're looking at, which sounds obvious in hindsight and was a long-standing source of confusion. Queen identity — name, title, avatar — now travels with the queen into message bubbles, the profile pane, and the org chart, so it's consistent no matter where you see her.
|
||||
|
||||
If you were using colonies in v0.10.1 or v0.10.2, this release is the one where the experience stops fighting you.
|
||||
|
||||
## Queen DMs stop eating your keystrokes
|
||||
|
||||
The most common complaint about Queen DMs was simple: if the queen was mid-turn and you thought of something to add, your message either got lost or arrived at a weird moment. That's gone.
|
||||
|
||||
Messages you send while the queen is working now land in a **pending queue**, visible in the chat panel with a **Steer** or **Cancel** control. Steer folds your message into the turn in progress; Cancel drops it. When the queue auto-flushes, the "typing…" indicator no longer flickers, and the old bootstrap race that sometimes rendered your own message twice is fixed.
|
||||
|
||||
The queen also got a proper `ask_user` tool this release, so when she genuinely needs something from you, it shows up as a question — not as a regular chat message you have to parse as one. Tool calls in chat are grouped by session now, so a chatty worker doesn't drown out the queen's own thinking, and her avatar is on every bubble so you can tell who's talking at a glance.
|
||||
|
||||
## Smaller things worth knowing
|
||||
|
||||
- **Prometheus tool** for querying metrics from agents (#7047).
|
||||
- **Scheduler + triggers** got a UI pass, better reliability on trigger message delivery, and scheduler tools are now available during the incubating phase.
|
||||
- **VSCode extension** bumped to 1.0.1 with refreshed icons and a fix for frame-resize jank.
|
||||
- **Model catalog** updates for Xiaomi and OpenRouter selections.
|
||||
- **Runtime reliability:** cancelled executions now fully terminate before a session can restart (#7001), Codex `store=False` is honored correctly (#7089), and the UI handles a broken Aden API key gracefully instead of hanging.
|
||||
|
||||
## Upgrading from v0.10.2
|
||||
|
||||
No migration. Pull `main` at `v0.10.3` and restart Hive — your existing `~/.hive/` profiles, queens, colonies, and sessions keep working.
|
||||
|
||||
One thing to be aware of: worker and table tabs are now scoped per colony. If you expected them to be global, switch colonies in the sidebar to see each colony's own.
|
||||
@@ -0,0 +1,77 @@
|
||||
# 🐝 Hive Agent v0.10.4: Skill & Tool Library
|
||||
|
||||
> Skills and tools move from something the framework hands down into something you curate. Every queen and every colony now has a dedicated allowlist and a UI to manage it, and the system prompt gets smaller and cache-friendlier along the way.
|
||||
|
||||
---
|
||||
|
||||
## ✨ Highlights
|
||||
|
||||
v0.10.4 turns skills and tools into first-class, user-editable surfaces.
|
||||
|
||||
Before this release, the skill and tool catalogs were effectively framework defaults: whatever a queen could reach, every queen could reach. Now each queen has her own tool allowlist, her own skills, and a pair of library pages where you can browse, enable, disable, upload, and author them. Colonies inherit their founding queen's configuration at creation time and then drift on their own — flip a tool off for a colony without touching the queen, or the other way around.
|
||||
|
||||
A quiet but important second theme: the system prompt is now **static** across a session. Date and time — the main source of per-turn churn — are now injected at turn time instead of baked into the prompt. That keeps the prompt prefix stable so provider-side prompt caching can do its job.
|
||||
|
||||
---
|
||||
|
||||
## 🆕 What's New
|
||||
|
||||
### Skill Library
|
||||
|
||||
- **Skill Library page** — browse every skill by scope (queen / colony / framework preset), view SKILL.md inline, toggle per-scope enablement, upload skills as `.md` or `.zip`, and author new skills from the UI.
|
||||
- **Per-scope overrides** — skill enablement is recorded in `~/.hive/agents/queens/{queen_id}/skills_overrides.json` and `~/.hive/colonies/{colony_name}/skills_overrides.json`; framework presets stay read-only, user-authored skills live under each scope's own skills directory.
|
||||
- **Skill provenance** — the API and UI now distinguish framework-preset skills, queen-authored skills, and colony-authored skills, so you can tell at a glance who owns a given skill.
|
||||
- **Skill authoring primitives** — a shared `framework.skills.authoring` module validates names, parses frontmatter, and materializes skill folders for the UI upload path, the `create_colony` tool's inline skills, and future runtime-learned skills.
|
||||
- **Preset rename** — built-in skills moved from `_default_skills/` to `_preset_skills/` to match the new "preset vs. user" split. Existing browser/linkedin/x automation skills carry over untouched.
|
||||
|
||||
### Tool Library
|
||||
|
||||
- **Tool Library page** with a shared `ToolsEditor` component used by the queen profile and colony settings panels.
|
||||
- **Per-queen tool allowlist** at `~/.hive/agents/queens/{queen_id}/tools.json`: `null` = allow all, `[]` = disable all, `["foo", "bar"]` = only these MCP tools pass the filter.
|
||||
- **Per-colony tool allowlist** at `~/.hive/colonies/{colony_name}/tools.json`, with the same schema, atomic writes, and independent lifecycle.
|
||||
- **Configurable defaults** — queens now carry a default tool/skill bundle that seeds each new colony, and the bundle itself is editable.
|
||||
- **Colony inheritance** — when a queen spawns a colony, the colony starts from the queen's tool and skill configuration. After spawn the two diverge freely.
|
||||
- **Colony sidecar** — `tools.json` lives next to `metadata.json` so identity/provenance (queen, created_at, workers) and tool gating evolve independently.
|
||||
|
||||
### MCP Server Management
|
||||
|
||||
- **MCP Servers panel** — dedicated settings UI for browsing, configuring, and enabling bundled and user MCP servers.
|
||||
- **`/api/mcp` routes** for listing built-in servers, inspecting state, and reporting errors with structured MCP error responses.
|
||||
- **Tool catalog wiring** — live queen sessions now surface their MCP tool catalog to the queen-tools and colony-tools endpoints, so the UI shows exactly what the running session can see.
|
||||
|
||||
### Prompt & Runtime
|
||||
|
||||
- **Static system prompt** — the agent loop, conversation, and provider adapters (LiteLLM, Antigravity, Codex, Mock) now build and freeze the system prompt once per session. Per-turn values that used to churn the prompt are gone.
|
||||
- **Date/time injected at turn time** — today's date and current time move from the system prompt into a turn-level injection path that updates cursor persistence and queen-lifecycle tooling.
|
||||
- **Queen orchestrator** — refreshed to pair with the static prompt model and the new tool/skill configuration layers.
|
||||
- **Session manager** — tightened session-creation input validation and reflection/skill edge handling; "create new session and switch branch" is now reliable.
|
||||
|
||||
---
|
||||
|
||||
## 🐛 Bug Fixes
|
||||
|
||||
- **No-cache middleware on `/api/*`** — every API response now carries `Cache-Control: no-store`. Without this, a one-off bad response (e.g. the SPA catch-all leaking `index.html` for an `/api/*` URL before a route was registered) could get pinned in the browser's disk cache and replayed forever, since our JSON handlers don't emit ETag/Last-Modified. Hard-refresh no longer required to recover.
|
||||
- **Tools & skills registration** — queens and colonies no longer end up with stale or duplicated entries after reloads.
|
||||
- **Session creation** — invalid inputs are rejected up front with clear errors instead of surfacing later as runtime failures.
|
||||
- **Skill / reflection edges** — tightened handling so reflection runs no longer see half-built skill state during scope reloads.
|
||||
- **Create new session + switch branch** flow works end-to-end without orphaning sessions.
|
||||
- **CI** — broken workflow repaired.
|
||||
|
||||
---
|
||||
|
||||
## 🧪 Tests
|
||||
|
||||
- `test_routes_skills.py`, `test_skill_overrides.py`, `test_colony_tools.py`, `test_queen_tools.py`, `test_mcp_routes.py` — coverage added for every new route group and the override store.
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Upgrading from v0.10.3
|
||||
|
||||
No migration. Pull `main` at `v0.10.4` and restart Hive — existing `~/.hive/` profiles, queens, colonies, and sessions keep working.
|
||||
|
||||
Two things to know:
|
||||
|
||||
1. **Preset skills directory renamed** from `_default_skills/` to `_preset_skills/` inside the framework. If you had external scripts pointing at that path, update them. User-authored skills under `~/.hive/` are unaffected.
|
||||
2. **First open of a queen or colony writes a `tools.json` sidecar** the first time you edit its allowlist. If you don't touch the Tool Library, nothing is written and behavior matches v0.10.3 (allow all MCP tools).
|
||||
|
||||
Curate your queens. 🐝
|
||||
@@ -0,0 +1,85 @@
|
||||
# 🐝 Hive Agent v0.10.5: Cache-Aware Cost + New Frontier Models
|
||||
|
||||
> A patch release with two big practical wins: real prompt-cache hits across OpenRouter routes (and the cost numbers to prove it), plus first-class entries for GPT-5.5, DeepSeek V4 Pro/Flash, and GLM-5.1.
|
||||
|
||||
---
|
||||
|
||||
## ✨ Highlights
|
||||
|
||||
### 💸 Huge cost cut from prompt caching
|
||||
|
||||
v0.10.4 made the system prompt static so providers could cache it. v0.10.5 actually collects on that work.
|
||||
|
||||
- **`cache_control` now propagates through OpenRouter** for the sub-providers whose upstream APIs honor it: `openrouter/anthropic/*`, `openrouter/google/gemini-*`, `openrouter/z-ai/glm*`, and `openrouter/minimax/*`. Direct Anthropic / Bedrock / Vertex routes already worked; OpenRouter routes were silently no-op'ing the cache marker before.
|
||||
- **Cache-token accounting is unified across providers.** A single `_extract_cache_tokens` helper now reads OpenAI-shape `prompt_tokens_details.cached_tokens`, Anthropic-raw `cache_read_input_tokens`, and OpenRouter's normalized `cache_write_tokens` / `cache_creation_input_tokens` — and surfaces both **cache-read** and **cache-creation** counts (subsets of the input total, never double-counted).
|
||||
- **Streaming cache tokens no longer get dropped.** LiteLLM's `calculate_total_usage` aggregates token totals but discards `prompt_tokens_details`; the stream path now reaches back into the most recent chunk to recover cached/cache-creation counts so the FinishEvent is accurate.
|
||||
- **Cost is reported in USD, not just tokens.** Every `LLMResponse` and `FinishEvent` now carries `cost_usd`. The extractor consults four sources in priority order: native `usage.cost` → LiteLLM `_hidden_params.response_cost` → `litellm.completion_cost` → curated catalog pricing — so models LiteLLM doesn't price (GLM, Kimi, MiniMax, DeepSeek V4) still get accurate numbers via the catalog fallback.
|
||||
- **Persistent cost tracking** — the cost number now flows through the event bus to the chat panel and queen DM, and is persisted across sessions instead of resetting on reload.
|
||||
|
||||
The combined effect: on a long Claude Sonnet / Opus session routed through OpenRouter, the static system prefix is now a cache hit on every turn after the first, and the panel shows you the dollar savings turn-by-turn.
|
||||
|
||||
### 🧠 New frontier models
|
||||
|
||||
- **GPT-5.5** is now the OpenAI default — frontier coding + reasoning, 128k output / 1.05M context, vision-capable.
|
||||
- **DeepSeek V4 Pro** and **DeepSeek V4 Flash** replace `deepseek-chat`. Both ship with **1M context**, **384k max output**, and full cache-read pricing (Pro: $1.74 / $3.48 / $0.145 per Mtok; Flash: $0.14 / $0.28 / $0.028). `deepseek-reasoner` is marked legacy.
|
||||
- **GLM-5.1** replaces `GLM-5` with cache-read pricing wired in.
|
||||
- **Catalog pricing schema** — every model can now declare `pricing_usd_per_mtok` with optional `cache_read` and `cache_creation` rates; validated on load.
|
||||
- **`supports_vision` flag** added to every model in the catalog and consulted by the new vision-fallback path so non-vision models can still receive image inputs via captioning.
|
||||
|
||||
---
|
||||
|
||||
## 🆕 What's New
|
||||
|
||||
### Cost & Cache
|
||||
|
||||
- **`cache_control` for OpenRouter sub-providers** — Anthropic, Gemini, GLM, MiniMax routes now mark the static system prefix as ephemeral cache. (@RichardTang-Aden)
|
||||
- **`_extract_cache_tokens` helper** — single reader for OpenAI / Anthropic / OpenRouter cache-token shapes; returns `(cache_read, cache_creation)`. (@RichardTang-Aden)
|
||||
- **Catalog pricing fallback** — `_cost_from_catalog_pricing` and `_cost_from_tokens` compute USD from `pricing_usd_per_mtok` when LiteLLM's catalog has no entry. (@RichardTang-Aden)
|
||||
- **Streaming usage recovery** — pull cache-token details from the last usage-bearing chunk after `calculate_total_usage` strips them. (@RichardTang-Aden)
|
||||
- **`cost_usd`, `cached_tokens`, `cache_creation_tokens`** added to `LLMResponse`, `FinishEvent`, and the stream-event bus. (@RichardTang-Aden)
|
||||
- **Persistent cost tracking** — costs survive session reload and surface in `ChatPanel` and `queen-dm`. (@RichardTang-Aden)
|
||||
|
||||
### Models & Catalog
|
||||
|
||||
- **GPT-5.5** as the new OpenAI default with 1.05M context + native pricing. (@RichardTang-Aden)
|
||||
- **DeepSeek V4 Pro / Flash** with 1M context, 384k output, and cache-read pricing. (@RichardTang-Aden)
|
||||
- **GLM-5.1** replaces GLM-5; cache-read pricing wired. (@RichardTang-Aden)
|
||||
- **`pricing_usd_per_mtok` schema** — validated `input` / `output` / `cache_read` / `cache_creation` per model. (@RichardTang-Aden)
|
||||
- **`supports_vision` flag** populated for every catalog entry; queried by the new vision-fallback path. (@RichardTang-Aden)
|
||||
- **`get_model_pricing` / `model_supports_vision`** helpers exposed from `framework.llm.model_catalog`. (@RichardTang-Aden)
|
||||
|
||||
### Vision & Agent Loop
|
||||
|
||||
- **Image vision fallback** — `framework.agent_loop.internals.vision_fallback` captions images for non-vision models so the same conversation works regardless of provider capability. (@TimothyZhang7)
|
||||
- **Hybrid compaction buffer** — context compaction now combines a fixed token reserve with a ratio-of-context buffer instead of one or the other. (@RichardTang-Aden)
|
||||
|
||||
### Frontend
|
||||
|
||||
- **Configuration UI redesign** — refreshed sidebar, prompt library, skills library, and tools editor. (@vincentjiang777)
|
||||
- **Cost + token usage in chat** — `ChatPanel` and `queen-dm` show running token consumption and USD cost per session. (@RichardTang-Aden)
|
||||
|
||||
### Tests
|
||||
|
||||
- `test_litellm_provider.py` (+448 lines) covering cache-token extraction, cost-extraction priority order, OpenRouter compat-mode cache wiring, and streaming usage recovery.
|
||||
- `test_model_catalog.py` extended for the new pricing schema and `supports_vision` flag.
|
||||
- `test_event_bus.py` / `test_stream_events.py` extended for the new cost + cache fields.
|
||||
|
||||
---
|
||||
|
||||
## 🐛 Bug Fixes
|
||||
|
||||
- **Vision caption** — fix incorrect caption attachment in the vision-fallback path. (@TimothyZhang7)
|
||||
- **Colony-fork test flake** — drain background fork tasks before asserting on colony-spawn artifacts. (@RichardTang-Aden)
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Upgrading from v0.10.4
|
||||
|
||||
No migration. Pull `main` at `v0.10.5` and restart Hive — existing `~/.hive/` profiles, queens, colonies, and sessions keep working.
|
||||
|
||||
Two things to know:
|
||||
|
||||
1. **Default DeepSeek model changed** from `deepseek-chat` to `deepseek-v4-pro`. If a queen is pinned to `deepseek-chat`, that id is gone from the catalog — pick `deepseek-v4-pro` or `deepseek-v4-flash`.
|
||||
2. **Default OpenAI model changed** from `gpt-5.4` to `gpt-5.5`. `gpt-5.4` stays in the catalog as the previous-flagship option.
|
||||
|
||||
Cache the prompts. 🐝
|
||||
+220
-7
@@ -1042,6 +1042,49 @@ print(json.dumps(config, indent=2))
|
||||
PY
|
||||
}
|
||||
|
||||
save_vision_fallback() {
|
||||
# Write the `vision_fallback` block to ~/.hive/configuration.json.
|
||||
# Args: provider_id, model, env_var (api_key_env_var), api_base (optional)
|
||||
# When provider_id is empty, REMOVE the block entirely (user opted out).
|
||||
local provider_id="$1"
|
||||
local model="$2"
|
||||
local env_var="$3"
|
||||
local api_base="${4:-}"
|
||||
|
||||
uv run python - "$provider_id" "$model" "$env_var" "$api_base" <<'PY'
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
provider_id, model, env_var, api_base = sys.argv[1:5]
|
||||
|
||||
cfg_path = Path.home() / ".hive" / "configuration.json"
|
||||
cfg_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
with open(cfg_path, encoding="utf-8-sig") as f:
|
||||
config = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
config = {}
|
||||
|
||||
# Empty provider_id means the user opted out — drop the block.
|
||||
if not provider_id:
|
||||
config.pop("vision_fallback", None)
|
||||
else:
|
||||
block = {"provider": provider_id, "model": model}
|
||||
if env_var:
|
||||
block["api_key_env_var"] = env_var
|
||||
if api_base:
|
||||
block["api_base"] = api_base
|
||||
config["vision_fallback"] = block
|
||||
|
||||
tmp_path = cfg_path.with_name(cfg_path.name + ".tmp")
|
||||
with open(tmp_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
tmp_path.replace(cfg_path)
|
||||
PY
|
||||
}
|
||||
|
||||
# Source shell rc file to pick up existing env vars (temporarily disable set -e)
|
||||
set +e
|
||||
if [ -f "$SHELL_RC_FILE" ]; then
|
||||
@@ -1309,9 +1352,11 @@ fi
|
||||
echo ""
|
||||
echo -e " ${CYAN}${BOLD}API key providers:${NC}"
|
||||
|
||||
# 8-13) API key providers — show (credential detected) if key already set
|
||||
PROVIDER_MENU_ENVS=(ANTHROPIC_API_KEY OPENAI_API_KEY GEMINI_API_KEY GROQ_API_KEY CEREBRAS_API_KEY OPENROUTER_API_KEY)
|
||||
PROVIDER_MENU_NAMES=("Anthropic (Claude) - Recommended" "OpenAI (GPT)" "Google Gemini - Free tier available" "Groq - Fast, free tier" "Cerebras - Fast, free tier" "OpenRouter - Bring any OpenRouter model")
|
||||
# 8-N) API key providers — show (credential detected) if key already set.
|
||||
# Order is reflected directly in the menu numbering; the case dispatcher
|
||||
# below resolves choice numbers via $((8 + index_in_arrays)).
|
||||
PROVIDER_MENU_ENVS=(ANTHROPIC_API_KEY OPENAI_API_KEY GEMINI_API_KEY GROQ_API_KEY CEREBRAS_API_KEY OPENROUTER_API_KEY DEEPSEEK_API_KEY)
|
||||
PROVIDER_MENU_NAMES=("Anthropic (Claude) - Recommended" "OpenAI (GPT)" "Google Gemini - Free tier available" "Groq - Fast, free tier" "Cerebras - Fast, free tier" "OpenRouter - Bring any OpenRouter model" "DeepSeek - V4 family")
|
||||
for idx in "${!PROVIDER_MENU_ENVS[@]}"; do
|
||||
num=$((idx + 8))
|
||||
env_var="${PROVIDER_MENU_ENVS[$idx]}"
|
||||
@@ -1322,14 +1367,16 @@ for idx in "${!PROVIDER_MENU_ENVS[@]}"; do
|
||||
fi
|
||||
done
|
||||
|
||||
# 14) Local (Ollama) — no API key needed
|
||||
# Local (Ollama) — slot computed from the provider list so adding/removing
|
||||
# API-key providers above doesn't require renumbering by hand.
|
||||
OLLAMA_CHOICE=$((8 + ${#PROVIDER_MENU_ENVS[@]}))
|
||||
if [ "$OLLAMA_DETECTED" = true ]; then
|
||||
echo -e " ${CYAN}14)${NC} Local (Ollama) - No API key needed ${GREEN}(ollama detected)${NC}"
|
||||
echo -e " ${CYAN}$OLLAMA_CHOICE)${NC} Local (Ollama) - No API key needed ${GREEN}(ollama detected)${NC}"
|
||||
else
|
||||
echo -e " ${CYAN}14)${NC} Local (Ollama) - No API key needed"
|
||||
echo -e " ${CYAN}$OLLAMA_CHOICE)${NC} Local (Ollama) - No API key needed"
|
||||
fi
|
||||
|
||||
SKIP_CHOICE=$((8 + ${#PROVIDER_MENU_ENVS[@]} + 1))
|
||||
SKIP_CHOICE=$((OLLAMA_CHOICE + 1))
|
||||
echo -e " ${CYAN}$SKIP_CHOICE)${NC} Skip for now"
|
||||
echo ""
|
||||
|
||||
@@ -1535,6 +1582,13 @@ case $choice in
|
||||
SIGNUP_URL="https://openrouter.ai/keys"
|
||||
;;
|
||||
14)
|
||||
SELECTED_ENV_VAR="DEEPSEEK_API_KEY"
|
||||
SELECTED_PROVIDER_ID="deepseek"
|
||||
SELECTED_API_BASE="https://api.deepseek.com"
|
||||
PROVIDER_NAME="DeepSeek"
|
||||
SIGNUP_URL="https://platform.deepseek.com/api_keys"
|
||||
;;
|
||||
"$OLLAMA_CHOICE")
|
||||
# Local (Ollama) — no API key; pick model from ollama list
|
||||
if [ "$OLLAMA_DETECTED" != true ]; then
|
||||
echo ""
|
||||
@@ -1772,6 +1826,165 @@ fi
|
||||
|
||||
echo ""
|
||||
|
||||
# ============================================================
|
||||
# Vision Fallback (subagent for tool-result images)
|
||||
# ============================================================
|
||||
#
|
||||
# When a tool returns an image (browser_screenshot, render_image, etc.)
|
||||
# but the main agent's model is text-only, the framework can route the
|
||||
# image through a separate VLM subagent that returns a text caption,
|
||||
# preserving the agent's ability to reason about visual state.
|
||||
#
|
||||
# Skip entirely when the chosen main model already supports vision per
|
||||
# the catalog's ``supports_vision`` flag — the fallback would never fire
|
||||
# in that case, and prompting for it just adds friction. For text-only
|
||||
# mains we still offer the prompt so the user can wire up a captioning
|
||||
# subagent.
|
||||
|
||||
MAIN_MODEL_HAS_VISION="false"
|
||||
if [ -n "$SELECTED_MODEL" ]; then
|
||||
MAIN_MODEL_HAS_VISION=$(uv run python - "$SELECTED_MODEL" <<'PY' 2>/dev/null || echo "false"
|
||||
import sys
|
||||
from framework.llm.model_catalog import model_supports_vision
|
||||
print("true" if model_supports_vision(sys.argv[1]) else "false")
|
||||
PY
|
||||
)
|
||||
fi
|
||||
|
||||
if [ -n "$SELECTED_PROVIDER_ID" ] && [ "$MAIN_MODEL_HAS_VISION" = "true" ]; then
|
||||
# Drop any stale vision_fallback block so the config reflects the
|
||||
# current main model's capabilities.
|
||||
save_vision_fallback "" "" "" "" > /dev/null 2>&1 || true
|
||||
echo -e "${GREEN}⬢${NC} Vision fallback ${DIM}skipped — ${SELECTED_MODEL} already supports vision${NC}"
|
||||
echo ""
|
||||
elif [ -n "$SELECTED_PROVIDER_ID" ]; then
|
||||
echo -e "${YELLOW}⬢${NC} ${BLUE}${BOLD}Vision fallback subagent${NC}"
|
||||
echo ""
|
||||
echo -e " ${DIM}When a screenshot/image tool is called from a text-only model,${NC}"
|
||||
echo -e " ${DIM}the framework can route the image through a vision-capable VLM${NC}"
|
||||
echo -e " ${DIM}and inject the caption into the conversation. Inert when your${NC}"
|
||||
echo -e " ${DIM}main model already supports vision (most do).${NC}"
|
||||
echo ""
|
||||
|
||||
# Build the candidate list from the same model_catalog.json the main
|
||||
# LLM step uses — never hardcode model IDs in this script. For each
|
||||
# provider in the catalogue, pick a model whose ``supports_vision``
|
||||
# flag is true (since the fallback subagent's whole purpose is to
|
||||
# caption images — a text-only candidate would be useless). Prefer
|
||||
# the provider's default when it supports vision, otherwise fall
|
||||
# back to the first vision-capable model in the provider's list.
|
||||
# Skip the provider entirely if no model in its catalog supports
|
||||
# vision. Output one TSV row per candidate:
|
||||
# provider_id<TAB>model<TAB>env_var<TAB>display_name
|
||||
VISION_CANDIDATES_TSV=$(uv run python - <<'PY'
|
||||
import os
|
||||
from framework.llm.model_catalog import get_default_models, get_models_catalogue
|
||||
|
||||
# Map provider_id → the env-var name the framework reads its key from.
|
||||
# Mirrors PROVIDER_ENV_VARS at the top of quickstart.sh, plus how the
|
||||
# rest of the script picks an env var per provider.
|
||||
PROVIDER_KEY_ENV = {
|
||||
"anthropic": "ANTHROPIC_API_KEY",
|
||||
"openai": "OPENAI_API_KEY",
|
||||
"gemini": "GEMINI_API_KEY",
|
||||
"groq": "GROQ_API_KEY",
|
||||
"cerebras": "CEREBRAS_API_KEY",
|
||||
"minimax": "MINIMAX_API_KEY",
|
||||
"mistral": "MISTRAL_API_KEY",
|
||||
"together": "TOGETHER_API_KEY",
|
||||
"deepseek": "DEEPSEEK_API_KEY",
|
||||
"kimi": "KIMI_API_KEY",
|
||||
"openrouter": "OPENROUTER_API_KEY",
|
||||
}
|
||||
|
||||
defaults = get_default_models()
|
||||
catalog = get_models_catalogue()
|
||||
for provider_id, default_model in sorted(defaults.items()):
|
||||
env = PROVIDER_KEY_ENV.get(provider_id)
|
||||
if not env:
|
||||
continue
|
||||
# GEMINI_API_KEY OR GOOGLE_API_KEY both unlock gemini
|
||||
has_key = bool(os.environ.get(env))
|
||||
if provider_id == "gemini" and not has_key:
|
||||
if os.environ.get("GOOGLE_API_KEY"):
|
||||
has_key = True
|
||||
env = "GOOGLE_API_KEY"
|
||||
if not has_key:
|
||||
continue
|
||||
# Pick a vision-capable model: prefer the catalog default if it has
|
||||
# supports_vision=true, else the first vision-capable model in the
|
||||
# provider's list. Skip the provider if none exist.
|
||||
models = catalog.get(provider_id, [])
|
||||
chosen = None
|
||||
for m in models:
|
||||
if m["id"] == default_model and m.get("supports_vision") is True:
|
||||
chosen = m["id"]
|
||||
break
|
||||
if chosen is None:
|
||||
for m in models:
|
||||
if m.get("supports_vision") is True:
|
||||
chosen = m["id"]
|
||||
break
|
||||
if chosen is None:
|
||||
continue
|
||||
# Display name: provider/model from the catalogue verbatim
|
||||
display = f"{provider_id}/{chosen}"
|
||||
print(f"{provider_id}\t{chosen}\t{env}\t{display}")
|
||||
PY
|
||||
)
|
||||
|
||||
if [ -z "$VISION_CANDIDATES_TSV" ]; then
|
||||
echo -e " ${YELLOW}No matching API keys detected for any catalog provider.${NC}"
|
||||
echo -e " ${DIM}Set an API key for any provider in model_catalog.json and rerun.${NC}"
|
||||
echo -e " ${DIM}Skipping for now — text-only models will lose image content silently.${NC}"
|
||||
else
|
||||
# Materialise into bash array for selection
|
||||
VISION_CANDIDATES=()
|
||||
while IFS= read -r line; do
|
||||
[ -n "$line" ] && VISION_CANDIDATES+=("$line")
|
||||
done <<< "$VISION_CANDIDATES_TSV"
|
||||
|
||||
echo -e " ${BOLD}Available vision-fallback models${NC} ${DIM}(from model_catalog.json):${NC}"
|
||||
echo -e " ${DIM}0)${NC} (skip — don't configure vision fallback)"
|
||||
idx=1
|
||||
for entry in "${VISION_CANDIDATES[@]}"; do
|
||||
IFS=$'\t' read -r _vp _vm _vk _vd <<< "$entry"
|
||||
echo -e " ${DIM}${idx})${NC} ${_vd} ${DIM}[\$${_vk}]${NC}"
|
||||
idx=$((idx + 1))
|
||||
done
|
||||
echo ""
|
||||
VISION_CHOICE=""
|
||||
while true; do
|
||||
read -r -p " Pick a vision-fallback model [1-${#VISION_CANDIDATES[@]}, 0=skip, default=1]: " VISION_CHOICE || VISION_CHOICE=""
|
||||
VISION_CHOICE="${VISION_CHOICE:-1}"
|
||||
if [[ "$VISION_CHOICE" =~ ^[0-9]+$ ]] && \
|
||||
[ "$VISION_CHOICE" -ge 0 ] && \
|
||||
[ "$VISION_CHOICE" -le "${#VISION_CANDIDATES[@]}" ]; then
|
||||
break
|
||||
fi
|
||||
echo -e " ${YELLOW}Please enter 0 (skip) or 1-${#VISION_CANDIDATES[@]}.${NC}"
|
||||
done
|
||||
|
||||
if [ "$VISION_CHOICE" = "0" ]; then
|
||||
# Explicit skip — drop any prior block so config stays clean.
|
||||
save_vision_fallback "" "" "" "" > /dev/null 2>&1 || true
|
||||
echo -e " ${DIM}skipped — no vision_fallback block written${NC}"
|
||||
else
|
||||
chosen="${VISION_CANDIDATES[$((VISION_CHOICE - 1))]}"
|
||||
IFS=$'\t' read -r vf_provider vf_model vf_env vf_display <<< "$chosen"
|
||||
echo -n " Saving vision_fallback... "
|
||||
if save_vision_fallback "$vf_provider" "$vf_model" "$vf_env" "" > /dev/null; then
|
||||
echo -e "${GREEN}⬢${NC}"
|
||||
echo -e " ${DIM}vision_fallback: ${vf_display} (key from \$${vf_env})${NC}"
|
||||
else
|
||||
echo -e "${RED}failed${NC}"
|
||||
echo -e " ${YELLOW}Could not write vision_fallback to ~/.hive/configuration.json — non-fatal, edit manually if needed.${NC}"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
echo ""
|
||||
fi
|
||||
|
||||
# ============================================================
|
||||
# Browser Automation (GCU) — always enabled
|
||||
# ============================================================
|
||||
|
||||
@@ -128,6 +128,7 @@ from .shell_config import (
|
||||
get_shell_source_command,
|
||||
)
|
||||
from .shopify import SHOPIFY_CREDENTIALS
|
||||
from .similarweb import SIMILARWEB_CREDENTIALS
|
||||
from .slack import SLACK_CREDENTIALS
|
||||
from .snowflake import SNOWFLAKE_CREDENTIALS
|
||||
from .store_adapter import CredentialStoreAdapter
|
||||
@@ -209,6 +210,7 @@ CREDENTIAL_SPECS = {
|
||||
**SAP_CREDENTIALS,
|
||||
**SEARCH_CREDENTIALS,
|
||||
**SERPAPI_CREDENTIALS,
|
||||
**SIMILARWEB_CREDENTIALS,
|
||||
**SHOPIFY_CREDENTIALS,
|
||||
**SLACK_CREDENTIALS,
|
||||
**SNOWFLAKE_CREDENTIALS,
|
||||
@@ -306,6 +308,7 @@ __all__ = [
|
||||
"SAP_CREDENTIALS",
|
||||
"SEARCH_CREDENTIALS",
|
||||
"SERPAPI_CREDENTIALS",
|
||||
"SIMILARWEB_CREDENTIALS",
|
||||
"SHOPIFY_CREDENTIALS",
|
||||
"SLACK_CREDENTIALS",
|
||||
"SNOWFLAKE_CREDENTIALS",
|
||||
|
||||
@@ -1125,6 +1125,29 @@ class SerpApiHealthChecker(BaseHttpHealthChecker):
|
||||
AUTH_QUERY_PARAM_NAME = "api_key"
|
||||
|
||||
|
||||
class SimilarWebHealthChecker(BaseHttpHealthChecker):
|
||||
"""Health checker for SimilarWeb API key."""
|
||||
|
||||
ENDPOINT = "https://api.similarweb.com/v5/website-analysis/websites/traffic-and-engagement/"
|
||||
SERVICE_NAME = "SimilarWeb"
|
||||
AUTH_TYPE = BaseHttpHealthChecker.AUTH_HEADER
|
||||
AUTH_HEADER_NAME = "api-key"
|
||||
AUTH_HEADER_TEMPLATE = "{token}"
|
||||
|
||||
def _build_params(self, credential_value: str) -> dict[str, str]:
|
||||
params = super()._build_params(credential_value)
|
||||
params.update(
|
||||
{
|
||||
"domain": "google.com",
|
||||
"start_date": "2024-01",
|
||||
"end_date": "2024-01",
|
||||
"country": "world",
|
||||
"granularity": "monthly",
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
class ApolloHealthChecker(BaseHttpHealthChecker):
|
||||
"""Health checker for Apollo.io API key."""
|
||||
|
||||
@@ -1386,6 +1409,7 @@ HEALTH_CHECKERS: dict[str, CredentialHealthChecker] = {
|
||||
"prometheus": PrometheusHealthChecker(),
|
||||
"resend": ResendHealthChecker(),
|
||||
"serpapi": SerpApiHealthChecker(),
|
||||
"similarweb": SimilarWebHealthChecker(),
|
||||
"slack": SlackHealthChecker(),
|
||||
"stripe": StripeHealthChecker(),
|
||||
"telegram": TelegramHealthChecker(),
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from aden_tools.credentials.base import CredentialSpec
|
||||
|
||||
SIMILARWEB_CREDENTIALS = {
|
||||
"similarweb": CredentialSpec(
|
||||
env_var="SIMILARWEB_API_KEY",
|
||||
tools=[
|
||||
"similarweb_v5_traffic_and_engagement",
|
||||
"similarweb_v5_website_rank",
|
||||
"similarweb_v5_traffic_sources",
|
||||
"similarweb_v5_geography",
|
||||
"similarweb_v5_demographics",
|
||||
"similarweb_v5_company_info",
|
||||
"similarweb_v5_top_sites_by_category",
|
||||
"similarweb_v5_referrals",
|
||||
"similarweb_v5_ppc_spend",
|
||||
"similarweb_v5_geography_details",
|
||||
"similarweb_v5_similar_sites",
|
||||
"similarweb_v5_ad_networks",
|
||||
"similarweb_v5_demographics_traffic",
|
||||
"similarweb_v5_deduplicated_audience",
|
||||
"similarweb_v5_audience_interests",
|
||||
"similarweb_v5_audience_overlap",
|
||||
"similarweb_v5_technologies",
|
||||
"similarweb_v5_leading_folders",
|
||||
"similarweb_v5_popular_pages",
|
||||
"similarweb_v5_subdomains",
|
||||
"similarweb_v5_keyword_competitors",
|
||||
"similarweb_v5_keyword_opportunities",
|
||||
"similarweb_v5_serp_features",
|
||||
"similarweb_v5_organic_keywords",
|
||||
"similarweb_v5_paid_keywords",
|
||||
"similarweb_v5_serp_players",
|
||||
"similarweb_v5_social_referrals",
|
||||
"similarweb_v5_segments_list",
|
||||
"similarweb_v5_segment_analysis",
|
||||
],
|
||||
required=True,
|
||||
help_url="https://developer.similarweb.com/",
|
||||
description="API key for SimilarWeb traffic and competitor insights.",
|
||||
direct_api_key_supported=True,
|
||||
api_key_instructions="""To get a SimilarWeb API key:
|
||||
1. Go to the SimilarWeb Developer Portal (https://developer.similarweb.com/)
|
||||
2. Or log into your SimilarWeb Pro account at pro.similarweb.com
|
||||
3. Navigate to Account Settings > API (or Data Extraction / API section)
|
||||
4. Click on "Generate API Key"
|
||||
5. Copy the generated API key and securely store it in your .env file""",
|
||||
credential_id="similarweb",
|
||||
credential_key="api_key",
|
||||
health_check_endpoint="https://api.similarweb.com/v5/website-analysis/websites/traffic-and-engagement/",
|
||||
)
|
||||
}
|
||||
@@ -118,6 +118,7 @@ from .salesforce_tool import register_tools as register_salesforce
|
||||
from .sap_tool import register_tools as register_sap
|
||||
from .serpapi_tool import register_tools as register_serpapi
|
||||
from .shopify_tool import register_tools as register_shopify
|
||||
from .similarweb_tool import register_tools as register_similarweb
|
||||
from .slack_tool import register_tools as register_slack
|
||||
from .snowflake_tool import register_tools as register_snowflake
|
||||
from .ssl_tls_scanner import register_tools as register_ssl_tls_scanner
|
||||
@@ -320,6 +321,7 @@ def _register_unverified(
|
||||
register_salesforce(mcp, credentials=credentials)
|
||||
register_sap(mcp, credentials=credentials)
|
||||
register_shopify(mcp, credentials=credentials)
|
||||
register_similarweb(mcp, credentials=credentials)
|
||||
register_snowflake(mcp, credentials=credentials)
|
||||
register_supabase(mcp, credentials=credentials)
|
||||
register_terraform(mcp, credentials=credentials)
|
||||
|
||||
@@ -0,0 +1,152 @@
|
||||
# SimilarWeb Tool
|
||||
|
||||
Integration with SimilarWeb for deep website analytics, competitor intelligence, market research data, traffic sources, and audience demographics.
|
||||
|
||||
## Overview
|
||||
|
||||
This tool enables Hive agents to interact with SimilarWeb's data intelligence infrastructure for:
|
||||
|
||||
- Website traffic analysis and engagement metrics
|
||||
- Competitor research and benchmarking
|
||||
- SEO and keyword analysis
|
||||
- Advertising strategy and PPC spend insights
|
||||
- Audience demographics and geographic distribution
|
||||
- Technical profile and company insights
|
||||
|
||||
## Available Tools
|
||||
|
||||
This integration provides the following MCP tools for comprehensive market intelligence operations:
|
||||
|
||||
**Website Overview & Traffic**
|
||||
|
||||
- `similarweb_v5_traffic_and_engagement` - Get traffic and engagement metrics (visits, duration, pages per visit, bounce rate)
|
||||
- `similarweb_v5_traffic_sources` - Get marketing channels (traffic sources) breakdown
|
||||
- `similarweb_v5_geography` - Get traffic distribution by geography
|
||||
- `similarweb_v5_geography_details` - Get detailed traffic distribution by country
|
||||
- `similarweb_v5_website_rank` - Get global, country, and category ranks for a website
|
||||
|
||||
**Competitor Intelligence**
|
||||
|
||||
- `similarweb_v5_similar_sites` - Get a list of websites similar to the given domain
|
||||
- `similarweb_v5_top_sites_by_category` - Get top sites in a specific category (e.g., 'Games', 'Lifestyle')
|
||||
- `similarweb_v5_company_info` - Get company information (HQ, industry, etc.) for a website domain
|
||||
- `similarweb_v5_technologies` - Get technologies used on the website (CMS, Ads, Analytics, etc.)
|
||||
|
||||
**Marketing Channels & Referrals**
|
||||
|
||||
- `similarweb_v5_referrals` - Get detailed referral traffic sources for a domain
|
||||
- `similarweb_v5_social_referrals` - Get traffic distribution from social networks
|
||||
- `similarweb_v5_ppc_spend` - Get estimated PPC spend for a website domain
|
||||
- `similarweb_v5_ad_networks` - Get performance data across different ad networks
|
||||
|
||||
**Keywords & Search**
|
||||
|
||||
- `similarweb_v5_keyword_competitors` - Get organic and paid keyword competitors
|
||||
- `similarweb_v5_keyword_opportunities` - Get keyword gap analysis and opportunities
|
||||
- `similarweb_v5_organic_keywords` - Get detailed organic keyword performance
|
||||
- `similarweb_v5_paid_keywords` - Get detailed paid keyword performance
|
||||
- `similarweb_v5_serp_features` - Get SERP features analysis
|
||||
- `similarweb_v5_serp_players` - Get top websites driving search traffic for keywords
|
||||
|
||||
**Website Content & Structure**
|
||||
|
||||
- `similarweb_v5_leading_folders` - Get top sub-folders by traffic
|
||||
- `similarweb_v5_popular_pages` - Get most visited pages
|
||||
- `similarweb_v5_subdomains` - Get traffic breakdown by subdomain
|
||||
|
||||
**Audience & Segments**
|
||||
|
||||
- `similarweb_v5_demographics` - Get audience demographics (age and gender)
|
||||
- `similarweb_v5_demographics_traffic` - Get traffic breakdown by audience demographic segments
|
||||
- `similarweb_v5_deduplicated_audience` - Get unique visitor count across multiple domains
|
||||
- `similarweb_v5_audience_interests` - Get interests and categories relevant to the website's audience
|
||||
- `similarweb_v5_audience_overlap` - Get shared audience between the main domain and a comparison domain
|
||||
- `similarweb_v5_segments_list` - List custom segments available for the domain
|
||||
- `similarweb_v5_segment_analysis` - Get traffic and engagement for a specific segment ID
|
||||
|
||||
## Setup
|
||||
|
||||
### 1. Get SimilarWeb API Credentials
|
||||
|
||||
1. Go to the [SimilarWeb Developer Portal](https://developer.similarweb.com/)
|
||||
2. Log into your SimilarWeb Pro account at `pro.similarweb.com`
|
||||
3. Navigate to **Account Settings** -> **API** (or Data Extraction / API section)
|
||||
4. Click on **Generate API Key**
|
||||
5. Copy the generated API key.
|
||||
|
||||
### 2. Configure Environment Variables
|
||||
|
||||
```bash
|
||||
export SIMILARWEB_API_KEY="your_api_key_here"
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
Here are usage examples for the available MCP tools:
|
||||
|
||||
### Website Overview & Traffic
|
||||
|
||||
```python
|
||||
similarweb_v5_traffic_and_engagement(domain="example.com", country="world", granularity="monthly")
|
||||
similarweb_v5_traffic_sources(domain="example.com", country="world")
|
||||
similarweb_v5_geography(domain="example.com")
|
||||
similarweb_v5_geography_details(domain="example.com")
|
||||
similarweb_v5_website_rank(domain="example.com")
|
||||
```
|
||||
|
||||
### Competitor Intelligence
|
||||
|
||||
```python
|
||||
similarweb_v5_similar_sites(domain="example.com")
|
||||
similarweb_v5_top_sites_by_category(category="Games", country="world")
|
||||
similarweb_v5_company_info(domain="example.com")
|
||||
similarweb_v5_technologies(domain="example.com")
|
||||
```
|
||||
|
||||
### Marketing Channels & Referrals
|
||||
|
||||
```python
|
||||
similarweb_v5_referrals(domain="example.com", country="world")
|
||||
similarweb_v5_social_referrals(domain="example.com", country="world")
|
||||
similarweb_v5_ppc_spend(domain="example.com", country="world")
|
||||
similarweb_v5_ad_networks(domain="example.com", country="world")
|
||||
```
|
||||
|
||||
### Keywords & Search
|
||||
|
||||
```python
|
||||
similarweb_v5_keyword_competitors(domain="example.com")
|
||||
similarweb_v5_keyword_opportunities(domain="example.com")
|
||||
similarweb_v5_organic_keywords(domain="example.com", country="world")
|
||||
similarweb_v5_paid_keywords(domain="example.com", country="world")
|
||||
similarweb_v5_serp_features(domain="example.com")
|
||||
similarweb_v5_serp_players(domain="example.com")
|
||||
```
|
||||
|
||||
### Website Content & Structure
|
||||
|
||||
```python
|
||||
similarweb_v5_leading_folders(domain="example.com", country="world")
|
||||
similarweb_v5_popular_pages(domain="example.com", country="world")
|
||||
similarweb_v5_subdomains(domain="example.com", country="world")
|
||||
```
|
||||
|
||||
### Audience & Segments
|
||||
|
||||
```python
|
||||
similarweb_v5_demographics(domain="example.com")
|
||||
similarweb_v5_demographics_traffic(domain="example.com")
|
||||
similarweb_v5_deduplicated_audience(domains="example.com,competitor.com", country="world")
|
||||
similarweb_v5_audience_interests(domain="example.com")
|
||||
similarweb_v5_audience_overlap(domain="example.com", compare_to="competitor.com")
|
||||
similarweb_v5_segments_list(domain="example.com")
|
||||
similarweb_v5_segment_analysis(segment_id="12345", country="world")
|
||||
```
|
||||
|
||||
## Authentication
|
||||
|
||||
The tool passes your `SIMILARWEB_API_KEY` to the API calls via the `api-key` HTTP header during communication with the endpoints hosted under `https://api.similarweb.com`. The framework's credential adapter intercepts the secret parameter injected into your workspace securely.
|
||||
|
||||
## Error Handling
|
||||
|
||||
The API responses gracefully return API errors inside regular Python dictionaries with a detailed message (e.g. `{"error": "HTTP error 403: ..."}`).
|
||||
@@ -0,0 +1,3 @@
|
||||
from .similarweb_tool import register_tools
|
||||
|
||||
__all__ = ["register_tools"]
|
||||
@@ -0,0 +1,570 @@
|
||||
"""
|
||||
SimilarWeb Tool - Traffic and competitor insights for FastMCP.
|
||||
|
||||
Provides website analytics, demographic data, and competitor intelligence.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastmcp import FastMCP
|
||||
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
|
||||
|
||||
def _get_api_key(credentials: CredentialStoreAdapter | None = None) -> str | dict[str, str]:
|
||||
"""Get the SimilarWeb API key from credentials or environment."""
|
||||
if credentials:
|
||||
key = credentials.get("similarweb")
|
||||
if key:
|
||||
return key
|
||||
|
||||
import os
|
||||
|
||||
env_key = os.environ.get("SIMILARWEB_API_KEY")
|
||||
if env_key:
|
||||
return env_key
|
||||
|
||||
return {
|
||||
"error": "SimilarWeb credentials not configured",
|
||||
"help": (
|
||||
"Set SIMILARWEB_API_KEY environment variable or configure "
|
||||
"via credential store. Get a key at https://developer.similarweb.com/"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _make_request(
|
||||
endpoint: str,
|
||||
api_key: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Helper method to make requests to the SimilarWeb API V5."""
|
||||
if params is None:
|
||||
params = {}
|
||||
|
||||
# SimilarWeb API v5 uses api-key in the header
|
||||
headers = {"api-key": api_key, "Accept": "application/json"}
|
||||
|
||||
url = f"https://api.similarweb.com/v5/{endpoint}"
|
||||
|
||||
try:
|
||||
response = httpx.get(url, params=params, headers=headers, timeout=30.0)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
return {"error": f"HTTP error {e.response.status_code}: {e.response.text}"}
|
||||
except Exception as e:
|
||||
return {"error": f"Request failed: {str(e)}"}
|
||||
|
||||
|
||||
def register_tools(mcp: FastMCP, credentials: CredentialStoreAdapter | None = None) -> None:
|
||||
"""Register SimilarWeb V5 tools with the MCP server."""
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_traffic_and_engagement(
|
||||
domain: str,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
country: str = "world",
|
||||
granularity: str = "monthly",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get traffic and engagement metrics for a website using V5 API.
|
||||
|
||||
Args:
|
||||
domain: The website domain (e.g., 'amazon.com')
|
||||
start_date: Start date (YYYY-MM or YYYY-MM-DD)
|
||||
end_date: End date (YYYY-MM or YYYY-MM-DD)
|
||||
country: 2-letter country code or 'world'
|
||||
granularity: 'daily', 'weekly', or 'monthly'
|
||||
"""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
params = {
|
||||
"metrics": "visits,bounce_rate,avg_visit_duration,pages_per_visit,total_page_views",
|
||||
"country": country,
|
||||
"granularity": granularity,
|
||||
}
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
|
||||
params["domain"] = domain
|
||||
params["web_source"] = "desktop"
|
||||
return _make_request("website-analysis/websites/traffic-and-engagement", api_key_res, params)
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_website_rank(
|
||||
domain: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Get global, country, and category ranks for a website."""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
return _make_request("website-analysis/websites/website-rank", api_key_res, {"domain": domain})
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_traffic_sources(
|
||||
domain: str,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
country: str = "world",
|
||||
) -> dict[str, Any]:
|
||||
"""Get marketing channels (traffic sources) breakdown for a website."""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
params = {"country": country}
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
|
||||
params["domain"] = domain
|
||||
return _make_request("website-analysis/websites/traffic-sources", api_key_res, params)
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_geography(
|
||||
domain: str,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Get traffic distribution by geography for a website."""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
params = {}
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
|
||||
params["domain"] = domain
|
||||
return _make_request("website-analysis/websites/traffic-geography", api_key_res, params)
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_demographics(
|
||||
domain: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Get audience demographics (age and gender) for a website."""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
return _make_request("website-analysis/websites/demographics/aggregated", api_key_res, {"domain": domain})
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_company_info(
|
||||
domain: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Get company information (HQ, industry, etc.) for a website domain."""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
return _make_request("website-analysis/websites/company-info/company-info", api_key_res, {"domain": domain})
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_top_sites_by_category(
|
||||
category: str,
|
||||
country: str = "world",
|
||||
) -> dict[str, Any]:
|
||||
"""Get top sites in a specific category (e.g., 'Games', 'Lifestyle')."""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
params = {"category": category, "country": country}
|
||||
return _make_request("website-analysis/websites/top-sites-by-category/aggregated", api_key_res, params)
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_referrals(
|
||||
domain: str,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
country: str = "world",
|
||||
) -> dict[str, Any]:
|
||||
"""Get detailed referral traffic sources for a domain."""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
params = {"country": country}
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
|
||||
params["domain"] = domain
|
||||
return _make_request("website-analysis/websites/referrals/aggregated", api_key_res, params)
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_ppc_spend(
|
||||
domain: str,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
country: str = "world",
|
||||
) -> dict[str, Any]:
|
||||
"""Get estimated PPC spend for a website domain."""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
params = {"country": country}
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
|
||||
params["domain"] = domain
|
||||
return _make_request("website-analysis/websites/ppc-spend", api_key_res, params)
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_geography_details(
|
||||
domain: str,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Get detailed traffic distribution by country (aggregated)."""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
params = {}
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
|
||||
params["domain"] = domain
|
||||
return _make_request("website-analysis/websites/geography/aggregated", api_key_res, params)
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_similar_sites(
|
||||
domain: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Get a list of websites similar to the given domain."""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
return _make_request("website-analysis/websites/similar-sites/aggregated", api_key_res, {"domain": domain})
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_ad_networks(
|
||||
domain: str,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
country: str = "world",
|
||||
) -> dict[str, Any]:
|
||||
"""Get performance data across different ad networks for a domain."""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
params = {"country": country}
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
|
||||
params["domain"] = domain
|
||||
return _make_request("website-analysis/ad-networks/aggregated", api_key_res, params)
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_demographics_traffic(
|
||||
domain: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Get traffic breakdown by audience demographic segments."""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
return _make_request(
|
||||
"website-analysis/websites/traffic-by-demographics/aggregated", api_key_res, {"domain": domain}
|
||||
)
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_deduplicated_audience(
|
||||
domains: str,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
country: str = "world",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get unique visitor count across multiple domains (comma-separated).
|
||||
|
||||
Args:
|
||||
domains: Comma-separated domains (e.g. 'amazon.com,ebay.com')
|
||||
"""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
params = {"domains": domains, "country": country}
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
|
||||
return _make_request("website-analysis/websites/deduplicated-audience", api_key_res, params)
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_audience_interests(
|
||||
domain: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Get interests and categories relevant to the website's audience."""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
return _make_request("website-analysis/websites/audience-interests/aggregated", api_key_res, {"domain": domain})
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_audience_overlap(
|
||||
domain: str,
|
||||
compare_to: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get shared audience between the main domain and a comparison domain.
|
||||
|
||||
Args:
|
||||
domain: The main domain
|
||||
compare_to: Domain to compare overlap with
|
||||
"""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
params = {"domain": domain, "compare_to": compare_to}
|
||||
return _make_request("website-analysis/websites/audience-overlap/aggregated", api_key_res, params)
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_technologies(
|
||||
domain: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Get technologies used on the website (CMS, Ads, Analytics, etc.)."""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
return _make_request("website-analysis/websites/technologies/aggregated", api_key_res, {"domain": domain})
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_leading_folders(
|
||||
domain: str,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
country: str = "world",
|
||||
) -> dict[str, Any]:
|
||||
"""Get top sub-folders by traffic for a domain."""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
params = {"country": country}
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
|
||||
params["domain"] = domain
|
||||
return _make_request("website-analysis/websites/pages/leading-folders/aggregated", api_key_res, params)
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_popular_pages(
|
||||
domain: str,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
country: str = "world",
|
||||
) -> dict[str, Any]:
|
||||
"""Get most visited pages on the given domain."""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
params = {"country": country}
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
|
||||
params["domain"] = domain
|
||||
return _make_request("website-content/pages/popular-pages/aggregated", api_key_res, params)
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_subdomains(
|
||||
domain: str,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
country: str = "world",
|
||||
) -> dict[str, Any]:
|
||||
"""Get traffic breakdown by subdomain for a domain."""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
params = {"country": country}
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
|
||||
params["domain"] = domain
|
||||
return _make_request("website-content/subdomains/aggregated", api_key_res, params)
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_keyword_competitors(
|
||||
domain: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Get organic and paid keyword competitors for a domain."""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
return _make_request(
|
||||
"website-analysis/websites/keywords-competitors/aggregated", api_key_res, {"domain": domain}
|
||||
)
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_keyword_opportunities(
|
||||
domain: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Get keyword gap analysis and opportunities for a domain."""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
return _make_request(
|
||||
"website-analysis/websites/keywords-opportunities/aggregated", api_key_res, {"domain": domain}
|
||||
)
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_serp_features(
|
||||
domain: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Get SERP features analysis for the domain."""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
return _make_request("website-analysis/websites/keywords/serp-features", api_key_res, {"domain": domain})
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_organic_keywords(
|
||||
domain: str,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
country: str = "world",
|
||||
) -> dict[str, Any]:
|
||||
"""Get detailed organic keyword performance for a domain."""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
params = {"country": country}
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
|
||||
params["domain"] = domain
|
||||
return _make_request("website-analysis/websites/keywords/organic/aggregated", api_key_res, params)
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_paid_keywords(
|
||||
domain: str,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
country: str = "world",
|
||||
) -> dict[str, Any]:
|
||||
"""Get detailed paid keyword performance for a domain."""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
params = {"country": country}
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
|
||||
params["domain"] = domain
|
||||
return _make_request("website-analysis/websites/keywords/paid/aggregated", api_key_res, params)
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_serp_players(
|
||||
domain: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Get top websites driving search traffic for keywords (SERP players)."""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
return _make_request(
|
||||
"website-analysis/websites/keywords/serp-players/aggregated", api_key_res, {"domain": domain}
|
||||
)
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_social_referrals(
|
||||
domain: str,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
country: str = "world",
|
||||
) -> dict[str, Any]:
|
||||
"""Get traffic distribution from social networks for a domain."""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
params = {"country": country}
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
|
||||
params["domain"] = domain
|
||||
return _make_request("website-analysis/websites/social-referrals/aggregated", api_key_res, params)
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_segments_list(
|
||||
domain: str,
|
||||
) -> dict[str, Any]:
|
||||
"""List custom segments available for the domain."""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
return _make_request("segment-analysis/segments/describe", api_key_res, {"domain": domain})
|
||||
|
||||
@mcp.tool()
|
||||
def similarweb_v5_segment_analysis(
|
||||
segment_id: str,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
country: str = "world",
|
||||
) -> dict[str, Any]:
|
||||
"""Get traffic and engagement for a specific segment ID."""
|
||||
api_key_res = _get_api_key(credentials)
|
||||
if isinstance(api_key_res, dict):
|
||||
return api_key_res
|
||||
|
||||
params = {"country": country}
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
|
||||
params["segment"] = segment_id
|
||||
return _make_request("segment-analysis/segments/traffic-and-engagement", api_key_res, params)
|
||||
+111
-38
@@ -2207,11 +2207,20 @@ class BeelineBridge:
|
||||
pass
|
||||
_interaction_highlights.pop(tab_id, None)
|
||||
|
||||
async def scroll(self, tab_id: int, direction: str = "down", amount: int = 500) -> dict:
|
||||
"""Scroll the page.
|
||||
async def scroll(
|
||||
self,
|
||||
tab_id: int,
|
||||
direction: str = "down",
|
||||
amount: int = 500,
|
||||
selector: str | None = None,
|
||||
) -> dict:
|
||||
"""Scroll the page or a specific scrollable container.
|
||||
|
||||
Uses JavaScript to find and scroll the appropriate container.
|
||||
Handles SPAs like LinkedIn where content is in a nested scrollable div.
|
||||
If ``selector`` is given, scroll that element directly (supports
|
||||
'>>>' shadow-piercing selectors). Otherwise pick a container with
|
||||
a direction-aware heuristic that prefers the visible scroll area
|
||||
at the viewport center, falling back to the largest visible
|
||||
scrollable element, then to ``window.scrollBy``.
|
||||
"""
|
||||
delta_x = 0
|
||||
delta_y = 0
|
||||
@@ -2224,47 +2233,101 @@ class BeelineBridge:
|
||||
elif direction == "left":
|
||||
delta_x = -amount
|
||||
|
||||
# JavaScript scroll that finds the largest scrollable container
|
||||
# NOTE: Do NOT wrap in IIFE - evaluate() already wraps scripts
|
||||
# Direction axis: only consider candidates that can actually scroll
|
||||
# along the requested axis. 'y' for up/down, 'x' for left/right.
|
||||
axis = "y" if direction in ("up", "down") else "x"
|
||||
|
||||
selector_json = json.dumps(selector) if selector else "null"
|
||||
|
||||
# NOTE: Do NOT wrap in IIFE — evaluate() already wraps scripts.
|
||||
# Use behavior:'instant' so the post-scroll snapshot reflects the
|
||||
# final state without racing the smooth-scroll animation.
|
||||
scroll_script = f"""
|
||||
// Find the largest scrollable container
|
||||
const candidates = [];
|
||||
const allElements = document.querySelectorAll('*');
|
||||
{self._SHADOW_QUERY_JS}
|
||||
|
||||
for (const el of allElements) {{
|
||||
const style = getComputedStyle(el);
|
||||
const overflow = style.overflow + style.overflowY;
|
||||
const dx = {delta_x};
|
||||
const dy = {delta_y};
|
||||
const axis = {json.dumps(axis)};
|
||||
const userSelector = {selector_json};
|
||||
|
||||
if (overflow.includes('scroll') || overflow.includes('auto')) {{
|
||||
const rect = el.getBoundingClientRect();
|
||||
if (rect.width > 100 && rect.height > 100 &&
|
||||
el.scrollHeight > el.clientHeight + 100) {{
|
||||
candidates.push({{el: el, area: rect.width * rect.height}});
|
||||
}}
|
||||
function canScroll(el) {{
|
||||
if (!el || el.nodeType !== 1) return false;
|
||||
if (el === document.scrollingElement || el === document.documentElement || el === document.body) {{
|
||||
return axis === 'y'
|
||||
? document.documentElement.scrollHeight > window.innerHeight + 1
|
||||
: document.documentElement.scrollWidth > window.innerWidth + 1;
|
||||
}}
|
||||
const style = getComputedStyle(el);
|
||||
if (style.visibility === 'hidden' || style.display === 'none') return false;
|
||||
const overflow = axis === 'y'
|
||||
? (style.overflowY + style.overflow)
|
||||
: (style.overflowX + style.overflow);
|
||||
if (!/auto|scroll|overlay/.test(overflow)) return false;
|
||||
return axis === 'y'
|
||||
? el.scrollHeight > el.clientHeight + 1
|
||||
: el.scrollWidth > el.clientWidth + 1;
|
||||
}}
|
||||
|
||||
function findScrollableAncestor(el) {{
|
||||
let node = el;
|
||||
while (node && node !== document.body && node !== document.documentElement) {{
|
||||
if (canScroll(node)) return node;
|
||||
node = node.parentElement;
|
||||
}}
|
||||
return null;
|
||||
}}
|
||||
|
||||
// 1. Explicit selector wins
|
||||
if (userSelector) {{
|
||||
const el = userSelector.includes('>>>')
|
||||
? _shadowQuery(userSelector)
|
||||
: document.querySelector(userSelector);
|
||||
if (!el) {{
|
||||
return {{ success: false, error: 'selector_not_found', selector: userSelector }};
|
||||
}}
|
||||
if (!canScroll(el)) {{
|
||||
return {{ success: false, error: 'not_scrollable_in_direction',
|
||||
selector: userSelector, axis: axis, tag: el.tagName }};
|
||||
}}
|
||||
el.scrollBy({{ top: dy, left: dx, behavior: 'instant' }});
|
||||
return {{ success: true, method: 'selector', tag: el.tagName }};
|
||||
}}
|
||||
|
||||
// 2. Prefer the scrollable ancestor at the viewport center —
|
||||
// a much better proxy for "what the agent is looking at"
|
||||
// than "largest element on the page."
|
||||
const cx = window.innerWidth / 2;
|
||||
const cy = window.innerHeight / 2;
|
||||
const elAtCenter = document.elementFromPoint(cx, cy);
|
||||
const centerHit = findScrollableAncestor(elAtCenter);
|
||||
if (centerHit) {{
|
||||
centerHit.scrollBy({{ top: dy, left: dx, behavior: 'instant' }});
|
||||
return {{ success: true, method: 'viewport-center', tag: centerHit.tagName }};
|
||||
}}
|
||||
|
||||
// 3. Fallback: largest visible scrollable element on the
|
||||
// correct axis. Filters out hidden/offscreen drawers and
|
||||
// elements that scroll the wrong way.
|
||||
const candidates = [];
|
||||
for (const el of document.querySelectorAll('*')) {{
|
||||
if (!canScroll(el)) continue;
|
||||
const rect = el.getBoundingClientRect();
|
||||
if (rect.width < 50 || rect.height < 50) continue;
|
||||
// Must intersect the viewport
|
||||
if (rect.bottom <= 0 || rect.top >= window.innerHeight) continue;
|
||||
if (rect.right <= 0 || rect.left >= window.innerWidth) continue;
|
||||
candidates.push({{ el: el, area: rect.width * rect.height }});
|
||||
}}
|
||||
candidates.sort((a, b) => b.area - a.area);
|
||||
const container = candidates.length > 0 ? candidates[0].el : null;
|
||||
|
||||
if (container) {{
|
||||
container.scrollBy({{ top: {delta_y}, left: {delta_x}, behavior: 'smooth' }});
|
||||
return {{
|
||||
success: true,
|
||||
method: 'container',
|
||||
tag: container.tagName,
|
||||
scrolled: true
|
||||
}};
|
||||
if (candidates.length > 0) {{
|
||||
const container = candidates[0].el;
|
||||
container.scrollBy({{ top: dy, left: dx, behavior: 'instant' }});
|
||||
return {{ success: true, method: 'largest-visible', tag: container.tagName }};
|
||||
}}
|
||||
|
||||
// Fallback to window scroll
|
||||
window.scrollBy({{ top: {delta_y}, left: {delta_x}, behavior: 'smooth' }});
|
||||
return {{
|
||||
success: true,
|
||||
method: 'window',
|
||||
tag: 'WINDOW',
|
||||
scrolled: true
|
||||
}};
|
||||
// 4. Last resort: window scroll
|
||||
window.scrollBy({{ top: dy, left: dx, behavior: 'instant' }});
|
||||
return {{ success: true, method: 'window', tag: 'WINDOW' }};
|
||||
"""
|
||||
|
||||
try:
|
||||
@@ -2280,8 +2343,18 @@ class BeelineBridge:
|
||||
"method": value.get("method", "js"),
|
||||
"container": value.get("tag", "unknown"),
|
||||
}
|
||||
else:
|
||||
return {"ok": False, "error": "scroll script returned failure"}
|
||||
err = value.get("error") or "scroll script returned failure"
|
||||
if err == "selector_not_found":
|
||||
return {"ok": False, "error": f"Element not found: {value.get('selector')}"}
|
||||
if err == "not_scrollable_in_direction":
|
||||
return {
|
||||
"ok": False,
|
||||
"error": (
|
||||
f"Element {value.get('tag')} ({value.get('selector')}) is not "
|
||||
f"scrollable along the {value.get('axis')} axis"
|
||||
),
|
||||
}
|
||||
return {"ok": False, "error": err}
|
||||
|
||||
except TimeoutError:
|
||||
return {"ok": False, "error": "scroll timed out"}
|
||||
|
||||
@@ -110,10 +110,11 @@ TOOL_SCHEMAS: dict[str, dict] = {
|
||||
},
|
||||
},
|
||||
"browser_scroll": {
|
||||
"description": "Scroll the page.",
|
||||
"description": "Scroll the page or a specific scrollable container.",
|
||||
"params": {
|
||||
"direction": {"type": "string", "default": "down", "enum": ["up", "down", "left", "right"]},
|
||||
"amount": {"type": "integer", "default": 500},
|
||||
"selector": {"type": "string"},
|
||||
"tab_id": {"type": "integer"},
|
||||
"profile": {"type": "string"},
|
||||
},
|
||||
|
||||
@@ -845,16 +845,24 @@ def register_interaction_tools(mcp: FastMCP) -> None:
|
||||
async def browser_scroll(
|
||||
direction: Literal["up", "down", "left", "right"] = "down",
|
||||
amount: int = 500,
|
||||
selector: str | None = None,
|
||||
tab_id: int | None = None,
|
||||
profile: str | None = None,
|
||||
auto_snapshot_mode: AutoSnapshotMode = "default",
|
||||
) -> dict:
|
||||
"""
|
||||
Scroll the page.
|
||||
Scroll the page or a specific scrollable container.
|
||||
|
||||
Args:
|
||||
direction: Scroll direction (up, down, left, right)
|
||||
amount: Scroll amount in pixels (default: 500)
|
||||
selector: Optional CSS selector for the container to scroll.
|
||||
Supports '>>>' shadow-piercing selectors. When omitted,
|
||||
the tool picks the scrollable container at the viewport
|
||||
center, then falls back to the largest visible
|
||||
scrollable element, then to the window. Use this when
|
||||
auto-pick scrolls the wrong area (e.g. nested panels,
|
||||
modals over a long page, chat history beside a sidebar).
|
||||
tab_id: Chrome tab ID (default: active tab)
|
||||
profile: Browser profile name (default: "default")
|
||||
auto_snapshot_mode: Controls the accessibility snapshot taken
|
||||
@@ -869,7 +877,13 @@ def register_interaction_tools(mcp: FastMCP) -> None:
|
||||
``auto_snapshot_mode="off"`` or the scroll failed.
|
||||
"""
|
||||
start = time.perf_counter()
|
||||
params = {"direction": direction, "amount": amount, "tab_id": tab_id, "profile": profile}
|
||||
params = {
|
||||
"direction": direction,
|
||||
"amount": amount,
|
||||
"selector": selector,
|
||||
"tab_id": tab_id,
|
||||
"profile": profile,
|
||||
}
|
||||
|
||||
bridge = get_bridge()
|
||||
if not bridge or not bridge.is_connected:
|
||||
@@ -890,7 +904,7 @@ def register_interaction_tools(mcp: FastMCP) -> None:
|
||||
return result
|
||||
|
||||
try:
|
||||
scroll_result = await bridge.scroll(target_tab, direction=direction, amount=amount)
|
||||
scroll_result = await bridge.scroll(target_tab, direction=direction, amount=amount, selector=selector)
|
||||
log_tool_call(
|
||||
"browser_scroll",
|
||||
params,
|
||||
|
||||
@@ -201,7 +201,7 @@ class TestComplexScriptExecution:
|
||||
"""Test LinkedIn-style infinite feed scrolling with lazy loading."""
|
||||
scroll_calls = []
|
||||
|
||||
async def mock_scroll(tab_id: int, direction: str, amount: int = 500) -> dict:
|
||||
async def mock_scroll(tab_id: int, direction: str, amount: int = 500, selector: str | None = None) -> dict:
|
||||
scroll_calls.append((tab_id, direction, amount))
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@@ -91,6 +91,7 @@ class TestHealthCheckerRegistry:
|
||||
"prometheus",
|
||||
"resend",
|
||||
"serpapi",
|
||||
"similarweb",
|
||||
"slack",
|
||||
"stripe",
|
||||
"telegram",
|
||||
|
||||
@@ -0,0 +1,581 @@
|
||||
"""Tests for similarweb_tool - Website traffic and competitor analytics (V5 API)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastmcp import FastMCP
|
||||
|
||||
from aden_tools.tools.similarweb_tool.similarweb_tool import register_tools
|
||||
|
||||
|
||||
class MockCredentials:
|
||||
def get(self, key: str) -> str | None:
|
||||
if key == "similarweb":
|
||||
return "test_api_key_123"
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def credentials() -> MockCredentials:
|
||||
return MockCredentials()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_with_tools(credentials: MockCredentials) -> FastMCP:
|
||||
mcp = FastMCP("SimilarWebTest")
|
||||
register_tools(mcp, credentials=credentials)
|
||||
return mcp
|
||||
|
||||
|
||||
class TestSimilarWebToolV5:
|
||||
def _mock_response(self, mock_get: MagicMock, json_data: dict) -> None:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = json_data
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
def _assert_v5_request(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
expected_full_endpoint: str,
|
||||
expected_params: dict | None = None,
|
||||
) -> None:
|
||||
mock_get.assert_called_once()
|
||||
actual_url = mock_get.call_args[0][0]
|
||||
expected_url = f"https://api.similarweb.com/v5/{expected_full_endpoint}"
|
||||
assert actual_url == expected_url
|
||||
|
||||
call_kwargs = mock_get.call_args[1]
|
||||
assert call_kwargs["headers"]["api-key"] == "test_api_key_123"
|
||||
assert call_kwargs["headers"]["Accept"] == "application/json"
|
||||
|
||||
if expected_params:
|
||||
for k, v in expected_params.items():
|
||||
assert call_kwargs["params"][k] == v
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_traffic_and_engagement_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {
|
||||
"meta": {"request": {"domain": "amazon.com"}},
|
||||
"visits": [{"date": "2023-01-01", "visits": 1000}],
|
||||
}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_traffic_and_engagement"]
|
||||
result = tool.fn(domain="amazon.com", country="us", granularity="daily")
|
||||
|
||||
self._assert_v5_request(
|
||||
mock_get,
|
||||
"website-analysis/websites/traffic-and-engagement",
|
||||
{
|
||||
"domain": "amazon.com",
|
||||
"country": "us",
|
||||
"granularity": "daily",
|
||||
"metrics": "visits,bounce_rate,avg_visit_duration,pages_per_visit,total_page_views",
|
||||
"web_source": "desktop",
|
||||
},
|
||||
)
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_website_rank_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {"global_rank": 10, "country_rank": 5}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_website_rank"]
|
||||
result = tool.fn(domain="amazon.com")
|
||||
|
||||
self._assert_v5_request(mock_get, "website-analysis/websites/website-rank", {"domain": "amazon.com"})
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_traffic_sources_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {"search": 0.4, "direct": 0.3}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_traffic_sources"]
|
||||
result = tool.fn(domain="amazon.com", country="world")
|
||||
|
||||
self._assert_v5_request(
|
||||
mock_get, "website-analysis/websites/traffic-sources", {"domain": "amazon.com", "country": "world"}
|
||||
)
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_geography_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {"top_countries": [{"country": "US", "share": 0.5}]}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_geography"]
|
||||
result = tool.fn(domain="amazon.com")
|
||||
|
||||
self._assert_v5_request(mock_get, "website-analysis/websites/traffic-geography", {"domain": "amazon.com"})
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_demographics_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {"age_distribution": {"18-24": 0.2}}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_demographics"]
|
||||
result = tool.fn(domain="amazon.com")
|
||||
|
||||
self._assert_v5_request(mock_get, "website-analysis/websites/demographics/aggregated", {"domain": "amazon.com"})
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_company_info_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {"company_name": "Amazon", "headquarters": "Seattle, WA"}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_company_info"]
|
||||
result = tool.fn(domain="amazon.com")
|
||||
|
||||
self._assert_v5_request(
|
||||
mock_get, "website-analysis/websites/company-info/company-info", {"domain": "amazon.com"}
|
||||
)
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_top_sites_by_category_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {"top_sites": [{"domain": "google.com", "rank": 1}]}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_top_sites_by_category"]
|
||||
result = tool.fn(category="Search Engines")
|
||||
|
||||
self._assert_v5_request(
|
||||
mock_get,
|
||||
"website-analysis/websites/top-sites-by-category/aggregated",
|
||||
{"category": "Search Engines", "country": "world"},
|
||||
)
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_keyword_competitors_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {"competitors": [{"domain": "competitor.com", "overlap": 0.8}]}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_keyword_competitors"]
|
||||
result = tool.fn(domain="amazon.com")
|
||||
|
||||
self._assert_v5_request(
|
||||
mock_get, "website-analysis/websites/keywords-competitors/aggregated", {"domain": "amazon.com"}
|
||||
)
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_technologies_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {"technologies": [{"name": "React", "category": "Frontend Framework"}]}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_technologies"]
|
||||
result = tool.fn(domain="amazon.com")
|
||||
|
||||
self._assert_v5_request(mock_get, "website-analysis/websites/technologies/aggregated", {"domain": "amazon.com"})
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_deduplicated_audience_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {"total_unique_visitors": 1000000}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_deduplicated_audience"]
|
||||
result = tool.fn(domains="amazon.com,ebay.com")
|
||||
|
||||
self._assert_v5_request(
|
||||
mock_get,
|
||||
"website-analysis/websites/deduplicated-audience",
|
||||
{"domains": "amazon.com,ebay.com", "country": "world"},
|
||||
)
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_referrals_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {"referrals": [{"domain": "google.com", "share": 0.5}]}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_referrals"]
|
||||
result = tool.fn(domain="amazon.com")
|
||||
|
||||
self._assert_v5_request(
|
||||
mock_get, "website-analysis/websites/referrals/aggregated", {"domain": "amazon.com", "country": "world"}
|
||||
)
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_ppc_spend_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {"ppc_spend": 5000}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_ppc_spend"]
|
||||
result = tool.fn(domain="amazon.com")
|
||||
|
||||
self._assert_v5_request(
|
||||
mock_get, "website-analysis/websites/ppc-spend", {"domain": "amazon.com", "country": "world"}
|
||||
)
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_geography_details_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {"top_countries": [{"country": "US", "share": 0.5}]}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_geography_details"]
|
||||
result = tool.fn(domain="amazon.com")
|
||||
|
||||
self._assert_v5_request(mock_get, "website-analysis/websites/geography/aggregated", {"domain": "amazon.com"})
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_similar_sites_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {"similar_sites": [{"domain": "ebay.com"}]}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_similar_sites"]
|
||||
result = tool.fn(domain="amazon.com")
|
||||
|
||||
self._assert_v5_request(
|
||||
mock_get, "website-analysis/websites/similar-sites/aggregated", {"domain": "amazon.com"}
|
||||
)
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_ad_networks_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {"ad_networks": [{"name": "Google Ads", "share": 0.5}]}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_ad_networks"]
|
||||
result = tool.fn(domain="amazon.com")
|
||||
|
||||
self._assert_v5_request(
|
||||
mock_get, "website-analysis/ad-networks/aggregated", {"domain": "amazon.com", "country": "world"}
|
||||
)
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_demographics_traffic_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {"demographics": {"male": 0.5, "female": 0.5}}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_demographics_traffic"]
|
||||
result = tool.fn(domain="amazon.com")
|
||||
|
||||
self._assert_v5_request(
|
||||
mock_get, "website-analysis/websites/traffic-by-demographics/aggregated", {"domain": "amazon.com"}
|
||||
)
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_audience_interests_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {"interests": ["shopping", "tech"]}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_audience_interests"]
|
||||
result = tool.fn(domain="amazon.com")
|
||||
|
||||
self._assert_v5_request(
|
||||
mock_get, "website-analysis/websites/audience-interests/aggregated", {"domain": "amazon.com"}
|
||||
)
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_audience_overlap_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {"overlap": 0.3}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_audience_overlap"]
|
||||
result = tool.fn(domain="amazon.com", compare_to="ebay.com")
|
||||
|
||||
self._assert_v5_request(
|
||||
mock_get,
|
||||
"website-analysis/websites/audience-overlap/aggregated",
|
||||
{"domain": "amazon.com", "compare_to": "ebay.com"},
|
||||
)
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_leading_folders_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {"folders": [{"name": "/products/", "share": 0.4}]}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_leading_folders"]
|
||||
result = tool.fn(domain="amazon.com")
|
||||
|
||||
self._assert_v5_request(
|
||||
mock_get,
|
||||
"website-analysis/websites/pages/leading-folders/aggregated",
|
||||
{"domain": "amazon.com", "country": "world"},
|
||||
)
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_popular_pages_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {"pages": [{"url": "amazon.com/best-sellers", "share": 0.1}]}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_popular_pages"]
|
||||
result = tool.fn(domain="amazon.com")
|
||||
|
||||
self._assert_v5_request(
|
||||
mock_get, "website-content/pages/popular-pages/aggregated", {"domain": "amazon.com", "country": "world"}
|
||||
)
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_subdomains_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {"subdomains": [{"name": "aws.amazon.com", "share": 0.2}]}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_subdomains"]
|
||||
result = tool.fn(domain="amazon.com")
|
||||
|
||||
self._assert_v5_request(
|
||||
mock_get, "website-content/subdomains/aggregated", {"domain": "amazon.com", "country": "world"}
|
||||
)
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_keyword_opportunities_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {"opportunities": [{"keyword": "buy electronics", "score": 90}]}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_keyword_opportunities"]
|
||||
result = tool.fn(domain="amazon.com")
|
||||
|
||||
self._assert_v5_request(
|
||||
mock_get, "website-analysis/websites/keywords-opportunities/aggregated", {"domain": "amazon.com"}
|
||||
)
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_serp_features_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {"serp_features": {"featured_snippets": 10}}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_serp_features"]
|
||||
result = tool.fn(domain="amazon.com")
|
||||
|
||||
self._assert_v5_request(mock_get, "website-analysis/websites/keywords/serp-features", {"domain": "amazon.com"})
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_organic_keywords_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {"keywords": [{"phrase": "shopping", "visits": 1000}]}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_organic_keywords"]
|
||||
result = tool.fn(domain="amazon.com")
|
||||
|
||||
self._assert_v5_request(
|
||||
mock_get,
|
||||
"website-analysis/websites/keywords/organic/aggregated",
|
||||
{"domain": "amazon.com", "country": "world"},
|
||||
)
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_paid_keywords_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {"keywords": [{"phrase": "buy books", "visits": 500}]}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_paid_keywords"]
|
||||
result = tool.fn(domain="amazon.com")
|
||||
|
||||
self._assert_v5_request(
|
||||
mock_get, "website-analysis/websites/keywords/paid/aggregated", {"domain": "amazon.com", "country": "world"}
|
||||
)
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_serp_players_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {"players": [{"domain": "walmart.com", "share": 0.1}]}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_serp_players"]
|
||||
result = tool.fn(domain="amazon.com")
|
||||
|
||||
self._assert_v5_request(
|
||||
mock_get, "website-analysis/websites/keywords/serp-players/aggregated", {"domain": "amazon.com"}
|
||||
)
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_social_referrals_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {"social": [{"name": "Facebook", "share": 0.5}]}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_social_referrals"]
|
||||
result = tool.fn(domain="amazon.com")
|
||||
|
||||
self._assert_v5_request(
|
||||
mock_get,
|
||||
"website-analysis/websites/social-referrals/aggregated",
|
||||
{"domain": "amazon.com", "country": "world"},
|
||||
)
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_segments_list_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {"segments": [{"id": "seg1", "name": "Segment 1"}]}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_segments_list"]
|
||||
result = tool.fn(domain="amazon.com")
|
||||
|
||||
self._assert_v5_request(mock_get, "segment-analysis/segments/describe", {"domain": "amazon.com"})
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_similarweb_v5_segment_analysis_success(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
response_data = {"metrics": {"visits": 1000}}
|
||||
self._mock_response(mock_get, response_data)
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_segment_analysis"]
|
||||
result = tool.fn(segment_id="seg1")
|
||||
|
||||
self._assert_v5_request(
|
||||
mock_get, "segment-analysis/segments/traffic-and-engagement", {"segment": "seg1", "country": "world"}
|
||||
)
|
||||
assert result == response_data
|
||||
|
||||
@patch("aden_tools.tools.similarweb_tool.similarweb_tool.httpx.get")
|
||||
def test_make_request_error_handling(
|
||||
self,
|
||||
mock_get: MagicMock,
|
||||
mcp_with_tools: FastMCP,
|
||||
) -> None:
|
||||
# Mock a 403 error
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 403
|
||||
mock_response.text = "Forbidden"
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"403 Forbidden", request=MagicMock(), response=mock_response
|
||||
)
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
tool = mcp_with_tools._tool_manager._tools["similarweb_v5_website_rank"]
|
||||
result = tool.fn(domain="forbidden.com")
|
||||
|
||||
assert "error" in result
|
||||
assert "403" in result["error"]
|
||||
assert "Forbidden" in result["error"]
|
||||
Reference in New Issue
Block a user