Compare commits
43 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| facd919371 | |||
| cb1484be85 | |||
| da361f735d | |||
| eea0429f93 | |||
| 833aa4bc7a | |||
| 0af597881f | |||
| 6fae1f04c8 | |||
| 8c4085f5e8 | |||
| 53240eb888 | |||
| de8d6f0946 | |||
| ea707438f2 | |||
| 445c9600ab | |||
| 2ab5e6d784 | |||
| e7f9b7d791 | |||
| 3cb0c69a96 | |||
| 22d75bfb05 | |||
| 357df1bbcb | |||
| 386bbd5780 | |||
| 235022b35d | |||
| 4d8f312c3e | |||
| 4651a6a85a | |||
| ea9c163438 | |||
| 77cc169606 | |||
| 8c6428f445 | |||
| 44cb0c0f4c | |||
| 2621fb88b1 | |||
| a70f92edbe | |||
| b2efa179ea | |||
| 8c6e76d052 | |||
| c7f1fbf19f | |||
| 7047ecbf46 | |||
| b96ee5aaab | |||
| 6744bea01a | |||
| 390038225b | |||
| b55c8fdf86 | |||
| e9aea0bbc4 | |||
| 0ba1fa8262 | |||
| 0fd96d410e | |||
| c658a7c50b | |||
| 56c3659bda | |||
| 14f927996c | |||
| 8a0ec070b8 | |||
| 80cd77ac30 |
@@ -85,7 +85,12 @@ from framework.agent_loop.internals.types import (
|
||||
JudgeVerdict,
|
||||
TriggerEvent,
|
||||
)
|
||||
from framework.agent_loop.internals.vision_fallback import (
|
||||
caption_tool_image,
|
||||
extract_intent_for_tool,
|
||||
)
|
||||
from framework.agent_loop.types import AgentContext, AgentProtocol, AgentResult
|
||||
from framework.config import get_vision_fallback_model
|
||||
from framework.host.event_bus import EventBus
|
||||
from framework.llm.capabilities import filter_tools_for_model, supports_image_tool_results
|
||||
from framework.llm.provider import Tool, ToolResult, ToolUse
|
||||
@@ -219,6 +224,46 @@ async def _describe_images_as_text(image_content: list[dict[str, Any]]) -> str |
|
||||
return None
|
||||
|
||||
|
||||
def _vision_fallback_active(model: str | None) -> bool:
|
||||
"""Return True if tool-result images for *model* should be routed
|
||||
through the vision-fallback chain rather than sent to the model.
|
||||
|
||||
Trigger: the model's catalog entry has ``supports_vision: false``
|
||||
(resolved via :func:`capabilities.supports_image_tool_results`,
|
||||
which reads ``model_catalog.json``). Unknown models default to
|
||||
vision-capable, so the fallback only fires when the catalog
|
||||
explicitly says the model is text-only.
|
||||
|
||||
The ``vision_fallback`` config block is the *substitution* model —
|
||||
it doesn't widen the trigger. To force fallback for a model that
|
||||
isn't catalogued yet, add an entry to ``model_catalog.json`` with
|
||||
``supports_vision: false`` rather than relying on a runtime config.
|
||||
"""
|
||||
if not model:
|
||||
return False
|
||||
return not supports_image_tool_results(model)
|
||||
|
||||
|
||||
async def _captioning_chain(
|
||||
intent: str,
|
||||
image_content: list[dict[str, Any]],
|
||||
) -> str | None:
|
||||
"""Two-stage caption chain used by the agent-loop tool-result hook.
|
||||
|
||||
Stage 1: configured ``vision_fallback`` model with intent + images.
|
||||
Stage 2: generic-caption rotation (gpt-4o-mini → claude-3-haiku
|
||||
→ gemini-flash) when stage 1 is unconfigured or fails.
|
||||
|
||||
Returns the caption text or None if both stages fail. Caller is
|
||||
responsible for the placeholder-on-None and the splice into the
|
||||
persisted tool-result content.
|
||||
"""
|
||||
caption = await caption_tool_image(intent, image_content)
|
||||
if not caption:
|
||||
caption = await _describe_images_as_text(image_content)
|
||||
return caption
|
||||
|
||||
|
||||
# Pattern for detecting context-window-exceeded errors across LLM providers.
|
||||
_CONTEXT_TOO_LARGE_RE = re.compile(
|
||||
r"context.{0,20}(length|window|limit|size)|"
|
||||
@@ -376,6 +421,14 @@ class AgentLoop(AgentProtocol):
|
||||
# dashboards can build aggregates over many runs.
|
||||
self._counters: dict[str, int] = {}
|
||||
|
||||
# Task-system reminder state (see framework/tasks/reminders.py).
|
||||
# Bumped each iteration; reset whenever a task op tool was called
|
||||
# in the iteration that just completed; nudges the agent via the
|
||||
# injection queue when it's been silent on tasks for too long.
|
||||
from framework.tasks.reminders import ReminderState as _RS
|
||||
|
||||
self._task_reminder_state: _RS = _RS()
|
||||
|
||||
def _bump(self, key: str, by: int = 1) -> None:
|
||||
"""Increment a reliability counter (creates the key on first use)."""
|
||||
self._counters[key] = self._counters.get(key, 0) + by
|
||||
@@ -575,6 +628,7 @@ class AgentLoop(AgentProtocol):
|
||||
store=self._conversation_store,
|
||||
run_id=ctx.effective_run_id,
|
||||
compaction_buffer_tokens=self._config.compaction_buffer_tokens,
|
||||
compaction_buffer_ratio=self._config.compaction_buffer_ratio,
|
||||
compaction_warning_buffer_tokens=(self._config.compaction_warning_buffer_tokens),
|
||||
)
|
||||
accumulator = OutputAccumulator(
|
||||
@@ -587,7 +641,12 @@ class AgentLoop(AgentProtocol):
|
||||
|
||||
initial_message = self._build_initial_message(ctx)
|
||||
if initial_message:
|
||||
await conversation.add_user_message(initial_message)
|
||||
# Stamp with arrival time so the conversation has a
|
||||
# temporal anchor for the first turn, matching the
|
||||
# stamping done by drain_injection_queue for every
|
||||
# subsequent event.
|
||||
_stamp = datetime.now().astimezone().strftime("%Y-%m-%d %H:%M %Z")
|
||||
await conversation.add_user_message(f"[{_stamp}] {initial_message}")
|
||||
|
||||
await self._run_hooks("session_start", conversation, trigger=initial_message)
|
||||
|
||||
@@ -599,7 +658,8 @@ class AgentLoop(AgentProtocol):
|
||||
initial_message = self._build_initial_message(ctx)
|
||||
if not initial_message:
|
||||
initial_message = "Hello"
|
||||
await conversation.add_user_message(initial_message)
|
||||
_stamp = datetime.now().astimezone().strftime("%Y-%m-%d %H:%M %Z")
|
||||
await conversation.add_user_message(f"[{_stamp}] {initial_message}")
|
||||
|
||||
# 2b. Restore spill counter from existing files (resume safety)
|
||||
self._restore_spill_counter()
|
||||
@@ -619,8 +679,23 @@ class AgentLoop(AgentProtocol):
|
||||
# Hide image-producing tools from text-only models so they never try
|
||||
# to call them. Avoids wasted turns + "screenshot failed" lessons
|
||||
# getting saved to memory. See framework.llm.capabilities.
|
||||
# EXCEPTION: when the model IS on the text-only deny list AND
|
||||
# a vision_fallback subagent is configured, leave image tools
|
||||
# visible. The post-execution hook in the inner tool loop
|
||||
# will route each image_content through the fallback VLM and
|
||||
# replace it with a text caption before the main agent sees
|
||||
# the result — so the main agent gets captions instead of
|
||||
# raw images, rather than losing the tool entirely. We DON'T
|
||||
# bypass the filter for vision-capable models (that would be
|
||||
# a no-op anyway — the filter doesn't fire for them) and we
|
||||
# DON'T bypass it without a configured fallback (the agent
|
||||
# would just see raw stripped tool results with no caption).
|
||||
_llm_model = ctx.llm.model if ctx.llm else ""
|
||||
tools, _hidden_image_tools = filter_tools_for_model(tools, _llm_model)
|
||||
_text_only_main = _llm_model and not supports_image_tool_results(_llm_model)
|
||||
if _text_only_main and get_vision_fallback_model() is not None:
|
||||
_hidden_image_tools: list[str] = []
|
||||
else:
|
||||
tools, _hidden_image_tools = filter_tools_for_model(tools, _llm_model)
|
||||
|
||||
logger.info(
|
||||
"[%s] Tools available (%d): %s | direct_user_io=%s | judge=%s | hidden_image_tools=%s",
|
||||
@@ -793,14 +868,56 @@ class AgentLoop(AgentProtocol):
|
||||
tools.extend(synthetic)
|
||||
|
||||
# 6b3. Dynamic prompt refresh (phase switching / memory refresh)
|
||||
if ctx.dynamic_prompt_provider is not None or ctx.dynamic_memory_provider is not None:
|
||||
if (
|
||||
ctx.dynamic_prompt_provider is not None
|
||||
or ctx.dynamic_memory_provider is not None
|
||||
or ctx.dynamic_skills_catalog_provider is not None
|
||||
):
|
||||
if ctx.dynamic_prompt_provider is not None:
|
||||
_new_prompt = stamp_prompt_datetime(ctx.dynamic_prompt_provider())
|
||||
_new_prompt = ctx.dynamic_prompt_provider()
|
||||
# When a suffix provider is also wired (Queen's
|
||||
# static/dynamic split), keep the two pieces separate
|
||||
# so the LLM wrapper can emit them as two system
|
||||
# content blocks with a cache breakpoint between them.
|
||||
# The timestamp used to be stamped here via
|
||||
# stamp_prompt_datetime on every iteration — it now
|
||||
# lives inside the frozen dynamic suffix and is only
|
||||
# refreshed at user-turn boundaries, so per-iteration
|
||||
# stamping would both double-stamp and bust the cache.
|
||||
_new_suffix: str | None = None
|
||||
if ctx.dynamic_prompt_suffix_provider is not None:
|
||||
try:
|
||||
_new_suffix = ctx.dynamic_prompt_suffix_provider() or ""
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"[%s] dynamic_prompt_suffix_provider raised — falling back to legacy stamp",
|
||||
node_id,
|
||||
exc_info=True,
|
||||
)
|
||||
_new_suffix = None
|
||||
if _new_suffix is None:
|
||||
# Legacy / fallback path: no split in use (or the
|
||||
# suffix provider raised). Stamp the timestamp at
|
||||
# the end of the single-string prompt so the model
|
||||
# still sees a current "now".
|
||||
_new_prompt = stamp_prompt_datetime(_new_prompt)
|
||||
else:
|
||||
# build_system_prompt_for_context reads dynamic_skills_catalog_provider
|
||||
# directly; no separate branch needed.
|
||||
_new_prompt = build_system_prompt_for_context(ctx)
|
||||
if _new_prompt != conversation.system_prompt:
|
||||
conversation.update_system_prompt(_new_prompt)
|
||||
logger.info("[%s] Dynamic prompt updated", node_id)
|
||||
_new_suffix = None
|
||||
if _new_suffix is not None:
|
||||
_combined_for_compare = f"{_new_prompt}\n\n{_new_suffix}" if _new_suffix else _new_prompt
|
||||
if (
|
||||
_combined_for_compare != conversation.system_prompt
|
||||
or _new_suffix != conversation.system_prompt_dynamic_suffix
|
||||
):
|
||||
conversation.update_system_prompt(_new_prompt, dynamic_suffix=_new_suffix)
|
||||
logger.info("[%s] Dynamic prompt updated (split)", node_id)
|
||||
else:
|
||||
if _new_prompt != conversation.system_prompt:
|
||||
conversation.update_system_prompt(_new_prompt)
|
||||
logger.info("[%s] Dynamic prompt updated", node_id)
|
||||
|
||||
# 6c. Publish iteration event (with per-iteration metadata when available)
|
||||
_iter_meta = None
|
||||
@@ -882,6 +999,17 @@ class AgentLoop(AgentProtocol):
|
||||
)
|
||||
total_input_tokens += turn_tokens.get("input", 0)
|
||||
total_output_tokens += turn_tokens.get("output", 0)
|
||||
|
||||
# Task-system reminder: if the model has been silent on
|
||||
# task ops for too long but still has open tasks, drop
|
||||
# a steering reminder onto the injection queue. Drained
|
||||
# at the next iteration's 6b so it lands as the next
|
||||
# user turn via the normal injection path. Best-effort
|
||||
# — never raises.
|
||||
try:
|
||||
await self._maybe_inject_task_reminder(ctx, logged_tool_calls)
|
||||
except Exception:
|
||||
logger.debug("task reminder check failed", exc_info=True)
|
||||
await self._publish_llm_turn_complete(
|
||||
stream_id,
|
||||
node_id,
|
||||
@@ -890,6 +1018,8 @@ class AgentLoop(AgentProtocol):
|
||||
input_tokens=turn_tokens.get("input", 0),
|
||||
output_tokens=turn_tokens.get("output", 0),
|
||||
cached_tokens=turn_tokens.get("cached", 0),
|
||||
cache_creation_tokens=turn_tokens.get("cache_creation", 0),
|
||||
cost_usd=float(turn_tokens.get("cost", 0.0) or 0.0),
|
||||
execution_id=execution_id,
|
||||
iteration=iteration,
|
||||
)
|
||||
@@ -2290,7 +2420,9 @@ class AgentLoop(AgentProtocol):
|
||||
stream_id = ctx.stream_id or ctx.agent_id
|
||||
node_id = ctx.agent_id
|
||||
execution_id = ctx.execution_id or ""
|
||||
token_counts: dict[str, int] = {"input": 0, "output": 0, "cached": 0}
|
||||
# Mixed-type dict: int token counts + str stop_reason/model + float cost.
|
||||
# Typed loosely to avoid churn in the many call sites that read from it.
|
||||
token_counts: dict[str, Any] = {"input": 0, "output": 0, "cached": 0, "cache_creation": 0, "cost": 0.0}
|
||||
tool_call_count = 0
|
||||
final_text = ""
|
||||
final_system_prompt = conversation.system_prompt
|
||||
@@ -2431,9 +2563,16 @@ class AgentLoop(AgentProtocol):
|
||||
nonlocal _first_event_at
|
||||
_clean_snapshot = "" # visible-only text for the frontend
|
||||
|
||||
# Split-prompt path: pass STATIC and DYNAMIC tail separately
|
||||
# so the LLM wrapper can emit them as two Anthropic system
|
||||
# content blocks with a cache breakpoint between them. When
|
||||
# no split is in use, ``system_prompt_static`` equals the
|
||||
# full prompt and the suffix is empty — identical to the
|
||||
# legacy single-block request.
|
||||
async for event in ctx.llm.stream(
|
||||
messages=_msgs,
|
||||
system=conversation.system_prompt,
|
||||
system=conversation.system_prompt_static,
|
||||
system_dynamic_suffix=(conversation.system_prompt_dynamic_suffix or None),
|
||||
tools=tools if tools else None,
|
||||
max_tokens=ctx.max_tokens,
|
||||
):
|
||||
@@ -2514,6 +2653,8 @@ class AgentLoop(AgentProtocol):
|
||||
token_counts["input"] += event.input_tokens
|
||||
token_counts["output"] += event.output_tokens
|
||||
token_counts["cached"] += event.cached_tokens
|
||||
token_counts["cache_creation"] += event.cache_creation_tokens
|
||||
token_counts["cost"] = token_counts.get("cost", 0.0) + event.cost_usd
|
||||
token_counts["stop_reason"] = event.stop_reason
|
||||
token_counts["model"] = event.model
|
||||
|
||||
@@ -3306,6 +3447,30 @@ class AgentLoop(AgentProtocol):
|
||||
|
||||
# Phase 3: record results into conversation in original order,
|
||||
# build logged/real lists, and publish completed events.
|
||||
#
|
||||
# Vision-fallback prefetch: a single turn may fire several
|
||||
# image-producing tools in parallel (e.g. one screenshot
|
||||
# per tab). Captioning each one takes a vision LLM round
|
||||
# trip (1–30 s). Doing them sequentially in this loop
|
||||
# would serialise that latency per image. Instead, kick
|
||||
# off all caption tasks concurrently NOW, and await each
|
||||
# one just-in-time inside the per-tc body. If only a
|
||||
# single image needs captioning, this collapses to a
|
||||
# single await with no overhead.
|
||||
_model_text_only = ctx.llm and _vision_fallback_active(ctx.llm.model)
|
||||
caption_tasks: dict[str, asyncio.Task[str | None]] = {}
|
||||
if _model_text_only:
|
||||
for tc in tool_calls[:executed_in_batch]:
|
||||
res = results_by_id.get(tc.tool_use_id)
|
||||
if not res or not res.image_content:
|
||||
continue
|
||||
intent = extract_intent_for_tool(
|
||||
conversation,
|
||||
tc.tool_name,
|
||||
tc.tool_input or {},
|
||||
)
|
||||
caption_tasks[tc.tool_use_id] = asyncio.create_task(_captioning_chain(intent, res.image_content))
|
||||
|
||||
for tc in tool_calls[:executed_in_batch]:
|
||||
result = results_by_id.get(tc.tool_use_id)
|
||||
if result is None:
|
||||
@@ -3328,11 +3493,30 @@ class AgentLoop(AgentProtocol):
|
||||
logged_tool_calls.append(tool_entry)
|
||||
|
||||
image_content = result.image_content
|
||||
if image_content and ctx.llm and not supports_image_tool_results(ctx.llm.model):
|
||||
logger.info(
|
||||
"Stripping image_content from tool result; model '%s' does not support images in tool results",
|
||||
ctx.llm.model,
|
||||
)
|
||||
# Vision-fallback marker spliced into the persisted text
|
||||
# below. None when no captioning ran (vision-capable
|
||||
# main model, no images, or no fallback chain reached
|
||||
# this tool).
|
||||
vision_fallback_marker: str | None = None
|
||||
if image_content and tc.tool_use_id in caption_tasks:
|
||||
caption = await caption_tasks.pop(tc.tool_use_id)
|
||||
if caption:
|
||||
vision_fallback_marker = f"[vision-fallback caption]\n{caption}"
|
||||
logger.info(
|
||||
"vision_fallback: captioned %d image(s) for tool '%s' (model '%s' routed through fallback)",
|
||||
len(image_content),
|
||||
tc.tool_name,
|
||||
ctx.llm.model if ctx.llm else "?",
|
||||
)
|
||||
else:
|
||||
vision_fallback_marker = "[image stripped — vision fallback exhausted]"
|
||||
logger.info(
|
||||
"vision_fallback: exhausted; stripping %d image(s) from "
|
||||
"tool '%s' result without caption (model '%s')",
|
||||
len(image_content),
|
||||
tc.tool_name,
|
||||
ctx.llm.model if ctx.llm else "?",
|
||||
)
|
||||
image_content = None
|
||||
|
||||
# Apply replay-detector steer prefix if this call matched a
|
||||
@@ -3344,6 +3528,11 @@ class AgentLoop(AgentProtocol):
|
||||
if _prefix:
|
||||
stored_content = f"{_prefix}{stored_content or ''}"
|
||||
|
||||
# Splice the vision-fallback caption / placeholder into
|
||||
# the persisted text after any prefix has been applied.
|
||||
if vision_fallback_marker:
|
||||
stored_content = f"{stored_content or ''}\n\n{vision_fallback_marker}"
|
||||
|
||||
await conversation.add_tool_result(
|
||||
tool_use_id=tc.tool_use_id,
|
||||
content=stored_content,
|
||||
@@ -4034,6 +4223,74 @@ class AgentLoop(AgentProtocol):
|
||||
execution_id=execution_id,
|
||||
)
|
||||
|
||||
async def _maybe_inject_task_reminder(
|
||||
self,
|
||||
ctx: AgentContext,
|
||||
logged_tool_calls: list[dict[str, Any]] | None,
|
||||
) -> None:
|
||||
"""Layer 3 task-system steering — periodic reminder injection.
|
||||
|
||||
Called once per iteration after the LLM turn completes. If the
|
||||
model has been silent on task ops for a while AND there are open
|
||||
tasks on its session list, queue a system-style reminder onto
|
||||
the injection queue so the next iteration drains it as a user
|
||||
turn. Idempotent / safe to call always — gates internally.
|
||||
|
||||
``logged_tool_calls`` is a list of dicts with at least a "name"
|
||||
key, as accumulated by ``_run_single_turn``. Names like
|
||||
``task_create``, ``task_update``, ``colony_template_*`` reset
|
||||
the counter (see ``framework.tasks.reminders.TASK_OP_TOOL_NAMES``).
|
||||
"""
|
||||
from framework.tasks import get_task_store
|
||||
from framework.tasks.models import TaskStatus
|
||||
from framework.tasks.reminders import build_reminder, saw_task_op
|
||||
|
||||
state = self._task_reminder_state
|
||||
|
||||
# 1. Update counters based on this turn's tool calls.
|
||||
names: list[str] = []
|
||||
for call in logged_tool_calls or []:
|
||||
try:
|
||||
name = call.get("name") or call.get("tool_name")
|
||||
if name:
|
||||
names.append(name)
|
||||
except (AttributeError, TypeError):
|
||||
continue
|
||||
if saw_task_op(names):
|
||||
state.on_task_op()
|
||||
state.on_iteration()
|
||||
|
||||
# 2. Resolve the agent's task list. Skip if context isn't wired yet.
|
||||
list_id = getattr(ctx, "task_list_id", None)
|
||||
if not list_id:
|
||||
return
|
||||
|
||||
# 3. Read the open-task snapshot. Best-effort.
|
||||
try:
|
||||
store = get_task_store()
|
||||
records = await store.list_tasks(list_id)
|
||||
except Exception:
|
||||
return
|
||||
open_tasks = [r for r in records if r.status != TaskStatus.COMPLETED]
|
||||
if not state.should_remind(bool(open_tasks)):
|
||||
return
|
||||
|
||||
body = build_reminder(records)
|
||||
if not body:
|
||||
return
|
||||
|
||||
# 4. Enqueue. Drained at the next iteration's 6b drain step and
|
||||
# rendered as a user turn (with the "[External event]" prefix).
|
||||
await self._injection_queue.put((body, False, None))
|
||||
state.on_reminder_sent()
|
||||
logger.info(
|
||||
"[task-reminder] queued nudge for %s (open=%d, silent_turns=%d)",
|
||||
list_id,
|
||||
len(open_tasks),
|
||||
state.turns_since_task_op,
|
||||
)
|
||||
self._bump("task_reminders_sent")
|
||||
|
||||
async def _run_hooks(
|
||||
self,
|
||||
event: str,
|
||||
@@ -4095,6 +4352,8 @@ class AgentLoop(AgentProtocol):
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cached_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
cost_usd: float = 0.0,
|
||||
execution_id: str = "",
|
||||
iteration: int | None = None,
|
||||
) -> None:
|
||||
@@ -4107,6 +4366,8 @@ class AgentLoop(AgentProtocol):
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cached_tokens=cached_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
cost_usd=cost_usd,
|
||||
execution_id=execution_id,
|
||||
iteration=iteration,
|
||||
)
|
||||
|
||||
@@ -427,9 +427,20 @@ class NodeConversation:
|
||||
store: ConversationStore | None = None,
|
||||
run_id: str | None = None,
|
||||
compaction_buffer_tokens: int | None = None,
|
||||
compaction_buffer_ratio: float | None = None,
|
||||
compaction_warning_buffer_tokens: int | None = None,
|
||||
) -> None:
|
||||
self._system_prompt = system_prompt
|
||||
# Optional split: when a caller updates the prompt with a
|
||||
# ``dynamic_suffix`` argument, we remember the static prefix and
|
||||
# suffix separately so the LLM wrapper can emit them as two
|
||||
# Anthropic system content blocks with a cache breakpoint between
|
||||
# them. ``_system_prompt`` stays as the concatenated form used for
|
||||
# persistence and for the legacy single-block LLM path.
|
||||
# On restore, these default to the concat/empty pair — the next
|
||||
# AgentLoop iteration's dynamic-prompt refresh step repopulates.
|
||||
self._system_prompt_static: str = system_prompt
|
||||
self._system_prompt_dynamic_suffix: str = ""
|
||||
self._max_context_tokens = max_context_tokens
|
||||
self._compaction_threshold = compaction_threshold
|
||||
# Buffer-based compaction trigger (Gap 7). When set, takes
|
||||
@@ -439,6 +450,11 @@ class NodeConversation:
|
||||
# limit. If left as None the legacy threshold-based rule is
|
||||
# used, keeping old call sites behaving identically.
|
||||
self._compaction_buffer_tokens = compaction_buffer_tokens
|
||||
# Ratio component of the hybrid buffer. Combines additively with
|
||||
# _compaction_buffer_tokens so callers can express "reserve N tokens
|
||||
# plus M% of the window" — the absolute floor matters on tiny
|
||||
# windows, the ratio matters on large ones.
|
||||
self._compaction_buffer_ratio = compaction_buffer_ratio
|
||||
self._compaction_warning_buffer_tokens = compaction_warning_buffer_tokens
|
||||
self._output_keys = output_keys
|
||||
self._store = store
|
||||
@@ -453,15 +469,56 @@ class NodeConversation:
|
||||
|
||||
@property
|
||||
def system_prompt(self) -> str:
|
||||
"""Full concatenated system prompt (static + dynamic suffix, if any).
|
||||
|
||||
This is the canonical form used for persistence and for the legacy
|
||||
single-block LLM path. Split-prompt callers should read
|
||||
``system_prompt_static`` and ``system_prompt_dynamic_suffix`` instead.
|
||||
"""
|
||||
return self._system_prompt
|
||||
|
||||
def update_system_prompt(self, new_prompt: str) -> None:
|
||||
@property
|
||||
def system_prompt_static(self) -> str:
|
||||
"""Static prefix of the system prompt (cache-stable).
|
||||
|
||||
Equals ``system_prompt`` when no split is in use. When the AgentLoop
|
||||
calls ``update_system_prompt(static, dynamic_suffix=...)``, this is
|
||||
the piece sent as the cache-controlled first block.
|
||||
"""
|
||||
return self._system_prompt_static
|
||||
|
||||
@property
|
||||
def system_prompt_dynamic_suffix(self) -> str:
|
||||
"""Dynamic tail of the system prompt (not cached).
|
||||
|
||||
Empty unless the consumer splits its prompt. The LLM wrapper uses a
|
||||
non-empty suffix to emit a two-block system content list with a
|
||||
cache breakpoint between the static prefix and this tail.
|
||||
"""
|
||||
return self._system_prompt_dynamic_suffix
|
||||
|
||||
def update_system_prompt(self, new_prompt: str, dynamic_suffix: str | None = None) -> None:
|
||||
"""Update the system prompt.
|
||||
|
||||
Used in continuous conversation mode at phase transitions to swap
|
||||
Layer 3 (focus) while preserving the conversation history.
|
||||
|
||||
When ``dynamic_suffix`` is provided, ``new_prompt`` is interpreted as
|
||||
the STATIC prefix and ``dynamic_suffix`` as the per-turn tail; they
|
||||
travel to the LLM as two separate cache-controlled blocks but are
|
||||
persisted as a single concatenated string for backward-compat
|
||||
restore. ``new_prompt`` alone (suffix left None) keeps the legacy
|
||||
single-string behavior.
|
||||
"""
|
||||
self._system_prompt = new_prompt
|
||||
if dynamic_suffix is None:
|
||||
# Legacy single-string path — static == full, no suffix split.
|
||||
self._system_prompt = new_prompt
|
||||
self._system_prompt_static = new_prompt
|
||||
self._system_prompt_dynamic_suffix = ""
|
||||
else:
|
||||
self._system_prompt_static = new_prompt
|
||||
self._system_prompt_dynamic_suffix = dynamic_suffix
|
||||
self._system_prompt = f"{new_prompt}\n\n{dynamic_suffix}" if dynamic_suffix else new_prompt
|
||||
self._meta_persisted = False # re-persist with new prompt
|
||||
|
||||
def set_current_phase(self, phase_id: str) -> None:
|
||||
@@ -847,19 +904,30 @@ class NodeConversation:
|
||||
"""True when the conversation should be compacted before the
|
||||
next LLM call.
|
||||
|
||||
Buffer-based rule (Gap 7): trigger when the current estimate
|
||||
plus the configured buffer would exceed the hard context limit.
|
||||
Prevents compaction from firing only AFTER we're already over
|
||||
the wire and forced into a reactive binary-split pass.
|
||||
Hybrid buffer rule: the headroom reserved before compaction fires
|
||||
is the SUM of an absolute fixed component and a ratio of the hard
|
||||
context limit:
|
||||
|
||||
When no buffer is configured, falls back to the multiplicative
|
||||
threshold the old callers were built around.
|
||||
effective_buffer = compaction_buffer_tokens
|
||||
+ compaction_buffer_ratio * max_context_tokens
|
||||
|
||||
The fixed component gives a floor on tiny windows; the ratio
|
||||
keeps the trigger meaningful on large windows where any constant
|
||||
buffer becomes a rounding error (an 8k buffer is 75% on a 32k
|
||||
window but 96% on a 200k window). Compaction fires when the
|
||||
current estimate would consume more than (limit - effective_buffer).
|
||||
|
||||
When neither component is configured, falls back to the legacy
|
||||
multiplicative threshold so old callers keep behaving identically.
|
||||
"""
|
||||
if self._max_context_tokens <= 0:
|
||||
return False
|
||||
if self._compaction_buffer_tokens is not None:
|
||||
budget = self._max_context_tokens - self._compaction_buffer_tokens
|
||||
return self.estimate_tokens() >= max(0, budget)
|
||||
fixed = self._compaction_buffer_tokens
|
||||
ratio = self._compaction_buffer_ratio
|
||||
if fixed is not None or ratio is not None:
|
||||
effective_buffer = (fixed or 0) + (ratio or 0.0) * self._max_context_tokens
|
||||
budget = self._max_context_tokens - effective_buffer
|
||||
return self.estimate_tokens() >= max(0.0, budget)
|
||||
return self.estimate_tokens() >= self._max_context_tokens * self._compaction_threshold
|
||||
|
||||
def compaction_warning(self) -> bool:
|
||||
@@ -1516,6 +1584,7 @@ class NodeConversation:
|
||||
"max_context_tokens": self._max_context_tokens,
|
||||
"compaction_threshold": self._compaction_threshold,
|
||||
"compaction_buffer_tokens": self._compaction_buffer_tokens,
|
||||
"compaction_buffer_ratio": self._compaction_buffer_ratio,
|
||||
"compaction_warning_buffer_tokens": (self._compaction_warning_buffer_tokens),
|
||||
"output_keys": self._output_keys,
|
||||
}
|
||||
@@ -1565,6 +1634,7 @@ class NodeConversation:
|
||||
store=store,
|
||||
run_id=run_id,
|
||||
compaction_buffer_tokens=meta.get("compaction_buffer_tokens"),
|
||||
compaction_buffer_ratio=meta.get("compaction_buffer_ratio"),
|
||||
compaction_warning_buffer_tokens=meta.get("compaction_warning_buffer_tokens"),
|
||||
)
|
||||
conv._meta_persisted = True
|
||||
|
||||
@@ -12,6 +12,7 @@ import json
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from framework.agent_loop.conversation import ConversationStore, NodeConversation
|
||||
@@ -191,15 +192,21 @@ async def drain_injection_queue(
|
||||
else:
|
||||
logger.info("[drain] no vision fallback available; images dropped")
|
||||
image_content = None
|
||||
# Real user input is stored as-is; external events get a prefix
|
||||
# Stamp every injected event with its arrival time so the model
|
||||
# has a consistent temporal log to reason over (and so the
|
||||
# stamp lives inside byte-stable conversation history instead
|
||||
# of a per-turn system-prompt tail). Minute precision is what
|
||||
# the queen needs for conversational / scheduling context.
|
||||
stamp = datetime.now().astimezone().strftime("%Y-%m-%d %H:%M %Z")
|
||||
if is_client_input:
|
||||
stamped = f"[{stamp}] {content}" if content else f"[{stamp}]"
|
||||
await conversation.add_user_message(
|
||||
content,
|
||||
stamped,
|
||||
is_client_input=True,
|
||||
image_content=image_content,
|
||||
)
|
||||
else:
|
||||
await conversation.add_user_message(f"[External event]: {content}")
|
||||
await conversation.add_user_message(f"[{stamp}] [External event] {content}")
|
||||
count += 1
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
@@ -232,7 +239,8 @@ async def drain_trigger_queue(
|
||||
payload_str = json.dumps(t.payload, default=str)
|
||||
parts.append(f"[TRIGGER: {t.trigger_type}/{t.source_id}]{task_line}\n{payload_str}")
|
||||
|
||||
combined = "\n\n".join(parts)
|
||||
stamp = datetime.now().astimezone().strftime("%Y-%m-%d %H:%M %Z")
|
||||
combined = f"[{stamp}]\n" + "\n\n".join(parts)
|
||||
logger.info("[drain] %d trigger(s): %s", len(triggers), combined[:200])
|
||||
# Tag the message so the UI can render a banner instead of the raw
|
||||
# `[TRIGGER: ...]` text. The LLM still sees `combined` verbatim.
|
||||
|
||||
@@ -108,6 +108,8 @@ async def publish_llm_turn_complete(
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cached_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
cost_usd: float = 0.0,
|
||||
execution_id: str = "",
|
||||
iteration: int | None = None,
|
||||
) -> None:
|
||||
@@ -120,6 +122,8 @@ async def publish_llm_turn_complete(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cached_tokens=cached_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
cost_usd=cost_usd,
|
||||
execution_id=execution_id,
|
||||
iteration=iteration,
|
||||
)
|
||||
|
||||
@@ -69,6 +69,20 @@ class LoopConfig:
|
||||
# and less tight than Anthropic's own counting. Override via
|
||||
# LoopConfig for larger windows.
|
||||
compaction_buffer_tokens: int = 8_000
|
||||
# Ratio-based component of the hybrid compaction buffer. Effective
|
||||
# headroom reserved before compaction fires is
|
||||
# compaction_buffer_tokens + compaction_buffer_ratio * max_context_tokens
|
||||
# The ratio scales with the model's window where the absolute fixed
|
||||
# component does not (an 8k absolute buffer is 75% trigger on a 32k
|
||||
# window but 96% on a 200k window). Combining them gives an absolute
|
||||
# floor sized for the worst-case single tool result (one un-spilled
|
||||
# max_tool_result_chars payload ≈ 30k chars ≈ 7.5k tokens, rounded to
|
||||
# 8k) plus a fractional headroom that keeps the trigger meaningful on
|
||||
# large windows, so the inner tool loop always has room to grow
|
||||
# without tripping the mid-turn pre-send guard. Defaults: 8k + 15%.
|
||||
# On 32k that's a 12.8k buffer (~60% trigger); on 200k it's 38k
|
||||
# (~81% trigger); on 1M it's 158k (~84% trigger).
|
||||
compaction_buffer_ratio: float = 0.15
|
||||
# Warning is emitted one buffer earlier so the user/telemetry gets
|
||||
# a "we're close" signal without triggering a compaction pass.
|
||||
compaction_warning_buffer_tokens: int = 12_000
|
||||
|
||||
@@ -0,0 +1,247 @@
|
||||
"""Vision-fallback subagent for tool-result images on text-only LLMs.
|
||||
|
||||
When a tool returns image content but the main agent's model can't
|
||||
accept image blocks (i.e. its catalog entry has ``supports_vision: false``),
|
||||
the framework strips the images before they ever reach the LLM. Without
|
||||
this module, the agent then sees only the tool's text envelope (URL,
|
||||
dimensions, size) and is blind to whatever the image actually shows.
|
||||
|
||||
This module provides:
|
||||
|
||||
* ``caption_tool_image()`` — direct LiteLLM call to a configured
|
||||
vision model (``vision_fallback`` block in ``~/.hive/configuration.json``)
|
||||
that takes the agent's intent + the image(s) and returns a textual
|
||||
description tailored to that intent.
|
||||
* ``extract_intent_for_tool()`` — pull the most recent assistant text
|
||||
+ the tool call descriptor and concatenate them into a ≤2KB intent
|
||||
string the vision subagent can reason against.
|
||||
|
||||
Both helpers degrade silently — return ``None`` / a placeholder rather
|
||||
than raise — so a vision-fallback failure can never kill the main
|
||||
agent's run. The agent-loop call site is responsible for chaining
|
||||
through to the existing generic-caption rotation
|
||||
(``_describe_images_as_text``) on a None return.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from framework.config import (
|
||||
get_vision_fallback_api_base,
|
||||
get_vision_fallback_api_key,
|
||||
get_vision_fallback_model,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..conversation import NodeConversation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Hard cap on the intent string handed to the vision subagent. The
|
||||
# subagent only needs the agent's recent reasoning + the tool descriptor;
|
||||
# anything longer is wasted tokens (and risks pushing past the vision
|
||||
# model's context with the image attached).
|
||||
_INTENT_MAX_CHARS = 4096
|
||||
|
||||
# Cap on the tool args JSON snippet inside the intent. Some tool inputs
|
||||
# (large strings, file contents) would dominate the intent if uncapped.
|
||||
_TOOL_ARGS_MAX_CHARS = 4096
|
||||
|
||||
# Subagent system prompt — kept short so it fits within any provider's
|
||||
# system-prompt budget alongside the user message + image. Tells the
|
||||
# subagent its role and constrains output format.
|
||||
#
|
||||
# Coordinate labeling: the main agent's browser tools
|
||||
# (browser_click_coordinate / browser_hover_coordinate / browser_press_at)
|
||||
# accept VIEWPORT FRACTIONS (x, y) in [0..1] where (0,0) is the top-left
|
||||
# and (1,1) is the bottom-right of the screenshot. Without coordinates
|
||||
# the text-only agent has no way to act on what we describe — it can
|
||||
# read the caption but cannot point. So for every interactive element
|
||||
# we name (button, link, input, icon, tab, menu item, dialog control),
|
||||
# include its approximate viewport-fraction centre as ``(fx, fy)``
|
||||
# right after the element's name, e.g. ``"Submit" button (0.83, 0.92)``.
|
||||
# Three rules: (1) coordinates only for things plausibly clickable /
|
||||
# hoverable / typeable — don't tag pure body text or decorative
|
||||
# graphics. (2) Eyeball to two decimal places; precision beyond that
|
||||
# is false confidence. (3) Never invent — if an element is partly
|
||||
# off-screen or you can't locate it, omit the coordinate rather than
|
||||
# guessing.
|
||||
_VISION_SUBAGENT_SYSTEM = (
|
||||
"You are a vision subagent for a text-only main agent. The main "
|
||||
"agent invoked a tool that returned the image(s) attached. Their "
|
||||
"intent (their reasoning + the tool call) is below. Describe what "
|
||||
"the image shows in service of their intent — concrete, factual, "
|
||||
"no speculation. If their intent asks a yes/no question, answer it "
|
||||
"directly first.\n\n"
|
||||
"Coordinate labeling: the main agent uses fractional viewport "
|
||||
"coordinates (x, y) in [0..1] — (0, 0) is the top-left of the "
|
||||
"image, (1, 1) is the bottom-right — to drive its click / hover / "
|
||||
"key-press tools. For every interactive element you mention "
|
||||
"(button, link, input, checkbox, radio, dropdown, tab, menu item, "
|
||||
"dialog control, icon), append its approximate centre as "
|
||||
"``(fx, fy)`` immediately after the element's name or label, e.g. "
|
||||
'``"Submit" button (0.83, 0.92)`` or ``profile avatar icon '
|
||||
"(0.05, 0.07)``. Use two decimal places — more is false precision. "
|
||||
"Skip coordinates for pure body text and decorative elements that "
|
||||
"aren't clickable. If an element is partially off-screen or you "
|
||||
"cannot reliably locate its centre, omit the coordinate rather "
|
||||
"than guessing.\n\n"
|
||||
"Output plain text, no markdown, ≤ 600 words."
|
||||
)
|
||||
|
||||
|
||||
def extract_intent_for_tool(
|
||||
conversation: NodeConversation,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any] | None,
|
||||
) -> str:
|
||||
"""Build the intent string passed to the vision subagent.
|
||||
|
||||
Combines the most recent assistant text (the LLM's reasoning right
|
||||
before invoking the tool) with a structured tool-call descriptor.
|
||||
Truncates to ``_INTENT_MAX_CHARS`` total, favouring the head of the
|
||||
assistant text where goal-stating sentences usually live.
|
||||
|
||||
If no preceding assistant text exists (rare — first turn), falls
|
||||
back to ``"<no preceding reasoning>"`` so the subagent still gets
|
||||
the tool descriptor.
|
||||
"""
|
||||
args_json: str
|
||||
try:
|
||||
args_json = json.dumps(tool_args or {}, default=str)
|
||||
except Exception:
|
||||
args_json = repr(tool_args)
|
||||
if len(args_json) > _TOOL_ARGS_MAX_CHARS:
|
||||
args_json = args_json[:_TOOL_ARGS_MAX_CHARS] + "…"
|
||||
|
||||
tool_line = f"Called: {tool_name}({args_json})"
|
||||
|
||||
# Walk newest → oldest, take the first assistant message with text.
|
||||
assistant_text = ""
|
||||
try:
|
||||
messages = getattr(conversation, "_messages", []) or []
|
||||
for msg in reversed(messages):
|
||||
if getattr(msg, "role", None) != "assistant":
|
||||
continue
|
||||
content = getattr(msg, "content", "") or ""
|
||||
if isinstance(content, str) and content.strip():
|
||||
assistant_text = content.strip()
|
||||
break
|
||||
except Exception:
|
||||
# Defensive — the agent loop must keep running even if the
|
||||
# conversation structure changes shape.
|
||||
assistant_text = ""
|
||||
|
||||
if not assistant_text:
|
||||
assistant_text = "<no preceding reasoning>"
|
||||
|
||||
# Intent = tool descriptor (always intact) + reasoning (truncated).
|
||||
head = f"{tool_line}\n\nReasoning before call:\n"
|
||||
budget = _INTENT_MAX_CHARS - len(head)
|
||||
if budget < 100:
|
||||
# Tool descriptor is huge somehow — truncate it.
|
||||
return head[:_INTENT_MAX_CHARS]
|
||||
if len(assistant_text) > budget:
|
||||
assistant_text = assistant_text[: budget - 1] + "…"
|
||||
return head + assistant_text
|
||||
|
||||
|
||||
async def caption_tool_image(
|
||||
intent: str,
|
||||
image_content: list[dict[str, Any]],
|
||||
*,
|
||||
timeout_s: float = 30.0,
|
||||
) -> str | None:
|
||||
"""Caption the given images using the configured ``vision_fallback`` model.
|
||||
|
||||
Returns the model's text response on success, or ``None`` on any
|
||||
failure (no config, no API key, timeout, exception, empty
|
||||
response). Callers chain to the next stage of the fallback on None.
|
||||
|
||||
Logs each call to ``~/.hive/llm_logs`` via ``log_llm_turn`` so the
|
||||
cost / latency / quality are auditable post-hoc, tagged with
|
||||
``execution_id="vision_fallback_subagent"``.
|
||||
"""
|
||||
model = get_vision_fallback_model()
|
||||
if not model:
|
||||
return None
|
||||
|
||||
api_key = get_vision_fallback_api_key()
|
||||
api_base = get_vision_fallback_api_base()
|
||||
if not api_key:
|
||||
logger.debug("vision_fallback configured but no API key resolved; skipping")
|
||||
return None
|
||||
|
||||
try:
|
||||
import litellm
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
user_blocks: list[dict[str, Any]] = [{"type": "text", "text": intent}]
|
||||
user_blocks.extend(image_content)
|
||||
messages = [
|
||||
{"role": "system", "content": _VISION_SUBAGENT_SYSTEM},
|
||||
{"role": "user", "content": user_blocks},
|
||||
]
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": 1024,
|
||||
"timeout": timeout_s,
|
||||
"api_key": api_key,
|
||||
}
|
||||
if api_base:
|
||||
kwargs["api_base"] = api_base
|
||||
|
||||
started = datetime.now()
|
||||
caption: str | None = None
|
||||
error_text: str | None = None
|
||||
try:
|
||||
response = await litellm.acompletion(**kwargs)
|
||||
text = (response.choices[0].message.content or "").strip()
|
||||
if text:
|
||||
caption = text
|
||||
except Exception as exc:
|
||||
error_text = f"{type(exc).__name__}: {exc}"
|
||||
logger.debug("vision_fallback model '%s' failed: %s", model, exc)
|
||||
|
||||
# Best-effort audit log so users can grep ~/.hive/llm_logs/ for
|
||||
# vision-fallback subagent calls. Failures here must not bubble.
|
||||
try:
|
||||
from framework.tracker.llm_debug_logger import log_llm_turn
|
||||
|
||||
# Don't dump the base64 image data into the log file — that
|
||||
# would balloon the jsonl with mostly-binary noise.
|
||||
elided_blocks: list[dict[str, Any]] = [{"type": "text", "text": intent}]
|
||||
elided_blocks.extend({"type": "image_url", "image_url": {"url": "<elided>"}} for _ in range(len(image_content)))
|
||||
log_llm_turn(
|
||||
node_id="vision_fallback_subagent",
|
||||
stream_id="vision_fallback",
|
||||
execution_id="vision_fallback_subagent",
|
||||
iteration=0,
|
||||
system_prompt=_VISION_SUBAGENT_SYSTEM,
|
||||
messages=[{"role": "user", "content": elided_blocks}],
|
||||
assistant_text=caption or "",
|
||||
tool_calls=[],
|
||||
tool_results=[],
|
||||
token_counts={
|
||||
"model": model,
|
||||
"elapsed_s": (datetime.now() - started).total_seconds(),
|
||||
"error": error_text,
|
||||
"num_images": len(image_content),
|
||||
"intent_chars": len(intent),
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return caption
|
||||
|
||||
|
||||
__all__ = ["caption_tool_image", "extract_intent_for_tool"]
|
||||
@@ -53,7 +53,14 @@ def build_prompt_spec(
|
||||
# trigger tools are present in this agent's tool list (e.g. browser_*
|
||||
# pulls in hive.browser-automation). Keeps non-browser agents lean.
|
||||
tool_names = [getattr(t, "name", "") for t in (getattr(ctx, "available_tools", None) or [])]
|
||||
skills_catalog_prompt = augment_catalog_for_tools(ctx.skills_catalog_prompt or "", tool_names)
|
||||
raw_catalog = ctx.skills_catalog_prompt or ""
|
||||
dynamic_catalog = getattr(ctx, "dynamic_skills_catalog_provider", None)
|
||||
if dynamic_catalog is not None:
|
||||
try:
|
||||
raw_catalog = dynamic_catalog() or ""
|
||||
except Exception:
|
||||
raw_catalog = ctx.skills_catalog_prompt or ""
|
||||
skills_catalog_prompt = augment_catalog_for_tools(raw_catalog, tool_names)
|
||||
|
||||
return PromptSpec(
|
||||
identity_prompt=ctx.identity_prompt or "",
|
||||
|
||||
@@ -180,9 +180,39 @@ class AgentContext:
|
||||
|
||||
stream_id: str = ""
|
||||
|
||||
# ----- Task system fields (see framework/tasks) -------------------
|
||||
# task_list_id: this agent's own session-scoped list, e.g.
|
||||
# session:{agent_id}:{session_id}. Set by the runner / ColonyRuntime
|
||||
# before the loop starts; immutable after first task_create.
|
||||
task_list_id: str | None = None
|
||||
# colony_id: set on the queen of a colony AND on every spawned worker
|
||||
# so workers can render the "picked up" chip and the queen can address
|
||||
# her colony template via colony_template_* tools.
|
||||
colony_id: str | None = None
|
||||
# picked_up_from: for workers, the (colony_task_list_id, template_task_id)
|
||||
# pair their session was spawned for. None for the queen and queen-DM.
|
||||
picked_up_from: tuple[str, int] | None = None
|
||||
|
||||
dynamic_tools_provider: Any = None
|
||||
dynamic_prompt_provider: Any = None
|
||||
# Optional Callable[[], str]: when set alongside ``dynamic_prompt_provider``,
|
||||
# the AgentLoop sends the system prompt as two pieces — the result of
|
||||
# ``dynamic_prompt_provider`` is the STATIC block (cached), and this
|
||||
# provider returns the DYNAMIC suffix (not cached). The LLM wrapper
|
||||
# emits them as two Anthropic system content blocks with a cache
|
||||
# breakpoint between them for providers that honor ``cache_control``.
|
||||
# For providers that don't, the two strings are concatenated. Used by
|
||||
# the Queen to keep her persona/role/tools block warm across iterations
|
||||
# while the recall + timestamp tail refreshes per user turn.
|
||||
dynamic_prompt_suffix_provider: Any = None
|
||||
dynamic_memory_provider: Any = None
|
||||
# Optional Callable[[], str]: when set, the current skills-catalog
|
||||
# prompt is sourced from this provider each iteration. Lets workers
|
||||
# pick up UI toggles without restarting the run. Queen agents already
|
||||
# rebuild the whole prompt via dynamic_prompt_provider — this field
|
||||
# is a surgical alternative used by colony workers where the rest of
|
||||
# the prompt stays constant and we don't want to thrash the cache.
|
||||
dynamic_skills_catalog_provider: Any = None
|
||||
|
||||
skills_catalog_prompt: str = ""
|
||||
protocols_prompt: str = ""
|
||||
|
||||
@@ -224,6 +224,11 @@ user decide next steps. Read generated files or worker reports with \
|
||||
read_file when the user asks for specifics. If the user wants \
|
||||
another pass, kick it off with run_parallel_workers; otherwise stay \
|
||||
conversational.
|
||||
|
||||
If the review itself is multi-step (e.g. "verify each worker's output, \
|
||||
then draft a summary, then propose next steps"), lay it out upfront \
|
||||
with `task_create_batch` and walk through with `task_update`. Skip the \
|
||||
ceremony for a single-paragraph summary.
|
||||
"""
|
||||
|
||||
|
||||
@@ -234,6 +239,18 @@ conversational.
|
||||
_queen_tools_independent = """
|
||||
# Tools (INDEPENDENT mode)
|
||||
|
||||
## Planning — use FIRST for multi-step work
|
||||
- task_create_batch — When a request has 3+ atomic steps, your FIRST \
|
||||
tool call is `task_create_batch` with one entry per step (atomic, \
|
||||
one round-trip). Use this for the upfront plan, NOT five separate \
|
||||
`task_create` calls.
|
||||
- task_create — One-off mid-run additions when you discover \
|
||||
unplanned work AFTER the initial plan is laid out.
|
||||
- task_update / task_list / task_get — Mark progress, inspect, or \
|
||||
re-read state.
|
||||
|
||||
See "Independent execution" for the per-step flow and granularity rule.
|
||||
|
||||
## File I/O (coder-tools MCP)
|
||||
- read_file, write_file, edit_file, hashline_edit, list_directory, \
|
||||
search_files, run_command, undo_changes
|
||||
@@ -401,14 +418,36 @@ asks for specifics. Do not invent a new pass unless the user asks for one.
|
||||
_queen_behavior_independent = """
|
||||
## Independent execution
|
||||
|
||||
You are the agent. Do one real inline instance before any scaling — \
|
||||
open the browser, call the real API, write to the real file. If the \
|
||||
action is irreversible or touches shared systems, show and confirm \
|
||||
before executing. Report concrete evidence (actual output, what \
|
||||
worked / failed) after the run. Scale order once inline succeeds: \
|
||||
repeat inline (≤10 items) → `run_parallel_workers` (batch, results \
|
||||
now) → `create_colony` (recurring / background). Conceptual or \
|
||||
strategic questions: answer directly, skip execution.
|
||||
You are the agent. **For multi-step work (3+ atomic actions): your FIRST \
|
||||
tool call is `task_create_batch`** with one entry per atomic action, \
|
||||
before you touch any other tool. (One call, atomic — not N separate \
|
||||
`task_create` calls.) Then work the list one task at a time:
|
||||
|
||||
1. `task_update` → in_progress before you start the step.
|
||||
2. Do one real inline instance — open the browser, call the real API, \
|
||||
write to the real file. If the action is irreversible or touches \
|
||||
shared systems, show and confirm before executing. Report concrete \
|
||||
evidence (actual output, what worked / failed) after the run.
|
||||
3. `task_update` → completed THE MOMENT it's done. **Do not let \
|
||||
multiple finished tasks pile up unmarked.** There is no batch update \
|
||||
tool by design — each `completed` transition is a discrete progress \
|
||||
heartbeat in the user's right-rail panel. Without those transitions \
|
||||
the panel shows a hung spinner no matter how much real work you got \
|
||||
done.
|
||||
|
||||
**Granularity: one task per atomic action, not one umbrella per project.** \
|
||||
Replying to 5 posts is 5 tasks, not 1. Crawling 3 sites is 3 tasks. \
|
||||
An umbrella task that stays `in_progress` for the whole run looks \
|
||||
identical to the user as "the queen is stuck".
|
||||
|
||||
Once one task succeeds inline, scale order for the rest of that task's \
|
||||
work: repeat inline (≤10 items) → `run_parallel_workers` (batch, \
|
||||
results now) → `create_colony` (recurring / background).
|
||||
|
||||
For conceptual or strategic questions, single-tool-call work, \
|
||||
greetings, or chat: answer directly in prose. Skip `task_*`, skip the \
|
||||
planning ceremony — the bar is "real multi-step work the user benefits \
|
||||
from seeing tracked", not "anything you reply to".
|
||||
"""
|
||||
|
||||
_queen_behavior_always = """
|
||||
|
||||
@@ -100,8 +100,9 @@ DEFAULT_QUEENS: dict[str, dict[str, Any]] = {
|
||||
"<relationship>Returning user — check recall memory for name, role, "
|
||||
"and what we last worked on. Weave it in.</relationship>\n"
|
||||
"<context>Bare greeting. No new task stated. Either picking up a "
|
||||
"thread or about to bring something new. Don't presume, don't call "
|
||||
"tools, just open the door.</context>\n"
|
||||
"thread or about to bring something new. Don't presume — start "
|
||||
"planning and tool use only after the user specifies a task. Just "
|
||||
"open the door.</context>\n"
|
||||
"<sentiment>Warm recognition if I know them. If memory is empty, "
|
||||
"still warm — but shift to role-forward framing.</sentiment>\n"
|
||||
"<physical_state>Looking up from the terminal, half-smile. Turning to face them.</physical_state>\n"
|
||||
@@ -252,8 +253,9 @@ DEFAULT_QUEENS: dict[str, dict[str, Any]] = {
|
||||
"role, and the cohort work we last touched. Weave it in."
|
||||
"</relationship>\n"
|
||||
"<context>Bare greeting. No new task stated. Could be a retention "
|
||||
"follow-up or a new question entirely. Don't presume, don't call "
|
||||
"tools.</context>\n"
|
||||
"follow-up or a new question entirely. Don't presume — start "
|
||||
"planning and tool use only after the user specifies a task."
|
||||
"</context>\n"
|
||||
"<sentiment>Curious warmth. Every returning conversation is a "
|
||||
"chance to see what the data says now.</sentiment>\n"
|
||||
"<physical_state>Leaning back from the dashboard, pulling off reading glasses.</physical_state>\n"
|
||||
@@ -383,8 +385,9 @@ DEFAULT_QUEENS: dict[str, dict[str, Any]] = {
|
||||
"the user research thread we were on. Pull it into the greeting."
|
||||
"</relationship>\n"
|
||||
"<context>Bare greeting. No new task yet. Could be picking up the "
|
||||
"research thread or bringing something fresh. Don't presume, "
|
||||
"don't call tools.</context>\n"
|
||||
"research thread or bringing something fresh. Don't presume — "
|
||||
"start planning and tool use only after the user specifies a task."
|
||||
"</context>\n"
|
||||
"<sentiment>Warm, curious. Every returning conversation is a "
|
||||
"chance to hear what the users actually did.</sentiment>\n"
|
||||
"<physical_state>Closing the interview notes, turning fully to face them.</physical_state>\n"
|
||||
|
||||
@@ -0,0 +1,217 @@
|
||||
"""Per-queen tool configuration sidecar (``tools.json``).
|
||||
|
||||
Lives at ``~/.hive/agents/queens/{queen_id}/tools.json`` alongside
|
||||
``profile.yaml``. Kept separate so identity (name, title, core traits)
|
||||
stays human-authored and lean, while the machine-managed tool allowlist
|
||||
can grow (per-tool overrides, audit timestamps, future per-phase rules)
|
||||
without bloating the profile.
|
||||
|
||||
Schema::
|
||||
|
||||
{
|
||||
"enabled_mcp_tools": ["read_file", ...] | null,
|
||||
"updated_at": "2026-04-21T12:34:56+00:00"
|
||||
}
|
||||
|
||||
- ``null`` / missing file → default "allow every MCP tool".
|
||||
- ``[]`` → explicitly disable every MCP tool.
|
||||
- ``["foo", "bar"]`` → only those MCP tool names pass the filter.
|
||||
|
||||
Atomic writes via ``os.replace`` follow the same pattern as
|
||||
``framework.host.colony_metadata.update_colony_metadata``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
from framework.config import QUEENS_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def tools_config_path(queen_id: str) -> Path:
|
||||
"""Return the on-disk path to a queen's ``tools.json``."""
|
||||
return QUEENS_DIR / queen_id / "tools.json"
|
||||
|
||||
|
||||
def _atomic_write_json(path: Path, data: dict[str, Any]) -> None:
|
||||
"""Write ``data`` to ``path`` atomically via tempfile + replace."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd, tmp = tempfile.mkstemp(
|
||||
prefix=".tools.",
|
||||
suffix=".json.tmp",
|
||||
dir=str(path.parent),
|
||||
)
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as fh:
|
||||
json.dump(data, fh, indent=2)
|
||||
fh.flush()
|
||||
os.fsync(fh.fileno())
|
||||
os.replace(tmp, path)
|
||||
except BaseException:
|
||||
try:
|
||||
os.unlink(tmp)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
|
||||
def _migrate_from_profile_if_needed(queen_id: str) -> list[str] | None:
|
||||
"""Hoist a legacy ``enabled_mcp_tools`` field out of ``profile.yaml``.
|
||||
|
||||
Returns the migrated value (or ``None`` if nothing to migrate). After
|
||||
migration the sidecar exists on disk and the profile YAML no longer
|
||||
contains ``enabled_mcp_tools``. Safe to call repeatedly.
|
||||
"""
|
||||
profile_path = QUEENS_DIR / queen_id / "profile.yaml"
|
||||
if not profile_path.exists():
|
||||
return None
|
||||
try:
|
||||
data = yaml.safe_load(profile_path.read_text(encoding="utf-8"))
|
||||
except (yaml.YAMLError, OSError):
|
||||
logger.warning("Could not read profile.yaml during tools migration: %s", queen_id)
|
||||
return None
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
if "enabled_mcp_tools" not in data:
|
||||
return None
|
||||
|
||||
raw = data.pop("enabled_mcp_tools")
|
||||
enabled: list[str] | None
|
||||
if raw is None:
|
||||
enabled = None
|
||||
elif isinstance(raw, list) and all(isinstance(x, str) for x in raw):
|
||||
enabled = raw
|
||||
else:
|
||||
logger.warning(
|
||||
"Legacy enabled_mcp_tools on queen %s had unexpected shape %r; dropping",
|
||||
queen_id,
|
||||
raw,
|
||||
)
|
||||
enabled = None
|
||||
|
||||
# Write sidecar first, then rewrite profile — if the second step
|
||||
# fails we still have the config available and won't re-migrate.
|
||||
_atomic_write_json(
|
||||
tools_config_path(queen_id),
|
||||
{
|
||||
"enabled_mcp_tools": enabled,
|
||||
"updated_at": datetime.now(UTC).isoformat(),
|
||||
},
|
||||
)
|
||||
profile_path.write_text(
|
||||
yaml.safe_dump(data, sort_keys=False, allow_unicode=True),
|
||||
encoding="utf-8",
|
||||
)
|
||||
logger.info(
|
||||
"Migrated enabled_mcp_tools for queen %s from profile.yaml to tools.json",
|
||||
queen_id,
|
||||
)
|
||||
return enabled
|
||||
|
||||
|
||||
def tools_config_exists(queen_id: str) -> bool:
|
||||
"""Return True when the queen has a persisted ``tools.json`` sidecar.
|
||||
|
||||
Used by callers that need to tell an explicit user save apart from a
|
||||
fallthrough to the role-based default (both can return the same
|
||||
value from ``load_queen_tools_config``).
|
||||
"""
|
||||
return tools_config_path(queen_id).exists()
|
||||
|
||||
|
||||
def delete_queen_tools_config(queen_id: str) -> bool:
|
||||
"""Delete the queen's ``tools.json`` sidecar if present.
|
||||
|
||||
Returns ``True`` if a file was removed, ``False`` if none existed.
|
||||
The next ``load_queen_tools_config`` call falls through to the
|
||||
role-based default (or allow-all for unknown queens).
|
||||
"""
|
||||
path = tools_config_path(queen_id)
|
||||
if not path.exists():
|
||||
return False
|
||||
try:
|
||||
path.unlink()
|
||||
return True
|
||||
except OSError:
|
||||
logger.warning("Failed to delete %s", path, exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
def load_queen_tools_config(
|
||||
queen_id: str,
|
||||
mcp_catalog: dict[str, list[dict]] | None = None,
|
||||
) -> list[str] | None:
|
||||
"""Return the queen's MCP tool allowlist, or ``None`` for default-allow.
|
||||
|
||||
Order of resolution:
|
||||
1. ``tools.json`` sidecar (authoritative; user has saved).
|
||||
2. Legacy ``profile.yaml`` field (migrated and deleted on first read).
|
||||
3. Role-based default from ``queen_tools_defaults`` when the queen
|
||||
is in the known persona table. ``mcp_catalog`` lets the helper
|
||||
expand ``@server:NAME`` shorthands; without it, shorthand entries
|
||||
are dropped.
|
||||
4. ``None`` — default "allow every MCP tool".
|
||||
"""
|
||||
path = tools_config_path(queen_id)
|
||||
if path.exists():
|
||||
try:
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError):
|
||||
logger.warning("Invalid %s; treating as default-allow", path)
|
||||
return None
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
raw = data.get("enabled_mcp_tools")
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, list) and all(isinstance(x, str) for x in raw):
|
||||
return raw
|
||||
logger.warning("Unexpected enabled_mcp_tools shape in %s; ignoring", path)
|
||||
return None
|
||||
|
||||
migrated = _migrate_from_profile_if_needed(queen_id)
|
||||
if migrated is not None:
|
||||
return migrated
|
||||
# If migration just hoisted an explicit ``null`` out of profile.yaml,
|
||||
# a sidecar with allow-all semantics now exists on disk. Honor that
|
||||
# over the role default so an explicit user choice wins.
|
||||
if tools_config_path(queen_id).exists():
|
||||
return None
|
||||
|
||||
# No sidecar, nothing to migrate — fall back to role-based default.
|
||||
from framework.agents.queen.queen_tools_defaults import resolve_queen_default_tools
|
||||
|
||||
return resolve_queen_default_tools(queen_id, mcp_catalog)
|
||||
|
||||
|
||||
def update_queen_tools_config(
|
||||
queen_id: str,
|
||||
enabled_mcp_tools: list[str] | None,
|
||||
) -> list[str] | None:
|
||||
"""Persist the queen's MCP allowlist to ``tools.json``.
|
||||
|
||||
Raises ``FileNotFoundError`` if the queen's directory is missing —
|
||||
we refuse to silently create a sidecar for a queen that doesn't
|
||||
exist.
|
||||
"""
|
||||
queen_dir = QUEENS_DIR / queen_id
|
||||
if not queen_dir.exists():
|
||||
raise FileNotFoundError(f"Queen directory not found: {queen_id}")
|
||||
_atomic_write_json(
|
||||
tools_config_path(queen_id),
|
||||
{
|
||||
"enabled_mcp_tools": enabled_mcp_tools,
|
||||
"updated_at": datetime.now(UTC).isoformat(),
|
||||
},
|
||||
)
|
||||
return enabled_mcp_tools
|
||||
@@ -0,0 +1,272 @@
|
||||
"""Role-based default tool allowlists for queens.
|
||||
|
||||
Every queen inherits the same MCP surface (all servers loaded for the
|
||||
queen agent), but exposing 94+ tools to every persona clutters the LLM
|
||||
tool catalog and wastes prompt tokens. This module defines a sensible
|
||||
default allowlist per queen persona so, e.g., Head of Legal doesn't
|
||||
see port scanners and Head of Finance doesn't see ``apply_patch``.
|
||||
|
||||
Defaults apply only when the queen has no ``tools.json`` sidecar — the
|
||||
moment the user saves an allowlist through the Tool Library, the
|
||||
sidecar becomes authoritative. A DELETE on the tools endpoint removes
|
||||
the sidecar and brings the queen back to her role default.
|
||||
|
||||
Category entries support a ``@server:NAME`` shorthand that expands to
|
||||
every tool name registered against that MCP server in the current
|
||||
catalog. This keeps the category table short and drift-free when new
|
||||
tools are added (e.g. browser_* auto-joins the ``browser`` category).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Categories — reusable bundles of MCP tool names.
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Each category is a flat list of either concrete tool names or the
|
||||
# ``@server:NAME`` shorthand. The shorthand expands to every tool the
|
||||
# given MCP server currently exposes (requires a live catalog; when one
|
||||
# is not available the shorthand is silently dropped so we fall back to
|
||||
# the named entries only).
|
||||
|
||||
_TOOL_CATEGORIES: dict[str, list[str]] = {
|
||||
# Read-only file operations — safe baseline for every knowledge queen.
|
||||
"file_read": [
|
||||
"read_file",
|
||||
"list_directory",
|
||||
"list_dir",
|
||||
"list_files",
|
||||
"search_files",
|
||||
"grep_search",
|
||||
"pdf_read",
|
||||
],
|
||||
# File mutation — only personas that author or edit artifacts.
|
||||
"file_write": [
|
||||
"write_file",
|
||||
"edit_file",
|
||||
"apply_diff",
|
||||
"apply_patch",
|
||||
"replace_file_content",
|
||||
"hashline_edit",
|
||||
"undo_changes",
|
||||
],
|
||||
# Shell + process control — engineering personas only.
|
||||
"shell": [
|
||||
"run_command",
|
||||
"execute_command_tool",
|
||||
"bash_kill",
|
||||
"bash_output",
|
||||
],
|
||||
# Tabular data. CSV/Excel read/write + DuckDB SQL.
|
||||
"data": [
|
||||
"csv_read",
|
||||
"csv_info",
|
||||
"csv_write",
|
||||
"csv_append",
|
||||
"csv_sql",
|
||||
"excel_read",
|
||||
"excel_info",
|
||||
"excel_write",
|
||||
"excel_append",
|
||||
"excel_search",
|
||||
"excel_sheet_list",
|
||||
"excel_sql",
|
||||
],
|
||||
# Browser automation — every tool from the gcu-tools MCP server.
|
||||
"browser": ["@server:gcu-tools"],
|
||||
# External research / information-gathering.
|
||||
"research": [
|
||||
"search_papers",
|
||||
"download_paper",
|
||||
"search_wikipedia",
|
||||
"web_scrape",
|
||||
],
|
||||
# Security scanners — pentest-ish, only for engineering/security roles.
|
||||
"security": [
|
||||
"dns_security_scan",
|
||||
"http_headers_scan",
|
||||
"port_scan",
|
||||
"ssl_tls_scan",
|
||||
"subdomain_enumerate",
|
||||
"tech_stack_detect",
|
||||
"risk_score",
|
||||
],
|
||||
# Lightweight context helpers — good default for every queen.
|
||||
"time_context": [
|
||||
"get_current_time",
|
||||
"get_account_info",
|
||||
],
|
||||
# Runtime log inspection — debug/observability for builder personas.
|
||||
"runtime_inspection": [
|
||||
"query_runtime_logs",
|
||||
"query_runtime_log_details",
|
||||
"query_runtime_log_raw",
|
||||
],
|
||||
# Agent-management tools — building/validating/checking agents.
|
||||
"agent_mgmt": [
|
||||
"list_agents",
|
||||
"list_agent_tools",
|
||||
"list_agent_sessions",
|
||||
"get_agent_checkpoint",
|
||||
"list_agent_checkpoints",
|
||||
"run_agent_tests",
|
||||
"save_agent_draft",
|
||||
"confirm_and_build",
|
||||
"validate_agent_package",
|
||||
"validate_agent_tools",
|
||||
"enqueue_task",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-queen mapping.
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Built from the queen personas in ``queen_profiles.DEFAULT_QUEENS``. The
|
||||
# goal is "just enough" — a queen should see tools she'd plausibly call
|
||||
# for her stated role, nothing more. Users curate further via the Tool
|
||||
# Library if they want.
|
||||
#
|
||||
# A queen whose ID is NOT in this map falls through to "allow every MCP
|
||||
# tool" (the original behavior), which keeps the system compatible with
|
||||
# user-added custom queen IDs that we don't know about.
|
||||
|
||||
QUEEN_DEFAULT_CATEGORIES: dict[str, list[str]] = {
|
||||
# Head of Technology — builds and operates systems; full toolkit.
|
||||
"queen_technology": [
|
||||
"file_read",
|
||||
"file_write",
|
||||
"shell",
|
||||
"data",
|
||||
"browser",
|
||||
"research",
|
||||
"security",
|
||||
"time_context",
|
||||
"runtime_inspection",
|
||||
"agent_mgmt",
|
||||
],
|
||||
# Head of Growth — data, experiments, competitor research; no shell/security.
|
||||
"queen_growth": [
|
||||
"file_read",
|
||||
"file_write",
|
||||
"data",
|
||||
"browser",
|
||||
"research",
|
||||
"time_context",
|
||||
],
|
||||
# Head of Product Strategy — user research + roadmaps; no shell/security.
|
||||
"queen_product_strategy": [
|
||||
"file_read",
|
||||
"file_write",
|
||||
"data",
|
||||
"browser",
|
||||
"research",
|
||||
"time_context",
|
||||
],
|
||||
# Head of Finance — financial models (CSV/Excel heavy), market research.
|
||||
"queen_finance_fundraising": [
|
||||
"file_read",
|
||||
"file_write",
|
||||
"data",
|
||||
"browser",
|
||||
"research",
|
||||
"time_context",
|
||||
],
|
||||
# Head of Legal — reads contracts/PDFs, researches; no shell/data/security.
|
||||
"queen_legal": [
|
||||
"file_read",
|
||||
"file_write",
|
||||
"browser",
|
||||
"research",
|
||||
"time_context",
|
||||
],
|
||||
# Head of Brand & Design — visual refs, style guides; no shell/data/security.
|
||||
"queen_brand_design": [
|
||||
"file_read",
|
||||
"file_write",
|
||||
"browser",
|
||||
"research",
|
||||
"time_context",
|
||||
],
|
||||
# Head of Talent — candidate pipelines, resumes; data + browser heavy.
|
||||
"queen_talent": [
|
||||
"file_read",
|
||||
"file_write",
|
||||
"data",
|
||||
"browser",
|
||||
"research",
|
||||
"time_context",
|
||||
],
|
||||
# Head of Operations — processes, automation, observability.
|
||||
"queen_operations": [
|
||||
"file_read",
|
||||
"file_write",
|
||||
"data",
|
||||
"browser",
|
||||
"research",
|
||||
"time_context",
|
||||
"runtime_inspection",
|
||||
"agent_mgmt",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def has_role_default(queen_id: str) -> bool:
|
||||
"""Return True when ``queen_id`` is known to the category table."""
|
||||
return queen_id in QUEEN_DEFAULT_CATEGORIES
|
||||
|
||||
|
||||
def resolve_queen_default_tools(
|
||||
queen_id: str,
|
||||
mcp_catalog: dict[str, list[dict[str, Any]]] | None = None,
|
||||
) -> list[str] | None:
|
||||
"""Return the role-based default allowlist for ``queen_id``.
|
||||
|
||||
Arguments:
|
||||
queen_id: Profile ID (e.g. ``"queen_technology"``).
|
||||
mcp_catalog: Optional mapping of ``{server_name: [{"name": ...}, ...]}``
|
||||
used to expand ``@server:NAME`` shorthands in categories.
|
||||
When absent, shorthand entries are dropped and the result
|
||||
contains only the explicitly-named tools.
|
||||
|
||||
Returns:
|
||||
A deduplicated list of tool names, or ``None`` if the queen has
|
||||
no role entry (caller should treat as "allow every MCP tool").
|
||||
"""
|
||||
categories = QUEEN_DEFAULT_CATEGORIES.get(queen_id)
|
||||
if not categories:
|
||||
return None
|
||||
|
||||
names: list[str] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
def _add(name: str) -> None:
|
||||
if name and name not in seen:
|
||||
seen.add(name)
|
||||
names.append(name)
|
||||
|
||||
for cat in categories:
|
||||
for entry in _TOOL_CATEGORIES.get(cat, []):
|
||||
if entry.startswith("@server:"):
|
||||
server_name = entry[len("@server:") :]
|
||||
if mcp_catalog is None:
|
||||
logger.debug(
|
||||
"resolve_queen_default_tools: catalog missing; cannot expand %s",
|
||||
entry,
|
||||
)
|
||||
continue
|
||||
for tool in mcp_catalog.get(server_name, []) or []:
|
||||
tname = tool.get("name") if isinstance(tool, dict) else None
|
||||
if tname:
|
||||
_add(tname)
|
||||
else:
|
||||
_add(entry)
|
||||
|
||||
return names
|
||||
@@ -155,6 +155,57 @@ def get_preferred_worker_model() -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def get_vision_fallback_model() -> str | None:
|
||||
"""Return the configured vision-fallback model, or None if not configured.
|
||||
|
||||
Reads from the ``vision_fallback`` section of ~/.hive/configuration.json.
|
||||
Used by the agent-loop hook that captions tool-result images when the
|
||||
main agent's model cannot accept image content (text-only LLMs).
|
||||
|
||||
When this returns None the fallback chain skips the configured-subagent
|
||||
stage and proceeds straight to the generic caption rotation
|
||||
(``_describe_images_as_text``).
|
||||
"""
|
||||
vision = get_hive_config().get("vision_fallback", {})
|
||||
if vision.get("provider") and vision.get("model"):
|
||||
provider = str(vision["provider"])
|
||||
model = str(vision["model"]).strip()
|
||||
if provider.lower() == "openrouter" and model.lower().startswith("openrouter/"):
|
||||
model = model[len("openrouter/") :]
|
||||
if model:
|
||||
return f"{provider}/{model}"
|
||||
return None
|
||||
|
||||
|
||||
def get_vision_fallback_api_key() -> str | None:
|
||||
"""Return the API key for the vision-fallback model.
|
||||
|
||||
Resolution order: ``vision_fallback.api_key_env_var`` from the env,
|
||||
then the default ``get_api_key()``. No subscription-token branches —
|
||||
vision fallback is intended for hosted vision models (Anthropic,
|
||||
OpenAI, Google), not for the subscription-bearer providers.
|
||||
"""
|
||||
vision = get_hive_config().get("vision_fallback", {})
|
||||
if not vision:
|
||||
return get_api_key()
|
||||
api_key_env_var = vision.get("api_key_env_var")
|
||||
if api_key_env_var:
|
||||
return os.environ.get(api_key_env_var)
|
||||
return get_api_key()
|
||||
|
||||
|
||||
def get_vision_fallback_api_base() -> str | None:
|
||||
"""Return the api_base for the vision-fallback model, or None."""
|
||||
vision = get_hive_config().get("vision_fallback", {})
|
||||
if not vision:
|
||||
return None
|
||||
if vision.get("api_base"):
|
||||
return vision["api_base"]
|
||||
if str(vision.get("provider", "")).lower() == "openrouter":
|
||||
return OPENROUTER_API_BASE
|
||||
return None
|
||||
|
||||
|
||||
def get_worker_api_key() -> str | None:
|
||||
"""Return the API key for the worker LLM, falling back to the default key."""
|
||||
worker_llm = get_hive_config().get("worker_llm", {})
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
"""Read/write helpers for per-colony metadata.json.
|
||||
|
||||
A colony's metadata.json lives at ``{COLONIES_DIR}/{colony_name}/metadata.json``
|
||||
and holds immutable provenance: the queen that created it, the forked
|
||||
session id, creation/update timestamps, and the list of workers.
|
||||
|
||||
Mutable user-editable tool configuration lives in a sibling
|
||||
``tools.json`` sidecar — see :mod:`framework.host.colony_tools_config`
|
||||
— so identity and tool gating evolve independently.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from framework.config import COLONIES_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def colony_metadata_path(colony_name: str) -> Path:
|
||||
"""Return the on-disk path to a colony's metadata.json."""
|
||||
return COLONIES_DIR / colony_name / "metadata.json"
|
||||
|
||||
|
||||
def load_colony_metadata(colony_name: str) -> dict[str, Any]:
|
||||
"""Load metadata.json for ``colony_name``.
|
||||
|
||||
Returns an empty dict if the file is missing or malformed — callers
|
||||
are expected to treat missing fields as defaults.
|
||||
"""
|
||||
path = colony_metadata_path(colony_name)
|
||||
if not path.exists():
|
||||
return {}
|
||||
try:
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError):
|
||||
logger.warning("Failed to read colony metadata at %s", path)
|
||||
return {}
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
|
||||
def update_colony_metadata(colony_name: str, updates: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Shallow-merge ``updates`` into metadata.json and persist.
|
||||
|
||||
Returns the full updated dict. Raises ``FileNotFoundError`` if the
|
||||
colony does not exist. Writes atomically via ``os.replace`` to
|
||||
minimize the window where a reader could see a half-written file.
|
||||
"""
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
path = colony_metadata_path(colony_name)
|
||||
if not path.parent.exists():
|
||||
raise FileNotFoundError(f"Colony '{colony_name}' not found")
|
||||
|
||||
data = load_colony_metadata(colony_name) if path.exists() else {}
|
||||
for key, value in updates.items():
|
||||
data[key] = value
|
||||
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd, tmp_path = tempfile.mkstemp(
|
||||
prefix=".metadata.",
|
||||
suffix=".json.tmp",
|
||||
dir=str(path.parent),
|
||||
)
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as fh:
|
||||
json.dump(data, fh, indent=2)
|
||||
fh.flush()
|
||||
os.fsync(fh.fileno())
|
||||
os.replace(tmp_path, path)
|
||||
except BaseException:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
return data
|
||||
|
||||
|
||||
def list_colony_names() -> list[str]:
|
||||
"""Return the names of every colony that has a metadata.json on disk."""
|
||||
if not COLONIES_DIR.is_dir():
|
||||
return []
|
||||
names: list[str] = []
|
||||
for entry in sorted(COLONIES_DIR.iterdir()):
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
if (entry / "metadata.json").exists():
|
||||
names.append(entry.name)
|
||||
return names
|
||||
@@ -185,6 +185,8 @@ class ColonyRuntime:
|
||||
protocols_prompt: str = "",
|
||||
skill_dirs: list[str] | None = None,
|
||||
pipeline_stages: list | None = None,
|
||||
queen_id: str | None = None,
|
||||
colony_name: str | None = None,
|
||||
):
|
||||
from framework.pipeline.runner import PipelineRunner
|
||||
from framework.skills.manager import SkillsManager
|
||||
@@ -193,14 +195,27 @@ class ColonyRuntime:
|
||||
self._goal = goal
|
||||
self._config = config or ColonyConfig()
|
||||
self._runtime_log_store = runtime_log_store
|
||||
self._queen_id: str | None = queen_id
|
||||
# ``colony_id`` is the event-bus scope (session.id in DM sessions);
|
||||
# ``colony_name`` is the on-disk identity under ~/.hive/colonies/.
|
||||
# They coincide for forked colonies but diverge for queen DM
|
||||
# sessions, so separate them explicitly.
|
||||
self._colony_name: str | None = colony_name
|
||||
|
||||
if pipeline_stages:
|
||||
self._pipeline = PipelineRunner(pipeline_stages)
|
||||
else:
|
||||
self._pipeline = self._load_pipeline_from_config()
|
||||
|
||||
if skills_manager_config is not None:
|
||||
self._skills_manager = SkillsManager(skills_manager_config)
|
||||
# Resolve per-colony override paths so UI toggles can reach this
|
||||
# runtime. Callers that build their own SkillsManagerConfig stay
|
||||
# in charge; bare construction auto-wires the standard paths.
|
||||
_effective_cfg = skills_manager_config
|
||||
if _effective_cfg is None and not (skills_catalog_prompt or protocols_prompt):
|
||||
_effective_cfg = self._build_default_skills_config(colony_name, queen_id)
|
||||
|
||||
if _effective_cfg is not None:
|
||||
self._skills_manager = SkillsManager(_effective_cfg)
|
||||
self._skills_manager.load()
|
||||
elif skills_catalog_prompt or protocols_prompt:
|
||||
import warnings
|
||||
@@ -221,6 +236,28 @@ class ColonyRuntime:
|
||||
self.batch_init_nudge: str | None = self._skills_manager.batch_init_nudge
|
||||
|
||||
self._colony_id: str = colony_id or "primary"
|
||||
|
||||
# Ensure the colony task template exists. Idempotent — if the
|
||||
# colony was created previously, this is a no-op (it just stamps
|
||||
# last_seen_session_ids if a session id is provided later).
|
||||
try:
|
||||
import asyncio as _asyncio
|
||||
|
||||
from framework.tasks import TaskListRole, get_task_store
|
||||
from framework.tasks.scoping import colony_task_list_id
|
||||
|
||||
_store = get_task_store()
|
||||
_list_id = colony_task_list_id(self._colony_id)
|
||||
try:
|
||||
# Best-effort: schedule on the running loop, or do it inline
|
||||
# if no loop is yet running (e.g. during construction).
|
||||
_loop = _asyncio.get_running_loop()
|
||||
_loop.create_task(_store.ensure_task_list(_list_id, role=TaskListRole.TEMPLATE))
|
||||
except RuntimeError:
|
||||
_asyncio.run(_store.ensure_task_list(_list_id, role=TaskListRole.TEMPLATE))
|
||||
except Exception:
|
||||
logger.debug("Failed to ensure colony task template", exc_info=True)
|
||||
|
||||
self._accounts_prompt = accounts_prompt
|
||||
self._accounts_data = accounts_data
|
||||
self._tool_provider_map = tool_provider_map
|
||||
@@ -238,10 +275,33 @@ class ColonyRuntime:
|
||||
self._event_bus = event_bus or EventBus(max_history=self._config.max_history)
|
||||
self._scoped_event_bus = StreamEventBus(self._event_bus, self._colony_id)
|
||||
|
||||
# Make the event bus visible to the task-system event emitters so
|
||||
# task lifecycle events fan out to the same bus the rest of the
|
||||
# system uses. Idempotent — last writer wins.
|
||||
try:
|
||||
from framework.tasks.events import set_default_event_bus
|
||||
|
||||
set_default_event_bus(self._event_bus)
|
||||
except Exception:
|
||||
logger.debug("Failed to register default task event bus", exc_info=True)
|
||||
|
||||
self._llm = llm
|
||||
self._tools = tools or []
|
||||
self._tool_executor = tool_executor
|
||||
|
||||
# Per-colony MCP tool allowlist — applied when spawning workers. A
|
||||
# value of ``None`` means "allow every MCP tool" (default), an empty
|
||||
# list disables every MCP tool, and a list of names only enables
|
||||
# those. Lifecycle / synthetic tools always pass through the filter
|
||||
# because their names are absent from ``_mcp_tool_names_all``. The
|
||||
# allowlist is re-read on every ``spawn`` so a PATCH that mutates
|
||||
# this attribute via ``set_tool_allowlist`` takes effect on the
|
||||
# NEXT worker spawn without a runtime restart. In-flight workers
|
||||
# keep the tool list they booted with — workers have no dynamic
|
||||
# tools provider today.
|
||||
self._enabled_mcp_tools: list[str] | None = None
|
||||
self._mcp_tool_names_all: set[str] = set()
|
||||
|
||||
# Worker management
|
||||
self._workers: dict[str, Worker] = {}
|
||||
# The persistent client-facing overseer (optional). Set by
|
||||
@@ -359,6 +419,19 @@ class ColonyRuntime:
|
||||
def _apply_pipeline_results(self) -> None:
|
||||
for stage in self._pipeline.stages:
|
||||
if stage.tool_registry is not None:
|
||||
# Register task tools on the same registry every worker
|
||||
# pulls from. Done here (not at worker spawn) so the
|
||||
# colony's `_tools` snapshot includes them.
|
||||
try:
|
||||
from framework.tasks.tools import register_task_tools
|
||||
|
||||
register_task_tools(stage.tool_registry)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to register task tools on pipeline registry",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
tools = list(stage.tool_registry.get_tools().values())
|
||||
if tools:
|
||||
self._tools = tools
|
||||
@@ -384,6 +457,128 @@ class ColonyRuntime:
|
||||
return PipelineRunner([])
|
||||
return build_pipeline_from_config(stages_config)
|
||||
|
||||
@staticmethod
|
||||
def _build_default_skills_config(
|
||||
colony_name: str | None,
|
||||
queen_id: str | None,
|
||||
) -> SkillsManagerConfig:
|
||||
"""Assemble a ``SkillsManagerConfig`` that wires in the per-colony /
|
||||
per-queen override files and the ``queen_ui`` / ``colony_ui`` scope
|
||||
dirs based on the standard ``~/.hive`` layout.
|
||||
|
||||
``colony_name`` must be an actual on-disk colony name
|
||||
(``~/.hive/colonies/{name}/``). DM sessions where the ``colony_id``
|
||||
is a session UUID should pass ``None`` so we don't create a stray
|
||||
override file under a session identifier.
|
||||
"""
|
||||
from framework.config import COLONIES_DIR, QUEENS_DIR
|
||||
from framework.skills.discovery import ExtraScope
|
||||
from framework.skills.manager import SkillsManagerConfig
|
||||
|
||||
extras: list[ExtraScope] = []
|
||||
queen_overrides_path: Path | None = None
|
||||
if queen_id:
|
||||
queen_home = QUEENS_DIR / queen_id
|
||||
queen_overrides_path = queen_home / "skills_overrides.json"
|
||||
extras.append(ExtraScope(directory=queen_home / "skills", label="queen_ui", priority=2))
|
||||
|
||||
colony_overrides_path: Path | None = None
|
||||
if colony_name:
|
||||
colony_home = COLONIES_DIR / colony_name
|
||||
colony_overrides_path = colony_home / "skills_overrides.json"
|
||||
# Colony-scope SKILL.md dir is the project-scope from discovery's
|
||||
# point of view (colony_dir is the project_root). Add it also as
|
||||
# a tagged ``colony_ui`` scope so UI-created entries resolve with
|
||||
# correct provenance.
|
||||
extras.append(
|
||||
ExtraScope(
|
||||
directory=colony_home / ".hive" / "skills",
|
||||
label="colony_ui",
|
||||
priority=3,
|
||||
)
|
||||
)
|
||||
|
||||
return SkillsManagerConfig(
|
||||
queen_id=queen_id,
|
||||
queen_overrides_path=queen_overrides_path,
|
||||
colony_name=colony_name,
|
||||
colony_overrides_path=colony_overrides_path,
|
||||
extra_scope_dirs=extras,
|
||||
interactive=False, # HTTP-driven runtimes never prompt for consent
|
||||
)
|
||||
|
||||
@property
|
||||
def queen_id(self) -> str | None:
|
||||
"""The queen that owns this runtime, if known."""
|
||||
return self._queen_id
|
||||
|
||||
@property
|
||||
def colony_name(self) -> str | None:
|
||||
"""The on-disk colony name (distinct from event-bus scope ``colony_id``)."""
|
||||
return self._colony_name
|
||||
|
||||
@property
|
||||
def skills_manager(self):
|
||||
"""Access the live :class:`SkillsManager` (for HTTP handlers)."""
|
||||
return self._skills_manager
|
||||
|
||||
async def reload_skills(self) -> dict[str, Any]:
|
||||
"""Rebuild the catalog after an override change; in-flight workers
|
||||
pick up the new catalog on their next iteration via
|
||||
``dynamic_skills_catalog_provider``.
|
||||
|
||||
Returns a small stats dict that HTTP handlers can echo back to
|
||||
the UI ("applied — N skills now in catalog").
|
||||
"""
|
||||
async with self._skills_manager.mutation_lock:
|
||||
self._skills_manager.reload()
|
||||
self.skill_dirs = self._skills_manager.allowlisted_dirs
|
||||
self.batch_init_nudge = self._skills_manager.batch_init_nudge
|
||||
self.context_warn_ratio = self._skills_manager.context_warn_ratio
|
||||
catalog_prompt = self._skills_manager.skills_catalog_prompt
|
||||
return {
|
||||
"catalog_chars": len(catalog_prompt),
|
||||
"skill_dirs": list(self.skill_dirs),
|
||||
}
|
||||
|
||||
# ── Per-colony tool allowlist ───────────────────────────────
|
||||
|
||||
def set_tool_allowlist(
|
||||
self,
|
||||
enabled_mcp_tools: list[str] | None,
|
||||
mcp_tool_names_all: set[str] | None = None,
|
||||
) -> None:
|
||||
"""Configure the per-colony MCP tool allowlist.
|
||||
|
||||
Called at construction time (from SessionManager) and again from
|
||||
the ``/api/colony/{name}/tools`` PATCH handler when a user edits
|
||||
the allowlist. The change applies to the NEXT worker spawn — we
|
||||
never mutate the tool list of a worker that is already running
|
||||
(workers have no dynamic tools provider, so hot-reloading their
|
||||
tool set would diverge from the list the LLM was already using).
|
||||
"""
|
||||
self._enabled_mcp_tools = list(enabled_mcp_tools) if enabled_mcp_tools is not None else None
|
||||
if mcp_tool_names_all is not None:
|
||||
self._mcp_tool_names_all = set(mcp_tool_names_all)
|
||||
|
||||
def _apply_tool_allowlist(self, tools: list) -> list:
|
||||
"""Filter ``tools`` against the colony's MCP allowlist.
|
||||
|
||||
Lifecycle / synthetic tools (those whose names are NOT in
|
||||
``_mcp_tool_names_all``) are never gated. MCP tools are kept only
|
||||
when ``_enabled_mcp_tools`` is None (default allow) or contains
|
||||
their name. Input list order is preserved so downstream cache
|
||||
keys and logs stay stable.
|
||||
"""
|
||||
if self._enabled_mcp_tools is None:
|
||||
return tools
|
||||
allowed = set(self._enabled_mcp_tools)
|
||||
return [
|
||||
t
|
||||
for t in tools
|
||||
if getattr(t, "name", None) not in self._mcp_tool_names_all or getattr(t, "name", None) in allowed
|
||||
]
|
||||
|
||||
# ── Lifecycle ───────────────────────────────────────────────
|
||||
|
||||
async def start(self) -> None:
|
||||
@@ -658,6 +853,14 @@ class ColonyRuntime:
|
||||
spawn_tools = tools if tools is not None else self._tools
|
||||
spawn_executor = tool_executor or self._tool_executor
|
||||
|
||||
# Apply the per-colony MCP tool allowlist (if any). Done HERE —
|
||||
# after spawn_tools is resolved but before it's frozen into the
|
||||
# worker's AgentContext — so the next spawn reflects any PATCH
|
||||
# that happened since the last spawn. A value of ``None`` on
|
||||
# ``_enabled_mcp_tools`` is a no-op so the default path is
|
||||
# unchanged.
|
||||
spawn_tools = self._apply_tool_allowlist(spawn_tools)
|
||||
|
||||
# Colony progress tracker: when the caller supplied a db_path
|
||||
# in input_data, this worker is part of a SQLite task queue
|
||||
# and must see the hive.colony-progress-tracker skill body in
|
||||
@@ -740,6 +943,34 @@ class ColonyRuntime:
|
||||
conversation_store=worker_conv_store,
|
||||
)
|
||||
|
||||
# Workers pick up UI-driven override changes via this provider,
|
||||
# which reads the live catalog on each iteration. The db_path
|
||||
# pre-activated catalog stays static because its contents are
|
||||
# built for *this* worker's task (a tombstone toggle from the
|
||||
# UI should not yank it mid-run).
|
||||
_db_path_pre_activated = bool(isinstance(input_data, dict) and input_data.get("db_path"))
|
||||
# Default-bind the manager into the closure so each loop iteration
|
||||
# captures the same manager instance — pyflakes B023 would flag a
|
||||
# free-variable capture here.
|
||||
_provider = None if _db_path_pre_activated else (lambda mgr=self._skills_manager: mgr.skills_catalog_prompt)
|
||||
|
||||
# Task-system fields. Each worker owns its session task list;
|
||||
# picked_up_from records the colony template entry it was
|
||||
# spawned for, when applicable.
|
||||
from framework.tasks.scoping import (
|
||||
colony_task_list_id as _colony_list_id,
|
||||
session_task_list_id as _session_list_id,
|
||||
)
|
||||
|
||||
_worker_list_id = _session_list_id(worker_id, worker_id)
|
||||
_picked_up = None
|
||||
_template_id = input_data.get("__template_task_id") if isinstance(input_data, dict) else None
|
||||
if _template_id is not None:
|
||||
try:
|
||||
_picked_up = (_colony_list_id(self._colony_id), int(_template_id))
|
||||
except (TypeError, ValueError):
|
||||
_picked_up = None
|
||||
|
||||
agent_context = AgentContext(
|
||||
runtime=self._make_runtime_adapter(worker_id),
|
||||
agent_id=worker_id,
|
||||
@@ -753,8 +984,12 @@ class ColonyRuntime:
|
||||
skills_catalog_prompt=_spawn_catalog,
|
||||
protocols_prompt=self.protocols_prompt,
|
||||
skill_dirs=_spawn_skill_dirs,
|
||||
dynamic_skills_catalog_provider=_provider,
|
||||
execution_id=worker_id,
|
||||
stream_id=explicit_stream_id or f"worker:{worker_id}",
|
||||
task_list_id=_worker_list_id,
|
||||
colony_id=self._colony_id,
|
||||
picked_up_from=_picked_up,
|
||||
)
|
||||
|
||||
worker = Worker(
|
||||
@@ -997,6 +1232,7 @@ class ColonyRuntime:
|
||||
conversation_store=overseer_conv_store,
|
||||
)
|
||||
|
||||
_overseer_skills_mgr = self._skills_manager
|
||||
overseer_ctx = AgentContext(
|
||||
runtime=self._make_runtime_adapter(overseer_id),
|
||||
agent_id=overseer_id,
|
||||
@@ -1010,6 +1246,7 @@ class ColonyRuntime:
|
||||
skills_catalog_prompt=self.skills_catalog_prompt,
|
||||
protocols_prompt=self.protocols_prompt,
|
||||
skill_dirs=self.skill_dirs,
|
||||
dynamic_skills_catalog_provider=lambda: _overseer_skills_mgr.skills_catalog_prompt,
|
||||
execution_id=overseer_id,
|
||||
stream_id="overseer",
|
||||
)
|
||||
|
||||
@@ -0,0 +1,162 @@
|
||||
"""Per-colony tool configuration sidecar (``tools.json``).
|
||||
|
||||
Lives at ``~/.hive/colonies/{colony_name}/tools.json`` alongside
|
||||
``metadata.json``. Kept separate so provenance (queen_name,
|
||||
created_at, workers) stays in metadata while the user-editable tool
|
||||
allowlist gets its own file.
|
||||
|
||||
Schema::
|
||||
|
||||
{
|
||||
"enabled_mcp_tools": ["read_file", ...] | null,
|
||||
"updated_at": "2026-04-21T12:34:56+00:00"
|
||||
}
|
||||
|
||||
- ``null`` / missing file → default "allow every MCP tool".
|
||||
- ``[]`` → explicitly disable every MCP tool.
|
||||
- ``["foo", "bar"]`` → only those MCP tool names pass the filter.
|
||||
|
||||
Atomic writes via ``os.replace`` mirror
|
||||
``framework.host.colony_metadata.update_colony_metadata``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from framework.config import COLONIES_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def tools_config_path(colony_name: str) -> Path:
|
||||
"""Return the on-disk path to a colony's ``tools.json``."""
|
||||
return COLONIES_DIR / colony_name / "tools.json"
|
||||
|
||||
|
||||
def _metadata_path(colony_name: str) -> Path:
|
||||
return COLONIES_DIR / colony_name / "metadata.json"
|
||||
|
||||
|
||||
def _atomic_write_json(path: Path, data: dict[str, Any]) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd, tmp = tempfile.mkstemp(
|
||||
prefix=".tools.",
|
||||
suffix=".json.tmp",
|
||||
dir=str(path.parent),
|
||||
)
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as fh:
|
||||
json.dump(data, fh, indent=2)
|
||||
fh.flush()
|
||||
os.fsync(fh.fileno())
|
||||
os.replace(tmp, path)
|
||||
except BaseException:
|
||||
try:
|
||||
os.unlink(tmp)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
|
||||
def _migrate_from_metadata_if_needed(colony_name: str) -> list[str] | None:
|
||||
"""Hoist a legacy ``enabled_mcp_tools`` field out of ``metadata.json``.
|
||||
|
||||
Returns the migrated value (or ``None`` if nothing to migrate). After
|
||||
migration the sidecar exists and ``metadata.json`` no longer contains
|
||||
``enabled_mcp_tools``. Safe to call repeatedly.
|
||||
"""
|
||||
meta_path = _metadata_path(colony_name)
|
||||
if not meta_path.exists():
|
||||
return None
|
||||
try:
|
||||
data = json.loads(meta_path.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError):
|
||||
logger.warning("Could not read metadata.json during tools migration: %s", colony_name)
|
||||
return None
|
||||
if not isinstance(data, dict) or "enabled_mcp_tools" not in data:
|
||||
return None
|
||||
|
||||
raw = data.pop("enabled_mcp_tools")
|
||||
enabled: list[str] | None
|
||||
if raw is None:
|
||||
enabled = None
|
||||
elif isinstance(raw, list) and all(isinstance(x, str) for x in raw):
|
||||
enabled = raw
|
||||
else:
|
||||
logger.warning(
|
||||
"Legacy enabled_mcp_tools on colony %s had unexpected shape %r; dropping",
|
||||
colony_name,
|
||||
raw,
|
||||
)
|
||||
enabled = None
|
||||
|
||||
# Sidecar first so a partial failure leaves the config recoverable.
|
||||
_atomic_write_json(
|
||||
tools_config_path(colony_name),
|
||||
{
|
||||
"enabled_mcp_tools": enabled,
|
||||
"updated_at": datetime.now(UTC).isoformat(),
|
||||
},
|
||||
)
|
||||
_atomic_write_json(meta_path, data)
|
||||
logger.info(
|
||||
"Migrated enabled_mcp_tools for colony %s from metadata.json to tools.json",
|
||||
colony_name,
|
||||
)
|
||||
return enabled
|
||||
|
||||
|
||||
def load_colony_tools_config(colony_name: str) -> list[str] | None:
|
||||
"""Return the colony's MCP tool allowlist, or ``None`` for default-allow.
|
||||
|
||||
Order of resolution:
|
||||
1. ``tools.json`` sidecar (authoritative).
|
||||
2. Legacy ``metadata.json`` field (migrated and deleted on first read).
|
||||
3. ``None`` — default "allow every MCP tool".
|
||||
"""
|
||||
path = tools_config_path(colony_name)
|
||||
if path.exists():
|
||||
try:
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError):
|
||||
logger.warning("Invalid %s; treating as default-allow", path)
|
||||
return None
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
raw = data.get("enabled_mcp_tools")
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, list) and all(isinstance(x, str) for x in raw):
|
||||
return raw
|
||||
logger.warning("Unexpected enabled_mcp_tools shape in %s; ignoring", path)
|
||||
return None
|
||||
|
||||
return _migrate_from_metadata_if_needed(colony_name)
|
||||
|
||||
|
||||
def update_colony_tools_config(
|
||||
colony_name: str,
|
||||
enabled_mcp_tools: list[str] | None,
|
||||
) -> list[str] | None:
|
||||
"""Persist a colony's MCP allowlist to ``tools.json``.
|
||||
|
||||
Raises ``FileNotFoundError`` if the colony's directory is missing.
|
||||
"""
|
||||
colony_dir = COLONIES_DIR / colony_name
|
||||
if not colony_dir.exists():
|
||||
raise FileNotFoundError(f"Colony directory not found: {colony_name}")
|
||||
_atomic_write_json(
|
||||
tools_config_path(colony_name),
|
||||
{
|
||||
"enabled_mcp_tools": enabled_mcp_tools,
|
||||
"updated_at": datetime.now(UTC).isoformat(),
|
||||
},
|
||||
)
|
||||
return enabled_mcp_tools
|
||||
@@ -165,6 +165,14 @@ class EventType(StrEnum):
|
||||
TRIGGER_REMOVED = "trigger_removed"
|
||||
TRIGGER_UPDATED = "trigger_updated"
|
||||
|
||||
# Task system lifecycle (per-list diffs streamed to the UI)
|
||||
TASK_CREATED = "task_created"
|
||||
TASK_UPDATED = "task_updated"
|
||||
TASK_DELETED = "task_deleted"
|
||||
TASK_LIST_RESET = "task_list_reset"
|
||||
TASK_LIST_REATTACH_MISMATCH = "task_list_reattach_mismatch"
|
||||
COLONY_TEMPLATE_ASSIGNMENT = "colony_template_assignment"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentEvent:
|
||||
@@ -809,16 +817,28 @@ class EventBus:
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cached_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
cost_usd: float = 0.0,
|
||||
execution_id: str | None = None,
|
||||
iteration: int | None = None,
|
||||
) -> None:
|
||||
"""Emit LLM turn completion with stop reason and model metadata."""
|
||||
"""Emit LLM turn completion with stop reason and model metadata.
|
||||
|
||||
``cached_tokens`` and ``cache_creation_tokens`` are subsets of
|
||||
``input_tokens`` (already inside provider ``prompt_tokens``).
|
||||
Subscribers should display them, not add them to a total.
|
||||
|
||||
``cost_usd`` is the USD cost for this turn when known (Anthropic,
|
||||
OpenAI, OpenRouter). 0.0 means unreported (not free).
|
||||
"""
|
||||
data: dict = {
|
||||
"stop_reason": stop_reason,
|
||||
"model": model,
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"cached_tokens": cached_tokens,
|
||||
"cache_creation_tokens": cache_creation_tokens,
|
||||
"cost_usd": cost_usd,
|
||||
}
|
||||
if iteration is not None:
|
||||
data["iteration"] = iteration
|
||||
|
||||
@@ -154,11 +154,21 @@ class Worker:
|
||||
# value without affecting the queen's ongoing calls.
|
||||
try:
|
||||
from framework.loader.tool_registry import ToolRegistry
|
||||
from framework.tasks.scoping import session_task_list_id
|
||||
|
||||
ToolRegistry.set_execution_context(profile=self.id)
|
||||
ctx = self._context
|
||||
agent_id = getattr(ctx, "agent_id", None) or self.id
|
||||
list_id = getattr(ctx, "task_list_id", None) or session_task_list_id(agent_id, self.id)
|
||||
ToolRegistry.set_execution_context(
|
||||
profile=self.id,
|
||||
agent_id=agent_id,
|
||||
task_list_id=list_id,
|
||||
colony_id=getattr(ctx, "colony_id", None),
|
||||
picked_up_from=getattr(ctx, "picked_up_from", None),
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Worker %s: failed to scope browser profile",
|
||||
"Worker %s: failed to scope execution context",
|
||||
self.id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
@@ -653,10 +653,17 @@ class AntigravityProvider(LLMProvider):
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
system_dynamic_suffix: str | None = None,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
import asyncio # noqa: PLC0415
|
||||
import concurrent.futures # noqa: PLC0415
|
||||
|
||||
# Antigravity (Google's proprietary endpoint) doesn't expose a
|
||||
# cache_control hook. Concatenate the dynamic suffix so its shape
|
||||
# matches the legacy single-string call site.
|
||||
if system_dynamic_suffix:
|
||||
system = f"{system}\n\n{system_dynamic_suffix}" if system else system_dynamic_suffix
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
queue: asyncio.Queue[StreamEvent | None] = asyncio.Queue()
|
||||
|
||||
|
||||
@@ -1,114 +1,32 @@
|
||||
"""Model capability checks for LLM providers.
|
||||
|
||||
Vision support rules are derived from official vendor documentation:
|
||||
- ZAI (z.ai): docs.z.ai/guides/vlm — GLM-4.6V variants are vision; GLM-5/4.6/4.7 are text-only
|
||||
- MiniMax: platform.minimax.io/docs — minimax-vl-01 is vision; M2.x are text-only
|
||||
- DeepSeek: api-docs.deepseek.com — deepseek-vl2 is vision; chat/reasoner are text-only
|
||||
- Cerebras: inference-docs.cerebras.ai — no vision models at all
|
||||
- Groq: console.groq.com/docs/vision — vision capable; treat as supported by default
|
||||
- Ollama/LM Studio/vLLM/llama.cpp: local runners denied by default; model names
|
||||
don't reliably indicate vision support, so users must configure explicitly
|
||||
Vision support is sourced from the curated ``model_catalog.json``. Each model
|
||||
entry carries an optional ``supports_vision`` boolean; unknown models default
|
||||
to vision-capable so hosted frontier models work out of the box. To toggle
|
||||
support for a model, edit its catalog entry rather than this file.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from framework.llm.model_catalog import model_supports_vision
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.llm.provider import Tool
|
||||
|
||||
|
||||
def _model_name(model: str) -> str:
|
||||
"""Return the bare model name after stripping any 'provider/' prefix."""
|
||||
if "/" in model:
|
||||
return model.split("/", 1)[1]
|
||||
return model
|
||||
|
||||
|
||||
# Step 1: explicit vision allow-list — these always support images regardless
|
||||
# of what the provider-level rules say. Checked first so that e.g. glm-4.6v
|
||||
# is allowed even though glm-4.6 is denied.
|
||||
_VISION_ALLOW_BARE_PREFIXES: tuple[str, ...] = (
|
||||
# ZAI/GLM vision models (docs.z.ai/guides/vlm)
|
||||
"glm-4v", # GLM-4V series (legacy)
|
||||
"glm-4.6v", # GLM-4.6V, GLM-4.6V-flash, GLM-4.6V-flashx
|
||||
# DeepSeek vision models
|
||||
"deepseek-vl", # deepseek-vl2, deepseek-vl2-small, deepseek-vl2-tiny
|
||||
# MiniMax vision model
|
||||
"minimax-vl", # minimax-vl-01
|
||||
)
|
||||
|
||||
# Step 2: provider-level deny — every model from this provider is text-only.
|
||||
_TEXT_ONLY_PROVIDER_PREFIXES: tuple[str, ...] = (
|
||||
# Cerebras: inference-docs.cerebras.ai lists only text models
|
||||
"cerebras/",
|
||||
# Local runners: model names don't reliably indicate vision support
|
||||
"ollama/",
|
||||
"ollama_chat/",
|
||||
"lm_studio/",
|
||||
"vllm/",
|
||||
"llamacpp/",
|
||||
)
|
||||
|
||||
# Step 3: per-model deny — text-only models within otherwise mixed providers.
|
||||
# Matched against the bare model name (provider prefix stripped, lower-cased).
|
||||
# The vision allow-list above is checked first, so vision variants of the same
|
||||
# family are already handled before these deny patterns are reached.
|
||||
_TEXT_ONLY_MODEL_BARE_PREFIXES: tuple[str, ...] = (
|
||||
# --- ZAI / GLM family ---
|
||||
# text-only: glm-5, glm-4.6, glm-4.7, glm-4.5, zai-glm-*
|
||||
# vision: glm-4v, glm-4.6v (caught by allow-list above)
|
||||
"glm-5",
|
||||
"glm-4.6", # bare glm-4.6 is text-only; glm-4.6v is caught by allow-list
|
||||
"glm-4.7",
|
||||
"glm-4.5",
|
||||
"zai-glm",
|
||||
# --- DeepSeek ---
|
||||
# text-only: deepseek-chat, deepseek-coder, deepseek-reasoner
|
||||
# vision: deepseek-vl2 (caught by allow-list above)
|
||||
# Note: LiteLLM's deepseek handler may flatten content lists for some models;
|
||||
# VL models are allowed through and rely on LiteLLM's native VL support.
|
||||
"deepseek-chat",
|
||||
"deepseek-coder",
|
||||
"deepseek-reasoner",
|
||||
# --- MiniMax ---
|
||||
# text-only: minimax-m2.*, minimax-text-*, abab* (legacy)
|
||||
# vision: minimax-vl-01 (caught by allow-list above)
|
||||
"minimax-m2",
|
||||
"minimax-text",
|
||||
"abab",
|
||||
)
|
||||
|
||||
|
||||
def supports_image_tool_results(model: str) -> bool:
|
||||
"""Return whether *model* can receive image content in messages.
|
||||
|
||||
Used to gate both user-message images and tool-result image blocks.
|
||||
|
||||
Logic (checked in order):
|
||||
1. Vision allow-list → True (known vision model, skip all denies)
|
||||
2. Provider deny → False (entire provider is text-only)
|
||||
3. Model deny → False (specific text-only model within a mixed provider)
|
||||
4. Default → True (assume capable; unknown providers and models)
|
||||
Thin wrapper over :func:`model_supports_vision` so existing call sites
|
||||
keep working. Used to gate both user-message images and tool-result
|
||||
image blocks. Empty model strings are treated as capable so the default
|
||||
code path doesn't strip images before a provider is selected.
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
bare = _model_name(model_lower)
|
||||
|
||||
# 1. Explicit vision allow — takes priority over all denies
|
||||
if any(bare.startswith(p) for p in _VISION_ALLOW_BARE_PREFIXES):
|
||||
if not model:
|
||||
return True
|
||||
|
||||
# 2. Provider-level deny (all models from this provider are text-only)
|
||||
if any(model_lower.startswith(p) for p in _TEXT_ONLY_PROVIDER_PREFIXES):
|
||||
return False
|
||||
|
||||
# 3. Per-model deny (text-only variants within mixed-capability families)
|
||||
if any(bare.startswith(p) for p in _TEXT_ONLY_MODEL_BARE_PREFIXES):
|
||||
return False
|
||||
|
||||
# 5. Default: assume vision capable
|
||||
# Covers: OpenAI, Anthropic, Google, Mistral, Kimi, and other hosted providers
|
||||
return True
|
||||
return model_supports_vision(model)
|
||||
|
||||
|
||||
def filter_tools_for_model(tools: list[Tool], model: str) -> tuple[list[Tool], list[str]]:
|
||||
|
||||
+350
-28
@@ -33,6 +33,7 @@ except ImportError:
|
||||
RateLimitError = Exception # type: ignore[assignment, misc]
|
||||
|
||||
from framework.config import HIVE_LLM_ENDPOINT as HIVE_API_BASE
|
||||
from framework.llm.model_catalog import get_model_pricing
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
from framework.llm.stream_events import StreamEvent
|
||||
|
||||
@@ -213,9 +214,72 @@ _CACHE_CONTROL_PREFIXES = (
|
||||
"glm-",
|
||||
)
|
||||
|
||||
# OpenRouter sub-provider prefixes whose upstream API honors `cache_control`.
|
||||
# OpenRouter passes the marker through to the underlying provider for these.
|
||||
# (See https://openrouter.ai/docs/guides/best-practices/prompt-caching.)
|
||||
# OpenAI/DeepSeek/Groq/Grok/Moonshot route through OpenRouter but cache
|
||||
# automatically server-side — sending cache_control there is a no-op, not a
|
||||
# win, and they need a separate prefix-stability fix to actually get hits.
|
||||
_OPENROUTER_CACHE_CONTROL_PREFIXES = (
|
||||
"openrouter/anthropic/",
|
||||
"openrouter/google/gemini-",
|
||||
"openrouter/z-ai/glm",
|
||||
"openrouter/minimax/",
|
||||
)
|
||||
|
||||
|
||||
def _model_supports_cache_control(model: str) -> bool:
|
||||
return any(model.startswith(p) for p in _CACHE_CONTROL_PREFIXES)
|
||||
if any(model.startswith(p) for p in _CACHE_CONTROL_PREFIXES):
|
||||
return True
|
||||
return any(model.startswith(p) for p in _OPENROUTER_CACHE_CONTROL_PREFIXES)
|
||||
|
||||
|
||||
def _build_system_message(
|
||||
system: str,
|
||||
system_dynamic_suffix: str | None,
|
||||
model: str,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Construct the system-role message for the chat completion.
|
||||
|
||||
Returns ``None`` when there is nothing to send.
|
||||
|
||||
Two-block split path — used when the caller supplied a non-empty
|
||||
``system_dynamic_suffix`` AND the provider honors ``cache_control``
|
||||
(Anthropic, MiniMax, Z-AI/GLM). We emit ``content`` as a list of two
|
||||
text blocks with an ephemeral ``cache_control`` marker on the first
|
||||
block only. The prompt cache keeps the static prefix warm across
|
||||
turns and across iterations within a turn; only the small dynamic
|
||||
tail is recomputed on every request.
|
||||
|
||||
Single-string path — used for every other case (no suffix provided,
|
||||
or provider doesn't honor ``cache_control``). We concatenate
|
||||
``system`` + ``\\n\\n`` + ``system_dynamic_suffix`` and attach
|
||||
``cache_control`` to the whole message when the provider supports
|
||||
it. This is byte-identical to the pre-split behavior for all
|
||||
non-cache-control providers (OpenAI, Gemini, Groq, Ollama, etc.).
|
||||
"""
|
||||
if not system and not system_dynamic_suffix:
|
||||
return None
|
||||
if system_dynamic_suffix and _model_supports_cache_control(model):
|
||||
content_blocks: list[dict[str, Any]] = []
|
||||
if system:
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": system,
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}
|
||||
)
|
||||
content_blocks.append({"type": "text", "text": system_dynamic_suffix})
|
||||
return {"role": "system", "content": content_blocks}
|
||||
# Single-string path (legacy or no-cache-control provider).
|
||||
combined = system
|
||||
if system_dynamic_suffix:
|
||||
combined = f"{system}\n\n{system_dynamic_suffix}" if system else system_dynamic_suffix
|
||||
sys_msg: dict[str, Any] = {"role": "system", "content": combined}
|
||||
if _model_supports_cache_control(model):
|
||||
sys_msg["cache_control"] = {"type": "ephemeral"}
|
||||
return sys_msg
|
||||
|
||||
|
||||
# Kimi For Coding uses an Anthropic-compatible endpoint (no /v1 suffix).
|
||||
@@ -297,6 +361,171 @@ FAILED_REQUESTS_DIR = Path.home() / ".hive" / "failed_requests"
|
||||
MAX_FAILED_REQUEST_DUMPS = 50
|
||||
|
||||
|
||||
def _cost_from_catalog_pricing(
|
||||
model: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cached_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
) -> float:
|
||||
"""Last-resort cost calculation using curated catalog pricing.
|
||||
|
||||
Consulted only when the provider response carries no native cost and
|
||||
LiteLLM's own catalog has no pricing for ``model``. Reads
|
||||
``pricing_usd_per_mtok`` from ``model_catalog.json``. Rates are USD per
|
||||
million tokens.
|
||||
|
||||
``cached_tokens`` and ``cache_creation_tokens`` are subsets of
|
||||
``input_tokens`` (see ``_extract_cache_tokens``), so subtract them from
|
||||
the base input count to avoid double-billing. If a cache rate is absent,
|
||||
fall back to the plain input rate.
|
||||
"""
|
||||
if not model or (input_tokens == 0 and output_tokens == 0):
|
||||
return 0.0
|
||||
pricing = get_model_pricing(model)
|
||||
if pricing is None and "/" in model:
|
||||
# LiteLLM prefixes some ids (e.g. "openrouter/z-ai/glm-5.1"); the
|
||||
# catalog stores the bare form ("z-ai/glm-5.1"). Strip one segment.
|
||||
pricing = get_model_pricing(model.split("/", 1)[1])
|
||||
if pricing is None:
|
||||
return 0.0
|
||||
|
||||
per_mtok_in = pricing.get("input", 0.0)
|
||||
per_mtok_out = pricing.get("output", 0.0)
|
||||
per_mtok_cache_read = pricing.get("cache_read", per_mtok_in)
|
||||
per_mtok_cache_write = pricing.get("cache_creation", per_mtok_in)
|
||||
|
||||
plain_input = max(input_tokens - cached_tokens - cache_creation_tokens, 0)
|
||||
total = (
|
||||
plain_input * per_mtok_in
|
||||
+ cached_tokens * per_mtok_cache_read
|
||||
+ cache_creation_tokens * per_mtok_cache_write
|
||||
+ output_tokens * per_mtok_out
|
||||
) / 1_000_000
|
||||
return float(total) if total > 0 else 0.0
|
||||
|
||||
|
||||
def _extract_cost(response: Any, model: str) -> float:
|
||||
"""Pull the USD cost for a non-streaming completion response.
|
||||
|
||||
Sources checked, in priority order:
|
||||
1. ``usage.cost`` — populated when OpenRouter returns native cost via
|
||||
``usage: {include: true}`` or when ``litellm.include_cost_in_streaming_usage``
|
||||
is on.
|
||||
2. ``response._hidden_params["response_cost"]`` — set by LiteLLM's
|
||||
logging layer after most successful completions.
|
||||
3. ``litellm.completion_cost(...)`` — computes from the model pricing
|
||||
table; works across Anthropic, OpenAI, and OpenRouter as long as the
|
||||
model is in LiteLLM's catalog.
|
||||
4. ``pricing_usd_per_mtok`` from the curated model catalog — covers
|
||||
models (e.g. GLM, Kimi, MiniMax) that LiteLLM doesn't price.
|
||||
|
||||
Returns 0.0 for unpriced models or unexpected response shapes — cost is a
|
||||
display concern, never let it break the hot path. For streaming paths
|
||||
where the aggregate response isn't a full ``ModelResponse``, use
|
||||
:func:`_cost_from_tokens` with the already-extracted token counts.
|
||||
"""
|
||||
if response is None:
|
||||
return 0.0
|
||||
usage = getattr(response, "usage", None)
|
||||
usage_cost = getattr(usage, "cost", None) if usage is not None else None
|
||||
if isinstance(usage_cost, (int, float)) and usage_cost > 0:
|
||||
return float(usage_cost)
|
||||
|
||||
hidden = getattr(response, "_hidden_params", None)
|
||||
if isinstance(hidden, dict):
|
||||
hp_cost = hidden.get("response_cost")
|
||||
if isinstance(hp_cost, (int, float)) and hp_cost > 0:
|
||||
return float(hp_cost)
|
||||
|
||||
try:
|
||||
import litellm as _litellm
|
||||
|
||||
computed = _litellm.completion_cost(completion_response=response, model=model)
|
||||
if isinstance(computed, (int, float)) and computed > 0:
|
||||
return float(computed)
|
||||
except Exception as exc:
|
||||
logger.debug("[cost] completion_cost failed for %s: %s", model, exc)
|
||||
|
||||
if usage is not None:
|
||||
input_tokens = int(getattr(usage, "prompt_tokens", 0) or 0)
|
||||
output_tokens = int(getattr(usage, "completion_tokens", 0) or 0)
|
||||
cache_read, cache_creation = _extract_cache_tokens(usage)
|
||||
fallback = _cost_from_catalog_pricing(model, input_tokens, output_tokens, cache_read, cache_creation)
|
||||
if fallback > 0:
|
||||
return fallback
|
||||
return 0.0
|
||||
|
||||
|
||||
def _cost_from_tokens(
|
||||
model: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cached_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
) -> float:
|
||||
"""Compute USD cost from already-normalized token counts.
|
||||
|
||||
Used on streaming paths where the aggregate ``response`` is the stream
|
||||
wrapper (not a full ``ModelResponse``) and ``litellm.completion_cost`` on
|
||||
it either no-ops or raises. Calls ``litellm.cost_per_token`` directly
|
||||
with the cache-aware inputs so Anthropic's 5-min-write / cache-read
|
||||
multipliers are applied correctly.
|
||||
"""
|
||||
if not model or (input_tokens == 0 and output_tokens == 0):
|
||||
return 0.0
|
||||
try:
|
||||
import litellm as _litellm
|
||||
|
||||
prompt_cost, completion_cost = _litellm.cost_per_token(
|
||||
model=model,
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=output_tokens,
|
||||
cache_read_input_tokens=cached_tokens,
|
||||
cache_creation_input_tokens=cache_creation_tokens,
|
||||
)
|
||||
total = (prompt_cost or 0.0) + (completion_cost or 0.0)
|
||||
if total > 0:
|
||||
return float(total)
|
||||
except Exception as exc:
|
||||
logger.debug("[cost] cost_per_token failed for %s: %s", model, exc)
|
||||
return _cost_from_catalog_pricing(model, input_tokens, output_tokens, cached_tokens, cache_creation_tokens)
|
||||
|
||||
|
||||
def _extract_cache_tokens(usage: Any) -> tuple[int, int]:
|
||||
"""Pull (cache_read, cache_creation) from a LiteLLM usage object.
|
||||
|
||||
Both are subsets of ``prompt_tokens`` already — providers count them
|
||||
inside the input total. Surface separately for visibility, never sum.
|
||||
|
||||
Field names vary by provider/proxy; check the known shapes in priority
|
||||
order and fall back to 0:
|
||||
|
||||
cache_read:
|
||||
- ``prompt_tokens_details.cached_tokens`` — OpenAI-shape; also what
|
||||
LiteLLM normalizes Anthropic and OpenRouter into.
|
||||
- ``cache_read_input_tokens`` — raw Anthropic field name.
|
||||
|
||||
cache_creation:
|
||||
- ``prompt_tokens_details.cache_write_tokens`` — OpenRouter's
|
||||
normalized field for cache writes (verified empirically against
|
||||
``openrouter/anthropic/*`` and ``openrouter/z-ai/*`` responses).
|
||||
- ``cache_creation_input_tokens`` — raw Anthropic top-level field.
|
||||
"""
|
||||
if not usage:
|
||||
return 0, 0
|
||||
_details = getattr(usage, "prompt_tokens_details", None)
|
||||
cache_read = (
|
||||
getattr(_details, "cached_tokens", 0) or 0
|
||||
if _details is not None
|
||||
else getattr(usage, "cache_read_input_tokens", 0) or 0
|
||||
)
|
||||
cache_creation = (getattr(_details, "cache_write_tokens", 0) or 0 if _details is not None else 0) or (
|
||||
getattr(usage, "cache_creation_input_tokens", 0) or 0
|
||||
)
|
||||
return cache_read, cache_creation
|
||||
|
||||
|
||||
def _estimate_tokens(model: str, messages: list[dict]) -> tuple[int, str]:
|
||||
"""Estimate token count for messages. Returns (token_count, method)."""
|
||||
# Try litellm's token counter first
|
||||
@@ -1015,12 +1244,17 @@ class LiteLLMProvider(LLMProvider):
|
||||
usage = response.usage
|
||||
input_tokens = usage.prompt_tokens if usage else 0
|
||||
output_tokens = usage.completion_tokens if usage else 0
|
||||
cached_tokens, cache_creation_tokens = _extract_cache_tokens(usage)
|
||||
cost_usd = _extract_cost(response, self.model)
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=response.model or self.model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cached_tokens=cached_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
cost_usd=cost_usd,
|
||||
stop_reason=response.choices[0].finish_reason or "",
|
||||
raw_response=response,
|
||||
)
|
||||
@@ -1169,8 +1403,16 @@ class LiteLLMProvider(LLMProvider):
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
system_dynamic_suffix: str | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Async version of complete(). Uses litellm.acompletion — non-blocking."""
|
||||
"""Async version of complete(). Uses litellm.acompletion — non-blocking.
|
||||
|
||||
``system_dynamic_suffix`` is an optional per-turn tail. When set and
|
||||
the provider honors ``cache_control``, ``system`` is sent as the
|
||||
cached prefix and the suffix trails as an uncached second content
|
||||
block. Otherwise the two strings are concatenated into a single
|
||||
system message (legacy behavior).
|
||||
"""
|
||||
# Codex ChatGPT backend requires streaming — route through stream() which
|
||||
# already handles Codex quirks and has proper tool call accumulation.
|
||||
if self._codex_backend:
|
||||
@@ -1181,6 +1423,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_format,
|
||||
json_mode=json_mode,
|
||||
system_dynamic_suffix=system_dynamic_suffix,
|
||||
)
|
||||
return await self._collect_stream_to_response(stream_iter)
|
||||
|
||||
@@ -1188,10 +1431,8 @@ class LiteLLMProvider(LLMProvider):
|
||||
if self._claude_code_oauth:
|
||||
billing = _claude_code_billing_header(messages)
|
||||
full_messages.append({"role": "system", "content": billing})
|
||||
if system:
|
||||
sys_msg: dict[str, Any] = {"role": "system", "content": system}
|
||||
if _model_supports_cache_control(self.model):
|
||||
sys_msg["cache_control"] = {"type": "ephemeral"}
|
||||
sys_msg = _build_system_message(system, system_dynamic_suffix, self.model)
|
||||
if sys_msg is not None:
|
||||
full_messages.append(sys_msg)
|
||||
full_messages.extend(messages)
|
||||
|
||||
@@ -1228,12 +1469,17 @@ class LiteLLMProvider(LLMProvider):
|
||||
usage = response.usage
|
||||
input_tokens = usage.prompt_tokens if usage else 0
|
||||
output_tokens = usage.completion_tokens if usage else 0
|
||||
cached_tokens, cache_creation_tokens = _extract_cache_tokens(usage)
|
||||
cost_usd = _extract_cost(response, self.model)
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=response.model or self.model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cached_tokens=cached_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
cost_usd=cost_usd,
|
||||
stop_reason=response.choices[0].finish_reason or "",
|
||||
raw_response=response,
|
||||
)
|
||||
@@ -1619,6 +1865,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
system_dynamic_suffix: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build a JSON-only prompt for models without native tool support."""
|
||||
tool_specs = [
|
||||
@@ -1646,7 +1893,19 @@ class LiteLLMProvider(LLMProvider):
|
||||
)
|
||||
compat_system = compat_instruction if not system else f"{system}\n\n{compat_instruction}"
|
||||
|
||||
full_messages: list[dict[str, Any]] = [{"role": "system", "content": compat_system}]
|
||||
# If the routed sub-provider honors cache_control (e.g.
|
||||
# openrouter/anthropic/*), split the static prefix from the dynamic
|
||||
# suffix so the prefix stays cache-warm across turns. Otherwise fall
|
||||
# back to a single concatenated string.
|
||||
system_message = _build_system_message(
|
||||
compat_system,
|
||||
system_dynamic_suffix,
|
||||
self.model,
|
||||
)
|
||||
|
||||
full_messages: list[dict[str, Any]] = []
|
||||
if system_message is not None:
|
||||
full_messages.append(system_message)
|
||||
full_messages.extend(messages)
|
||||
return [
|
||||
message
|
||||
@@ -1660,9 +1919,21 @@ class LiteLLMProvider(LLMProvider):
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
max_tokens: int,
|
||||
system_dynamic_suffix: str | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Emulate tool calling via JSON when OpenRouter rejects native tools."""
|
||||
full_messages = self._build_openrouter_tool_compat_messages(messages, system, tools)
|
||||
"""Emulate tool calling via JSON when OpenRouter rejects native tools.
|
||||
|
||||
When the routed sub-provider honors ``cache_control`` (e.g.
|
||||
``openrouter/anthropic/*``), the message builder splits the static
|
||||
prefix from the dynamic suffix so the prefix stays cache-warm.
|
||||
Otherwise the suffix is concatenated into a single system string.
|
||||
"""
|
||||
full_messages = self._build_openrouter_tool_compat_messages(
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
system_dynamic_suffix=system_dynamic_suffix,
|
||||
)
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": full_messages,
|
||||
@@ -1683,6 +1954,8 @@ class LiteLLMProvider(LLMProvider):
|
||||
usage = response.usage
|
||||
input_tokens = usage.prompt_tokens if usage else 0
|
||||
output_tokens = usage.completion_tokens if usage else 0
|
||||
cached_tokens, cache_creation_tokens = _extract_cache_tokens(usage)
|
||||
cost_usd = _extract_cost(response, self.model)
|
||||
stop_reason = "tool_calls" if tool_calls else (response.choices[0].finish_reason or "stop")
|
||||
|
||||
return LLMResponse(
|
||||
@@ -1690,6 +1963,9 @@ class LiteLLMProvider(LLMProvider):
|
||||
model=response.model or self.model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cached_tokens=cached_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
cost_usd=cost_usd,
|
||||
stop_reason=stop_reason,
|
||||
raw_response={
|
||||
"compat_mode": "openrouter_tool_emulation",
|
||||
@@ -1704,6 +1980,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
max_tokens: int,
|
||||
system_dynamic_suffix: str | None = None,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Fallback stream for OpenRouter models without native tool support."""
|
||||
from framework.llm.stream_events import (
|
||||
@@ -1724,6 +2001,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
system_dynamic_suffix=system_dynamic_suffix,
|
||||
)
|
||||
except Exception as e:
|
||||
yield StreamErrorEvent(error=str(e), recoverable=False)
|
||||
@@ -1747,6 +2025,9 @@ class LiteLLMProvider(LLMProvider):
|
||||
stop_reason=response.stop_reason,
|
||||
input_tokens=response.input_tokens,
|
||||
output_tokens=response.output_tokens,
|
||||
cached_tokens=response.cached_tokens,
|
||||
cache_creation_tokens=response.cache_creation_tokens,
|
||||
cost_usd=response.cost_usd,
|
||||
model=response.model,
|
||||
)
|
||||
|
||||
@@ -1758,6 +2039,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
max_tokens: int,
|
||||
response_format: dict[str, Any] | None,
|
||||
json_mode: bool,
|
||||
system_dynamic_suffix: str | None = None,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Fallback path: convert non-stream completion to stream events.
|
||||
|
||||
@@ -1781,6 +2063,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_format,
|
||||
json_mode=json_mode,
|
||||
system_dynamic_suffix=system_dynamic_suffix,
|
||||
)
|
||||
except Exception as e:
|
||||
yield StreamErrorEvent(error=str(e), recoverable=False)
|
||||
@@ -1812,6 +2095,9 @@ class LiteLLMProvider(LLMProvider):
|
||||
stop_reason=response.stop_reason or "stop",
|
||||
input_tokens=response.input_tokens,
|
||||
output_tokens=response.output_tokens,
|
||||
cached_tokens=response.cached_tokens,
|
||||
cache_creation_tokens=response.cache_creation_tokens,
|
||||
cost_usd=response.cost_usd,
|
||||
model=response.model,
|
||||
)
|
||||
|
||||
@@ -1823,6 +2109,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
max_tokens: int = 4096,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
system_dynamic_suffix: str | None = None,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Stream a completion via litellm.acompletion(stream=True).
|
||||
|
||||
@@ -1833,6 +2120,9 @@ class LiteLLMProvider(LLMProvider):
|
||||
Empty responses (e.g. Gemini stealth rate-limits that return 200
|
||||
with no content) are retried with exponential backoff, mirroring
|
||||
the retry behaviour of ``_completion_with_rate_limit_retry``.
|
||||
|
||||
``system_dynamic_suffix`` is an optional per-turn tail. See
|
||||
``acomplete`` docstring for the two-block split semantics.
|
||||
"""
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
@@ -1852,6 +2142,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_format,
|
||||
json_mode=json_mode,
|
||||
system_dynamic_suffix=system_dynamic_suffix,
|
||||
):
|
||||
yield event
|
||||
return
|
||||
@@ -1862,6 +2153,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
system_dynamic_suffix=system_dynamic_suffix,
|
||||
):
|
||||
yield event
|
||||
return
|
||||
@@ -1870,10 +2162,8 @@ class LiteLLMProvider(LLMProvider):
|
||||
if self._claude_code_oauth:
|
||||
billing = _claude_code_billing_header(messages)
|
||||
full_messages.append({"role": "system", "content": billing})
|
||||
if system:
|
||||
sys_msg: dict[str, Any] = {"role": "system", "content": system}
|
||||
if _model_supports_cache_control(self.model):
|
||||
sys_msg["cache_control"] = {"type": "ephemeral"}
|
||||
sys_msg = _build_system_message(system, system_dynamic_suffix, self.model)
|
||||
if sys_msg is not None:
|
||||
full_messages.append(sys_msg)
|
||||
full_messages.extend(messages)
|
||||
|
||||
@@ -2109,37 +2399,44 @@ class LiteLLMProvider(LLMProvider):
|
||||
type(usage).__name__,
|
||||
)
|
||||
cached_tokens = 0
|
||||
cache_creation_tokens = 0
|
||||
if usage:
|
||||
input_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
output_tokens = getattr(usage, "completion_tokens", 0) or 0
|
||||
_details = getattr(usage, "prompt_tokens_details", None)
|
||||
cached_tokens = (
|
||||
getattr(_details, "cached_tokens", 0) or 0
|
||||
if _details is not None
|
||||
else getattr(usage, "cache_read_input_tokens", 0) or 0
|
||||
)
|
||||
cached_tokens, cache_creation_tokens = _extract_cache_tokens(usage)
|
||||
logger.debug(
|
||||
"[tokens] finish-chunk usage: input=%d output=%d cached=%d model=%s",
|
||||
"[tokens] finish-chunk usage: input=%d output=%d cached=%d cache_creation=%d model=%s",
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cached_tokens,
|
||||
cache_creation_tokens,
|
||||
self.model,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"[tokens] finish event: input=%d output=%d cached=%d stop=%s model=%s",
|
||||
"[tokens] finish event: input=%d output=%d cached=%d cache_creation=%d stop=%s model=%s",
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cached_tokens,
|
||||
cache_creation_tokens,
|
||||
choice.finish_reason,
|
||||
self.model,
|
||||
)
|
||||
cost_usd = _cost_from_tokens(
|
||||
self.model,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cached_tokens,
|
||||
cache_creation_tokens,
|
||||
)
|
||||
tail_events.append(
|
||||
FinishEvent(
|
||||
stop_reason=choice.finish_reason,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cached_tokens=cached_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
cost_usd=cost_usd,
|
||||
model=self.model,
|
||||
)
|
||||
)
|
||||
@@ -2159,19 +2456,36 @@ class LiteLLMProvider(LLMProvider):
|
||||
_usage = calculate_total_usage(chunks=_chunks)
|
||||
input_tokens = _usage.prompt_tokens or 0
|
||||
output_tokens = _usage.completion_tokens or 0
|
||||
_details = getattr(_usage, "prompt_tokens_details", None)
|
||||
cached_tokens = (
|
||||
getattr(_details, "cached_tokens", 0) or 0
|
||||
if _details is not None
|
||||
else getattr(_usage, "cache_read_input_tokens", 0) or 0
|
||||
)
|
||||
# `calculate_total_usage` aggregates token totals
|
||||
# but discards `prompt_tokens_details` — which is
|
||||
# where OpenRouter puts `cached_tokens` and
|
||||
# `cache_write_tokens`. Recover them directly
|
||||
# from the most recent chunk that carries usage.
|
||||
cached_tokens, cache_creation_tokens = 0, 0
|
||||
for _raw in reversed(_chunks):
|
||||
_raw_usage = getattr(_raw, "usage", None)
|
||||
if _raw_usage is None:
|
||||
continue
|
||||
_cr, _cc = _extract_cache_tokens(_raw_usage)
|
||||
if _cr or _cc:
|
||||
cached_tokens, cache_creation_tokens = _cr, _cc
|
||||
break
|
||||
logger.debug(
|
||||
"[tokens] post-loop chunks fallback: input=%d output=%d cached=%d model=%s",
|
||||
"[tokens] post-loop chunks fallback: input=%d output=%d "
|
||||
"cached=%d cache_creation=%d model=%s",
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cached_tokens,
|
||||
cache_creation_tokens,
|
||||
self.model,
|
||||
)
|
||||
cost_usd = _cost_from_tokens(
|
||||
self.model,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cached_tokens,
|
||||
cache_creation_tokens,
|
||||
)
|
||||
# Patch the FinishEvent already queued with 0 tokens
|
||||
for _i, _ev in enumerate(tail_events):
|
||||
if isinstance(_ev, FinishEvent) and _ev.input_tokens == 0:
|
||||
@@ -2180,6 +2494,8 @@ class LiteLLMProvider(LLMProvider):
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cached_tokens=cached_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
cost_usd=cost_usd,
|
||||
model=_ev.model,
|
||||
)
|
||||
break
|
||||
@@ -2390,6 +2706,8 @@ class LiteLLMProvider(LLMProvider):
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
cached_tokens = 0
|
||||
cache_creation_tokens = 0
|
||||
stop_reason = ""
|
||||
model = self.model
|
||||
|
||||
@@ -2407,6 +2725,8 @@ class LiteLLMProvider(LLMProvider):
|
||||
elif isinstance(event, FinishEvent):
|
||||
input_tokens = event.input_tokens
|
||||
output_tokens = event.output_tokens
|
||||
cached_tokens = event.cached_tokens
|
||||
cache_creation_tokens = event.cache_creation_tokens
|
||||
stop_reason = event.stop_reason
|
||||
if event.model:
|
||||
model = event.model
|
||||
@@ -2419,6 +2739,8 @@ class LiteLLMProvider(LLMProvider):
|
||||
model=model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cached_tokens=cached_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
stop_reason=stop_reason,
|
||||
raw_response={"tool_calls": tool_calls} if tool_calls else None,
|
||||
)
|
||||
|
||||
@@ -155,8 +155,11 @@ class MockLLMProvider(LLMProvider):
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
system_dynamic_suffix: str | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Async mock completion (no I/O, returns immediately)."""
|
||||
if system_dynamic_suffix:
|
||||
system = f"{system}\n\n{system_dynamic_suffix}" if system else system_dynamic_suffix
|
||||
return self.complete(
|
||||
messages=messages,
|
||||
system=system,
|
||||
@@ -173,6 +176,7 @@ class MockLLMProvider(LLMProvider):
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
system_dynamic_suffix: str | None = None,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Stream a mock completion as word-level TextDeltaEvents.
|
||||
|
||||
@@ -180,6 +184,8 @@ class MockLLMProvider(LLMProvider):
|
||||
TextDeltaEvent with an accumulating snapshot, exercising the full
|
||||
streaming pipeline without any API calls.
|
||||
"""
|
||||
if system_dynamic_suffix:
|
||||
system = f"{system}\n\n{system_dynamic_suffix}" if system else system_dynamic_suffix
|
||||
content = self._generate_mock_response(system=system, json_mode=False)
|
||||
words = content.split(" ")
|
||||
accumulated = ""
|
||||
|
||||
@@ -9,47 +9,65 @@
|
||||
"label": "Haiku 4.5 - Fast + cheap",
|
||||
"recommended": false,
|
||||
"max_tokens": 64000,
|
||||
"max_context_tokens": 136000
|
||||
"max_context_tokens": 136000,
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "claude-sonnet-4-5-20250929",
|
||||
"label": "Sonnet 4.5 - Best balance",
|
||||
"recommended": false,
|
||||
"max_tokens": 64000,
|
||||
"max_context_tokens": 136000
|
||||
"max_context_tokens": 136000,
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "claude-opus-4-6",
|
||||
"label": "Opus 4.6 - Most capable",
|
||||
"recommended": true,
|
||||
"max_tokens": 128000,
|
||||
"max_context_tokens": 872000
|
||||
"max_context_tokens": 872000,
|
||||
"supports_vision": true
|
||||
}
|
||||
]
|
||||
},
|
||||
"openai": {
|
||||
"default_model": "gpt-5.4",
|
||||
"default_model": "gpt-5.5",
|
||||
"models": [
|
||||
{
|
||||
"id": "gpt-5.4",
|
||||
"label": "GPT-5.4 - Best intelligence",
|
||||
"id": "gpt-5.5",
|
||||
"label": "GPT-5.5 - Frontier coding + reasoning",
|
||||
"recommended": true,
|
||||
"max_tokens": 128000,
|
||||
"max_context_tokens": 960000
|
||||
"max_context_tokens": 1050000,
|
||||
"pricing_usd_per_mtok": {
|
||||
"input": 5.00,
|
||||
"output": 30.00
|
||||
},
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.4",
|
||||
"label": "GPT-5.4 - Previous flagship",
|
||||
"recommended": false,
|
||||
"max_tokens": 128000,
|
||||
"max_context_tokens": 960000,
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.4-mini",
|
||||
"label": "GPT-5.4 Mini - Faster + cheaper",
|
||||
"recommended": false,
|
||||
"max_tokens": 128000,
|
||||
"max_context_tokens": 400000
|
||||
"max_context_tokens": 400000,
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.4-nano",
|
||||
"label": "GPT-5.4 Nano - Cheapest high-volume",
|
||||
"recommended": false,
|
||||
"max_tokens": 128000,
|
||||
"max_context_tokens": 400000
|
||||
"max_context_tokens": 400000,
|
||||
"supports_vision": true
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -61,14 +79,16 @@
|
||||
"label": "Gemini 3 Flash - Fast",
|
||||
"recommended": false,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 240000
|
||||
"max_context_tokens": 240000,
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "gemini-3.1-pro-preview-customtools",
|
||||
"label": "Gemini 3.1 Pro - Best quality",
|
||||
"recommended": true,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 240000
|
||||
"max_context_tokens": 240000,
|
||||
"supports_vision": true
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -80,28 +100,32 @@
|
||||
"label": "GPT-OSS 120B - Best reasoning",
|
||||
"recommended": true,
|
||||
"max_tokens": 65536,
|
||||
"max_context_tokens": 131072
|
||||
"max_context_tokens": 131072,
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "openai/gpt-oss-20b",
|
||||
"label": "GPT-OSS 20B - Fast + cheaper",
|
||||
"recommended": false,
|
||||
"max_tokens": 65536,
|
||||
"max_context_tokens": 131072
|
||||
"max_context_tokens": 131072,
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "llama-3.3-70b-versatile",
|
||||
"label": "Llama 3.3 70B - General purpose",
|
||||
"recommended": false,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 131072
|
||||
"max_context_tokens": 131072,
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "llama-3.1-8b-instant",
|
||||
"label": "Llama 3.1 8B - Fastest",
|
||||
"recommended": false,
|
||||
"max_tokens": 131072,
|
||||
"max_context_tokens": 131072
|
||||
"max_context_tokens": 131072,
|
||||
"supports_vision": false
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -113,21 +137,24 @@
|
||||
"label": "GPT-OSS 120B - Best production reasoning",
|
||||
"recommended": true,
|
||||
"max_tokens": 40960,
|
||||
"max_context_tokens": 131072
|
||||
"max_context_tokens": 131072,
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "zai-glm-4.7",
|
||||
"label": "Z.ai GLM 4.7 - Strong coding preview",
|
||||
"recommended": true,
|
||||
"max_tokens": 40960,
|
||||
"max_context_tokens": 131072
|
||||
"max_context_tokens": 131072,
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "qwen-3-235b-a22b-instruct-2507",
|
||||
"label": "Qwen 3 235B Instruct - Frontier preview",
|
||||
"recommended": false,
|
||||
"max_tokens": 40960,
|
||||
"max_context_tokens": 131072
|
||||
"max_context_tokens": 131072,
|
||||
"supports_vision": false
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -139,14 +166,20 @@
|
||||
"label": "MiniMax M2.7 - Best coding quality",
|
||||
"recommended": true,
|
||||
"max_tokens": 40960,
|
||||
"max_context_tokens": 180000
|
||||
"max_context_tokens": 180000,
|
||||
"pricing_usd_per_mtok": {
|
||||
"input": 0.30,
|
||||
"output": 1.20
|
||||
},
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "MiniMax-M2.5",
|
||||
"label": "MiniMax M2.5 - Strong value",
|
||||
"recommended": false,
|
||||
"max_tokens": 40960,
|
||||
"max_context_tokens": 180000
|
||||
"max_context_tokens": 180000,
|
||||
"supports_vision": false
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -158,28 +191,32 @@
|
||||
"label": "Mistral Large 3 - Best quality",
|
||||
"recommended": true,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 256000
|
||||
"max_context_tokens": 256000,
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "mistral-medium-2508",
|
||||
"label": "Mistral Medium 3.1 - Balanced",
|
||||
"recommended": false,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 128000
|
||||
"max_context_tokens": 128000,
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "mistral-small-2603",
|
||||
"label": "Mistral Small 4 - Fast + capable",
|
||||
"recommended": false,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 256000
|
||||
"max_context_tokens": 256000,
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "codestral-2508",
|
||||
"label": "Codestral - Coding specialist",
|
||||
"recommended": false,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 128000
|
||||
"max_context_tokens": 128000,
|
||||
"supports_vision": false
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -191,47 +228,71 @@
|
||||
"label": "DeepSeek V3.1 - Best general coding",
|
||||
"recommended": true,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 128000
|
||||
"max_context_tokens": 128000,
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8",
|
||||
"label": "Qwen3 Coder 480B - Advanced coding",
|
||||
"recommended": false,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 262144
|
||||
"max_context_tokens": 262144,
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "openai/gpt-oss-120b",
|
||||
"label": "GPT-OSS 120B - Strong reasoning",
|
||||
"recommended": false,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 128000
|
||||
"max_context_tokens": 128000,
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
||||
"label": "Llama 3.3 70B Turbo - Fast baseline",
|
||||
"recommended": false,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 131072
|
||||
"max_context_tokens": 131072,
|
||||
"supports_vision": false
|
||||
}
|
||||
]
|
||||
},
|
||||
"deepseek": {
|
||||
"default_model": "deepseek-chat",
|
||||
"default_model": "deepseek-v4-pro",
|
||||
"models": [
|
||||
{
|
||||
"id": "deepseek-chat",
|
||||
"label": "DeepSeek Chat - Fast default",
|
||||
"id": "deepseek-v4-pro",
|
||||
"label": "DeepSeek V4 Pro - Most capable",
|
||||
"recommended": true,
|
||||
"max_tokens": 8192,
|
||||
"max_context_tokens": 128000
|
||||
"max_tokens": 384000,
|
||||
"max_context_tokens": 1000000,
|
||||
"pricing_usd_per_mtok": {
|
||||
"input": 1.74,
|
||||
"output": 3.48,
|
||||
"cache_read": 0.145
|
||||
},
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "deepseek-v4-flash",
|
||||
"label": "DeepSeek V4 Flash - Fast + cheap",
|
||||
"recommended": true,
|
||||
"max_tokens": 384000,
|
||||
"max_context_tokens": 1000000,
|
||||
"pricing_usd_per_mtok": {
|
||||
"input": 0.14,
|
||||
"output": 0.28,
|
||||
"cache_read": 0.028
|
||||
},
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "deepseek-reasoner",
|
||||
"label": "DeepSeek Reasoner - Deep thinking",
|
||||
"label": "DeepSeek Reasoner - Legacy (deprecating)",
|
||||
"recommended": false,
|
||||
"max_tokens": 64000,
|
||||
"max_context_tokens": 128000
|
||||
"max_context_tokens": 128000,
|
||||
"supports_vision": false
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -243,7 +304,13 @@
|
||||
"label": "Kimi K2.5 - Best coding",
|
||||
"recommended": true,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 200000
|
||||
"max_context_tokens": 200000,
|
||||
"pricing_usd_per_mtok": {
|
||||
"input": 0.60,
|
||||
"output": 2.50,
|
||||
"cache_read": 0.15
|
||||
},
|
||||
"supports_vision": true
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -255,21 +322,30 @@
|
||||
"label": "Queen - Hive native",
|
||||
"recommended": true,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 180000
|
||||
"max_context_tokens": 180000,
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "kimi-2.5",
|
||||
"label": "Kimi 2.5 - Via Hive",
|
||||
"recommended": false,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 240000
|
||||
"max_context_tokens": 240000,
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "GLM-5",
|
||||
"label": "GLM-5 - Via Hive",
|
||||
"id": "glm-5.1",
|
||||
"label": "GLM-5.1 - Via Hive",
|
||||
"recommended": false,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 180000
|
||||
"max_context_tokens": 180000,
|
||||
"pricing_usd_per_mtok": {
|
||||
"input": 1.40,
|
||||
"output": 4.40,
|
||||
"cache_read": 0.26,
|
||||
"cache_creation": 0.0
|
||||
},
|
||||
"supports_vision": false
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -281,63 +357,82 @@
|
||||
"label": "GPT-5.4 - Best overall",
|
||||
"recommended": true,
|
||||
"max_tokens": 128000,
|
||||
"max_context_tokens": 872000
|
||||
"max_context_tokens": 872000,
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "anthropic/claude-sonnet-4.6",
|
||||
"label": "Claude Sonnet 4.6 - Best coding balance",
|
||||
"recommended": false,
|
||||
"max_tokens": 64000,
|
||||
"max_context_tokens": 872000
|
||||
"max_context_tokens": 872000,
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "anthropic/claude-opus-4.6",
|
||||
"label": "Claude Opus 4.6 - Most capable",
|
||||
"recommended": false,
|
||||
"max_tokens": 128000,
|
||||
"max_context_tokens": 872000
|
||||
"max_context_tokens": 872000,
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "google/gemini-3.1-pro-preview-customtools",
|
||||
"label": "Gemini 3.1 Pro Preview - Long-context reasoning",
|
||||
"recommended": false,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 872000
|
||||
"max_context_tokens": 872000,
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "qwen/qwen3.6-plus",
|
||||
"label": "Qwen 3.6 Plus - Strong reasoning",
|
||||
"recommended": true,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 240000
|
||||
"max_context_tokens": 240000,
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "z-ai/glm-5v-turbo",
|
||||
"label": "GLM-5V Turbo - Vision capable",
|
||||
"recommended": true,
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 192000
|
||||
"max_context_tokens": 192000,
|
||||
"supports_vision": true
|
||||
},
|
||||
{
|
||||
"id": "z-ai/glm-5.1",
|
||||
"label": "GLM-5.1 - Better but Slower",
|
||||
"recommended": true,
|
||||
"max_tokens": 40960,
|
||||
"max_context_tokens": 192000
|
||||
"max_context_tokens": 192000,
|
||||
"pricing_usd_per_mtok": {
|
||||
"input": 1.40,
|
||||
"output": 4.40,
|
||||
"cache_read": 0.26,
|
||||
"cache_creation": 0.0
|
||||
},
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "minimax/minimax-m2.7",
|
||||
"label": "Minimax M2.7 - Minimax flagship",
|
||||
"recommended": false,
|
||||
"max_tokens": 40960,
|
||||
"max_context_tokens": 180000
|
||||
"max_context_tokens": 180000,
|
||||
"pricing_usd_per_mtok": {
|
||||
"input": 0.30,
|
||||
"output": 1.20
|
||||
},
|
||||
"supports_vision": false
|
||||
},
|
||||
{
|
||||
"id": "xiaomi/mimo-v2-pro",
|
||||
"label": "MiMo V2 Pro - Xiaomi multimodal",
|
||||
"recommended": true,
|
||||
"max_tokens": 64000,
|
||||
"max_context_tokens": 240000
|
||||
"max_context_tokens": 240000,
|
||||
"supports_vision": true
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -352,7 +447,7 @@
|
||||
"zai_code": {
|
||||
"provider": "openai",
|
||||
"api_key_env_var": "ZAI_API_KEY",
|
||||
"model": "glm-5",
|
||||
"model": "glm-5.1",
|
||||
"max_tokens": 32768,
|
||||
"max_context_tokens": 180000,
|
||||
"api_base": "https://api.z.ai/api/coding/paas/v4"
|
||||
@@ -399,8 +494,8 @@
|
||||
"recommended": false
|
||||
},
|
||||
{
|
||||
"id": "GLM-5",
|
||||
"label": "GLM-5",
|
||||
"id": "glm-5.1",
|
||||
"label": "glm-5.1",
|
||||
"recommended": false
|
||||
}
|
||||
]
|
||||
|
||||
@@ -27,6 +27,28 @@ def _require_list(value: Any, path: str) -> list[Any]:
|
||||
return value
|
||||
|
||||
|
||||
_PRICING_KEYS = ("input", "output", "cache_read", "cache_creation")
|
||||
|
||||
|
||||
def _validate_pricing(value: Any, path: str) -> None:
|
||||
"""Validate an optional ``pricing_usd_per_mtok`` block.
|
||||
|
||||
Keys are USD-per-million-tokens rates. ``input``/``output`` are required;
|
||||
``cache_read``/``cache_creation`` are optional. All values must be
|
||||
non-negative numbers. Used as a last-resort fallback when neither the
|
||||
provider nor LiteLLM's catalog reports a cost.
|
||||
"""
|
||||
pricing = _require_mapping(value, path)
|
||||
for key in ("input", "output"):
|
||||
if key not in pricing:
|
||||
raise ModelCatalogError(f"{path}.{key} is required")
|
||||
for key, rate in pricing.items():
|
||||
if key not in _PRICING_KEYS:
|
||||
raise ModelCatalogError(f"{path}.{key} is not a recognized pricing field")
|
||||
if not isinstance(rate, (int, float)) or isinstance(rate, bool) or rate < 0:
|
||||
raise ModelCatalogError(f"{path}.{key} must be a non-negative number")
|
||||
|
||||
|
||||
def _validate_model_catalog(data: dict[str, Any]) -> dict[str, Any]:
|
||||
providers = _require_mapping(data.get("providers"), "providers")
|
||||
|
||||
@@ -69,6 +91,14 @@ def _validate_model_catalog(data: dict[str, Any]) -> dict[str, Any]:
|
||||
if not isinstance(value, int) or value <= 0:
|
||||
raise ModelCatalogError(f"{model_path}.{key} must be a positive integer")
|
||||
|
||||
pricing = model_map.get("pricing_usd_per_mtok")
|
||||
if pricing is not None:
|
||||
_validate_pricing(pricing, f"{model_path}.pricing_usd_per_mtok")
|
||||
|
||||
supports_vision = model_map.get("supports_vision")
|
||||
if supports_vision is not None and not isinstance(supports_vision, bool):
|
||||
raise ModelCatalogError(f"{model_path}.supports_vision must be a boolean when present")
|
||||
|
||||
if not default_found:
|
||||
raise ModelCatalogError(
|
||||
f"{provider_path}.default_model={default_model!r} is not present in {provider_path}.models"
|
||||
@@ -184,6 +214,53 @@ def get_model_limits(provider: str, model_id: str) -> tuple[int, int] | None:
|
||||
return int(model["max_tokens"]), int(model["max_context_tokens"])
|
||||
|
||||
|
||||
def get_model_pricing(model_id: str) -> dict[str, float] | None:
|
||||
"""Return ``pricing_usd_per_mtok`` for a model id, searching all providers.
|
||||
|
||||
Returns ``None`` when the model is absent from the catalog or has no
|
||||
pricing entry. Used by the cost-extraction fallback in ``litellm.py``
|
||||
when the provider response and LiteLLM's catalog both come up empty.
|
||||
"""
|
||||
if not model_id:
|
||||
return None
|
||||
for provider_info in load_model_catalog()["providers"].values():
|
||||
for model in provider_info["models"]:
|
||||
if model["id"] == model_id:
|
||||
pricing = model.get("pricing_usd_per_mtok")
|
||||
if pricing is None:
|
||||
return None
|
||||
return {key: float(rate) for key, rate in pricing.items()}
|
||||
return None
|
||||
|
||||
|
||||
def model_supports_vision(model_id: str) -> bool:
|
||||
"""Return whether *model_id* supports image inputs per the curated catalog.
|
||||
|
||||
Looks up the bare model id (and the provider-prefix-stripped form) in the
|
||||
catalog. Returns the model's ``supports_vision`` flag when found, defaulting
|
||||
to ``True`` for unknown models or when the flag is absent — assume vision
|
||||
capable for hosted providers, since modern frontier models support images
|
||||
by default and the captioning fallback is more expensive than just letting
|
||||
the provider handle the image.
|
||||
"""
|
||||
if not model_id:
|
||||
return True
|
||||
|
||||
candidates = [model_id]
|
||||
if "/" in model_id:
|
||||
candidates.append(model_id.split("/", 1)[1])
|
||||
|
||||
for candidate in candidates:
|
||||
for provider_info in load_model_catalog()["providers"].values():
|
||||
for model in provider_info["models"]:
|
||||
if model["id"] == candidate:
|
||||
flag = model.get("supports_vision")
|
||||
if isinstance(flag, bool):
|
||||
return flag
|
||||
return True
|
||||
return True
|
||||
|
||||
|
||||
def get_preset(preset_id: str) -> dict[str, Any] | None:
|
||||
"""Return one preset entry."""
|
||||
preset = load_model_catalog()["presets"].get(preset_id)
|
||||
|
||||
@@ -10,12 +10,24 @@ from typing import Any
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""Response from an LLM call."""
|
||||
"""Response from an LLM call.
|
||||
|
||||
``cached_tokens`` and ``cache_creation_tokens`` are subsets of
|
||||
``input_tokens`` (providers report them inside ``prompt_tokens``).
|
||||
Surface them for visibility; do not add to a total.
|
||||
|
||||
``cost_usd`` is the per-call USD cost when the provider / pricing table
|
||||
can produce one (Anthropic, OpenAI, OpenRouter are supported). 0.0 when
|
||||
unknown or unpriced — treat as "unreported", not "free".
|
||||
"""
|
||||
|
||||
content: str
|
||||
model: str
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cached_tokens: int = 0
|
||||
cache_creation_tokens: int = 0
|
||||
cost_usd: float = 0.0
|
||||
stop_reason: str = ""
|
||||
raw_response: Any = None
|
||||
|
||||
@@ -110,19 +122,28 @@ class LLMProvider(ABC):
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
max_retries: int | None = None,
|
||||
system_dynamic_suffix: str | None = None,
|
||||
) -> "LLMResponse":
|
||||
"""Async version of complete(). Non-blocking on the event loop.
|
||||
|
||||
Default implementation offloads the sync complete() to a thread pool.
|
||||
Subclasses SHOULD override for native async I/O.
|
||||
|
||||
``system_dynamic_suffix`` is an optional per-turn tail for providers
|
||||
that honor ``cache_control`` (see LiteLLMProvider for semantics).
|
||||
The default implementation concatenates it onto ``system`` since the
|
||||
sync ``complete()`` path does not support the split.
|
||||
"""
|
||||
combined_system = system
|
||||
if system_dynamic_suffix:
|
||||
combined_system = f"{system}\n\n{system_dynamic_suffix}" if system else system_dynamic_suffix
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
partial(
|
||||
self.complete,
|
||||
messages=messages,
|
||||
system=system,
|
||||
system=combined_system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_format,
|
||||
@@ -137,6 +158,7 @@ class LLMProvider(ABC):
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
system_dynamic_suffix: str | None = None,
|
||||
) -> AsyncIterator["StreamEvent"]:
|
||||
"""
|
||||
Stream a completion as an async iterator of StreamEvents.
|
||||
@@ -147,6 +169,9 @@ class LLMProvider(ABC):
|
||||
Tool orchestration is the CALLER's responsibility:
|
||||
- Caller detects ToolCallEvent, executes tool, adds result
|
||||
to messages, calls stream() again.
|
||||
|
||||
``system_dynamic_suffix`` is forwarded to ``acomplete``; see its
|
||||
docstring for the two-block split semantics.
|
||||
"""
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
@@ -159,6 +184,7 @@ class LLMProvider(ABC):
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
system_dynamic_suffix=system_dynamic_suffix,
|
||||
)
|
||||
yield TextDeltaEvent(content=response.content, snapshot=response.content)
|
||||
yield TextEndEvent(full_text=response.content)
|
||||
@@ -166,6 +192,9 @@ class LLMProvider(ABC):
|
||||
stop_reason=response.stop_reason,
|
||||
input_tokens=response.input_tokens,
|
||||
output_tokens=response.output_tokens,
|
||||
cached_tokens=response.cached_tokens,
|
||||
cache_creation_tokens=response.cache_creation_tokens,
|
||||
cost_usd=response.cost_usd,
|
||||
model=response.model,
|
||||
)
|
||||
|
||||
|
||||
@@ -65,13 +65,23 @@ class ReasoningDeltaEvent:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FinishEvent:
|
||||
"""The LLM has finished generating."""
|
||||
"""The LLM has finished generating.
|
||||
|
||||
``cached_tokens`` and ``cache_creation_tokens`` are subsets of
|
||||
``input_tokens`` — providers count both inside ``prompt_tokens`` already.
|
||||
Surface them separately for visibility; never add to a total.
|
||||
|
||||
``cost_usd`` is the per-turn USD cost when the provider or LiteLLM's
|
||||
pricing table supplies one; 0.0 means unreported (not free).
|
||||
"""
|
||||
|
||||
type: Literal["finish"] = "finish"
|
||||
stop_reason: str = ""
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cached_tokens: int = 0
|
||||
cache_creation_tokens: int = 0
|
||||
cost_usd: float = 0.0
|
||||
model: str = ""
|
||||
|
||||
|
||||
|
||||
@@ -9,8 +9,8 @@ Nodes that need browser access declare ``tools: {policy: "all"}`` in their
|
||||
agent.json config.
|
||||
|
||||
Note: the canonical source of truth for browser automation guidance is
|
||||
the ``browser-automation`` default skill at
|
||||
``core/framework/skills/_default_skills/browser-automation/SKILL.md``.
|
||||
the ``browser-automation`` preset skill at
|
||||
``core/framework/skills/_preset_skills/browser-automation/SKILL.md``.
|
||||
Activate that skill for the full decision tree. This module holds a
|
||||
compact subset suitable for direct inlining into a node's system prompt
|
||||
when a skill activation is not desired.
|
||||
|
||||
@@ -543,6 +543,10 @@ class NodeContext:
|
||||
# Dynamic memory provider — when set, EventLoopNode rebuilds the
|
||||
# system prompt with the latest memory block each iteration.
|
||||
dynamic_memory_provider: Any = None # Callable[[], str] | None
|
||||
# Surgical skills-catalog refresh, same contract as AgentContext's
|
||||
# field of the same name. Lets workers pick up UI-driven skill
|
||||
# toggles without rebuilding the full system prompt each turn.
|
||||
dynamic_skills_catalog_provider: Any = None # Callable[[], str] | None
|
||||
|
||||
# Skill system prompts — injected by the skill discovery pipeline
|
||||
skills_catalog_prompt: str = "" # Available skills XML catalog
|
||||
|
||||
@@ -155,6 +155,17 @@ class SessionState(BaseModel):
|
||||
# True after first successful worker execution (gates trigger delivery on restart)
|
||||
worker_configured: bool = Field(default=False)
|
||||
|
||||
# Task-system fields (see framework/tasks).
|
||||
# task_list_id: this session's own task list id (populated on first
|
||||
# task_create; immutable thereafter). Used for resume reattachment —
|
||||
# if it differs from resolve_task_list_id(ctx) on resume, a
|
||||
# TASK_LIST_REATTACH_MISMATCH event is emitted and a fresh list is
|
||||
# created at the resolved id (the orphan stays on disk).
|
||||
task_list_id: str | None = None
|
||||
# picked_up_from: for worker sessions, the (colony_task_list_id,
|
||||
# template_task_id) pair this session was spawned for.
|
||||
picked_up_from: list[Any] | None = None
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
@property
|
||||
|
||||
@@ -140,6 +140,25 @@ async def cors_middleware(request: web.Request, handler):
|
||||
return response
|
||||
|
||||
|
||||
@web.middleware
|
||||
async def no_cache_api_middleware(request: web.Request, handler):
|
||||
"""Prevent browsers from caching API responses.
|
||||
|
||||
Without this, a one-off bad response (e.g. the SPA catch-all leaking
|
||||
index.html for an /api/* URL before a route was registered) can get
|
||||
pinned in the browser's disk cache and replayed forever, since our
|
||||
JSON handlers don't emit ETag/Last-Modified and browsers fall back
|
||||
to heuristic freshness.
|
||||
"""
|
||||
try:
|
||||
response = await handler(request)
|
||||
except web.HTTPException as exc:
|
||||
response = exc
|
||||
if request.path.startswith("/api/"):
|
||||
response.headers["Cache-Control"] = "no-store"
|
||||
return response
|
||||
|
||||
|
||||
@web.middleware
|
||||
async def error_middleware(request: web.Request, handler):
|
||||
"""Catch exceptions and return JSON error responses.
|
||||
@@ -268,7 +287,7 @@ def create_app(model: str | None = None) -> web.Application:
|
||||
Returns:
|
||||
Configured aiohttp Application ready to run.
|
||||
"""
|
||||
app = web.Application(middlewares=[cors_middleware, error_middleware])
|
||||
app = web.Application(middlewares=[cors_middleware, no_cache_api_middleware, error_middleware])
|
||||
|
||||
# Initialize credential store (before SessionManager so it can be shared)
|
||||
from framework.credentials.store import CredentialStore
|
||||
@@ -325,16 +344,21 @@ def create_app(model: str | None = None) -> web.Application:
|
||||
app.router.add_get("/api/browser/status/stream", handle_browser_status_stream)
|
||||
|
||||
# Register route modules
|
||||
from framework.server.routes_colony_tools import register_routes as register_colony_tools_routes
|
||||
from framework.server.routes_colony_workers import register_routes as register_colony_worker_routes
|
||||
from framework.server.routes_config import register_routes as register_config_routes
|
||||
from framework.server.routes_credentials import register_routes as register_credential_routes
|
||||
from framework.server.routes_events import register_routes as register_event_routes
|
||||
from framework.server.routes_execution import register_routes as register_execution_routes
|
||||
from framework.server.routes_logs import register_routes as register_log_routes
|
||||
from framework.server.routes_mcp import register_routes as register_mcp_routes
|
||||
from framework.server.routes_messages import register_routes as register_message_routes
|
||||
from framework.server.routes_prompts import register_routes as register_prompt_routes
|
||||
from framework.server.routes_queen_tools import register_routes as register_queen_tools_routes
|
||||
from framework.server.routes_queens import register_routes as register_queen_routes
|
||||
from framework.server.routes_sessions import register_routes as register_session_routes
|
||||
from framework.server.routes_skills import register_routes as register_skills_routes
|
||||
from framework.server.routes_tasks import register_routes as register_task_routes
|
||||
from framework.server.routes_workers import register_routes as register_worker_routes
|
||||
|
||||
register_config_routes(app)
|
||||
@@ -346,8 +370,13 @@ def create_app(model: str | None = None) -> web.Application:
|
||||
register_worker_routes(app)
|
||||
register_log_routes(app)
|
||||
register_queen_routes(app)
|
||||
register_queen_tools_routes(app)
|
||||
register_colony_tools_routes(app)
|
||||
register_mcp_routes(app)
|
||||
register_colony_worker_routes(app)
|
||||
register_prompt_routes(app)
|
||||
register_skills_routes(app)
|
||||
register_task_routes(app)
|
||||
|
||||
# Static file serving — Option C production mode
|
||||
# If frontend/dist/ exists, serve built frontend files on /
|
||||
|
||||
@@ -253,6 +253,92 @@ async def materialize_queen_identity(
|
||||
)
|
||||
|
||||
|
||||
def build_queen_tool_registry_bare() -> tuple[Any, dict[str, list[dict[str, Any]]]]:
|
||||
"""Build a Queen ``ToolRegistry`` and a (server_name → tools) catalog.
|
||||
|
||||
Used by the Tool Library GET route to populate the MCP tool surface
|
||||
without needing a live queen session. We DO NOT register queen
|
||||
lifecycle tools here (they require a Session stub); the catalog only
|
||||
covers MCP-origin tools, which is what the allowlist gates.
|
||||
|
||||
Loading MCP servers spawns subprocesses, so call this once per
|
||||
backend process and cache the result.
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import framework.agents.queen as _queen_pkg
|
||||
from framework.loader.mcp_registry import MCPRegistry
|
||||
from framework.loader.tool_registry import ToolRegistry
|
||||
|
||||
queen_registry = ToolRegistry()
|
||||
queen_pkg_dir = Path(_queen_pkg.__file__).parent
|
||||
|
||||
mcp_config = queen_pkg_dir / "mcp_servers.json"
|
||||
if mcp_config.exists():
|
||||
try:
|
||||
queen_registry.load_mcp_config(mcp_config)
|
||||
except Exception:
|
||||
logger.warning("build_queen_tool_registry_bare: MCP config failed", exc_info=True)
|
||||
|
||||
try:
|
||||
reg = MCPRegistry()
|
||||
reg.initialize()
|
||||
if (queen_pkg_dir / "mcp_registry.json").is_file():
|
||||
queen_registry.set_mcp_registry_agent_path(queen_pkg_dir)
|
||||
registry_configs, selection_max_tools = reg.load_agent_selection(queen_pkg_dir)
|
||||
|
||||
already = {cfg.get("name") for cfg in registry_configs if cfg.get("name")}
|
||||
extra: list[str] = []
|
||||
try:
|
||||
for entry in reg.list_installed():
|
||||
if entry.get("source") != "local":
|
||||
continue
|
||||
if not entry.get("enabled", True):
|
||||
continue
|
||||
name = entry.get("name")
|
||||
if name and name not in already:
|
||||
extra.append(name)
|
||||
except Exception:
|
||||
pass
|
||||
if extra:
|
||||
try:
|
||||
extra_configs = reg.resolve_for_agent(include=extra)
|
||||
registry_configs = list(registry_configs) + [reg._server_config_to_dict(c) for c in extra_configs]
|
||||
except Exception:
|
||||
logger.debug("build_queen_tool_registry_bare: resolve_for_agent(extra) failed", exc_info=True)
|
||||
|
||||
if registry_configs:
|
||||
queen_registry.load_registry_servers(
|
||||
registry_configs,
|
||||
preserve_existing_tools=True,
|
||||
log_collisions=False,
|
||||
max_tools=selection_max_tools,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("build_queen_tool_registry_bare: MCP registry load failed", exc_info=True)
|
||||
|
||||
# Build the catalog.
|
||||
tools_by_name = queen_registry.get_tools()
|
||||
server_map = dict(getattr(queen_registry, "_mcp_server_tools", {}) or {})
|
||||
catalog: dict[str, list[dict[str, Any]]] = {}
|
||||
for server_name in sorted(server_map):
|
||||
entries: list[dict[str, Any]] = []
|
||||
for tool_name in sorted(server_map[server_name]):
|
||||
tool = tools_by_name.get(tool_name)
|
||||
if tool is None:
|
||||
continue
|
||||
entries.append(
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": tool.parameters,
|
||||
}
|
||||
)
|
||||
catalog[server_name] = entries
|
||||
|
||||
return queen_registry, catalog
|
||||
|
||||
|
||||
async def create_queen(
|
||||
session: Session,
|
||||
session_manager: Any,
|
||||
@@ -326,6 +412,45 @@ async def create_queen(
|
||||
if (queen_pkg_dir / "mcp_registry.json").is_file():
|
||||
queen_registry.set_mcp_registry_agent_path(queen_pkg_dir)
|
||||
registry_configs, selection_max_tools = registry.load_agent_selection(queen_pkg_dir)
|
||||
|
||||
# Auto-include every user-added local MCP server that the repo
|
||||
# selection hasn't already loaded. Users register servers via
|
||||
# the `/api/mcp/servers` route (or `hive mcp add`); they live in
|
||||
# ~/.hive/mcp_registry/installed.json with source == "local".
|
||||
# New servers take effect on the next queen session start; the
|
||||
# prompt cache and ToolRegistry are still loaded once per boot.
|
||||
already_loaded_names = {cfg.get("name") for cfg in registry_configs if cfg.get("name")}
|
||||
extra_names: list[str] = []
|
||||
try:
|
||||
for entry in registry.list_installed():
|
||||
if entry.get("source") != "local":
|
||||
continue
|
||||
if not entry.get("enabled", True):
|
||||
continue
|
||||
name = entry.get("name")
|
||||
if not name or name in already_loaded_names:
|
||||
continue
|
||||
extra_names.append(name)
|
||||
except Exception:
|
||||
logger.debug("Queen: list_installed() failed while auto-including user servers", exc_info=True)
|
||||
|
||||
if extra_names:
|
||||
try:
|
||||
extra_configs = registry.resolve_for_agent(include=extra_names)
|
||||
extra_dicts = [registry._server_config_to_dict(c) for c in extra_configs]
|
||||
registry_configs = list(registry_configs) + extra_dicts
|
||||
logger.info(
|
||||
"Queen: auto-including %d user-added MCP server(s): %s",
|
||||
len(extra_dicts),
|
||||
[c.get("name") for c in extra_dicts],
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Queen: failed to resolve user-added MCP servers %s",
|
||||
extra_names,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if registry_configs:
|
||||
results = queen_registry.load_registry_servers(
|
||||
registry_configs,
|
||||
@@ -363,6 +488,21 @@ async def create_queen(
|
||||
phase_state=phase_state,
|
||||
)
|
||||
|
||||
# ---- Task system tools --------------------------------------------
|
||||
# Every queen gets the four session task tools. Queens-of-colony
|
||||
# additionally get the colony_template_* tools (gated by colony_id).
|
||||
from framework.tasks.tools import (
|
||||
register_colony_template_tools,
|
||||
register_task_tools,
|
||||
)
|
||||
|
||||
register_task_tools(queen_registry)
|
||||
_colony_id_for_queen = getattr(session, "colony_id", None) or getattr(
|
||||
getattr(session, "colony_runtime", None), "_colony_id", None
|
||||
)
|
||||
if _colony_id_for_queen:
|
||||
register_colony_template_tools(queen_registry, colony_id=_colony_id_for_queen)
|
||||
|
||||
# ---- Colony runtime check (only when worker is loaded) ----------------
|
||||
if session.colony_runtime:
|
||||
from framework.tools.worker_monitoring_tools import register_worker_monitoring_tools
|
||||
@@ -417,6 +557,71 @@ async def create_queen(
|
||||
sorted(t.name for t in phase_state.incubating_tools),
|
||||
)
|
||||
|
||||
# ---- Per-queen MCP tool allowlist --------------------------------
|
||||
# Capture the set of MCP-origin tool names so the allowlist in
|
||||
# ``QueenPhaseState`` only gates MCP tools (lifecycle and synthetic
|
||||
# tools always pass through). Then apply the queen profile's stored
|
||||
# allowlist (if any) and memoize the filtered independent tool list.
|
||||
mcp_server_tools_map: dict[str, set[str]] = dict(getattr(queen_registry, "_mcp_server_tools", {}))
|
||||
phase_state.mcp_tool_names_all = set().union(*mcp_server_tools_map.values()) if mcp_server_tools_map else set()
|
||||
# The queen's MCP tool allowlist now lives in a dedicated
|
||||
# ``tools.json`` sidecar next to ``profile.yaml``. ``load_queen_tools_config``
|
||||
# migrates any legacy ``enabled_mcp_tools`` field out of profile.yaml
|
||||
# on first read, so existing installs upgrade silently.
|
||||
from framework.agents.queen.queen_tools_config import load_queen_tools_config
|
||||
|
||||
# Build a minimal catalog for default-tool resolution. The full
|
||||
# ``session_manager._mcp_tool_catalog`` snapshot is written further
|
||||
# down the flow; a queen booted for the first time needs the catalog
|
||||
# now so ``@server:NAME`` shorthands in the role-default table can
|
||||
# expand against the just-loaded MCP servers.
|
||||
_boot_catalog: dict[str, list[dict]] = {
|
||||
srv: [{"name": name} for name in sorted(names)] for srv, names in mcp_server_tools_map.items()
|
||||
}
|
||||
# ``queen_dir`` is ``queens/<queen_id>/sessions/<session_id>``; the
|
||||
# allowlist sidecar is keyed by queen_id, not session_id.
|
||||
phase_state.enabled_mcp_tools = load_queen_tools_config(session.queen_name, _boot_catalog)
|
||||
phase_state.rebuild_independent_filter()
|
||||
if phase_state.enabled_mcp_tools is not None:
|
||||
total_mcp = len(phase_state.mcp_tool_names_all)
|
||||
allowed_mcp = len(set(phase_state.enabled_mcp_tools) & phase_state.mcp_tool_names_all)
|
||||
logger.info(
|
||||
"Queen: per-queen MCP allowlist active — %d of %d MCP tools enabled",
|
||||
allowed_mcp,
|
||||
total_mcp,
|
||||
)
|
||||
|
||||
# ---- MCP tool catalog for the frontend ---------------------------
|
||||
# Snapshot per-server tool metadata so the Queen Tools API can render
|
||||
# the tool surface without spawning MCP subprocesses. Keyed by server
|
||||
# name so the UI can group tools by origin. Updated every time a
|
||||
# queen boots, so installing a new server and starting a new queen
|
||||
# session refreshes the catalog.
|
||||
mcp_tool_catalog: dict[str, list[dict[str, Any]]] = {}
|
||||
tools_by_name = {t.name: t for t in queen_tools}
|
||||
for server_name, tool_names in mcp_server_tools_map.items():
|
||||
server_entries: list[dict[str, Any]] = []
|
||||
for tool_name in sorted(tool_names):
|
||||
tool = tools_by_name.get(tool_name)
|
||||
if tool is None:
|
||||
continue
|
||||
server_entries.append(
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": tool.parameters,
|
||||
}
|
||||
)
|
||||
mcp_tool_catalog[server_name] = server_entries
|
||||
# All queens share one MCP registry, so the catalog is a manager-level
|
||||
# fact; stash it on the SessionManager so the Queen Tools route can
|
||||
# render the tool list even when no queen session is currently live.
|
||||
if session_manager is not None:
|
||||
try:
|
||||
session_manager._mcp_tool_catalog = mcp_tool_catalog # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("Queen: could not attach mcp_tool_catalog to manager", exc_info=True)
|
||||
|
||||
# ---- Global + queen-scoped memory ----------------------------------
|
||||
global_dir, queen_mem_dir = initialize_memory_scopes(session, phase_state)
|
||||
|
||||
@@ -476,12 +681,34 @@ async def create_queen(
|
||||
# ---- Default skill protocols -------------------------------------
|
||||
_queen_skill_dirs: list[str] = []
|
||||
try:
|
||||
from framework.config import QUEENS_DIR
|
||||
from framework.skills.discovery import ExtraScope
|
||||
from framework.skills.manager import SkillsManager, SkillsManagerConfig
|
||||
|
||||
# Pass project_root so user-scope skills (~/.hive/skills/, ~/.agents/skills/)
|
||||
# are discovered. Queen has no agent-specific project root, so we use its
|
||||
# own directory — the value just needs to be non-None to enable user-scope scanning.
|
||||
_queen_skills_mgr = SkillsManager(SkillsManagerConfig(project_root=Path(__file__).parent))
|
||||
# Queen home backs the queen-UI skill scope and the queen's
|
||||
# override store. The directory already exists (or is created on
|
||||
# demand by queen_profiles.py); treat a missing queen_name as the
|
||||
# default queen to preserve backwards compatibility.
|
||||
_queen_id = getattr(session, "queen_name", None) or "default"
|
||||
_queen_home = QUEENS_DIR / _queen_id
|
||||
_queen_skills_mgr = SkillsManager(
|
||||
SkillsManagerConfig(
|
||||
queen_id=_queen_id,
|
||||
queen_overrides_path=_queen_home / "skills_overrides.json",
|
||||
extra_scope_dirs=[
|
||||
ExtraScope(
|
||||
directory=_queen_home / "skills",
|
||||
label="queen_ui",
|
||||
priority=2,
|
||||
)
|
||||
],
|
||||
# No project_root — queen's project is her own identity;
|
||||
# user-scope discovery still runs without one.
|
||||
project_root=None,
|
||||
skip_community_discovery=True,
|
||||
interactive=False,
|
||||
)
|
||||
)
|
||||
_queen_skills_mgr.load()
|
||||
phase_state.protocols_prompt = _queen_skills_mgr.protocols_prompt
|
||||
phase_state.skills_catalog_prompt = _queen_skills_mgr.skills_catalog_prompt
|
||||
@@ -520,8 +747,37 @@ async def create_queen(
|
||||
|
||||
# ---- Recall on each real user turn --------------------------------
|
||||
async def _recall_on_user_input(event: AgentEvent) -> None:
|
||||
"""Re-select memories when real user input arrives."""
|
||||
await _refresh_recall_cache((event.data or {}).get("content", ""))
|
||||
"""On real user input, freeze the dynamic system-prompt suffix and
|
||||
refresh recall memories in the background.
|
||||
|
||||
The EventBus drops handlers that exceed 15s, so we MUST return fast.
|
||||
Recall selection queries the LLM and can take >15s on slow backends;
|
||||
we fire it off as a background task and re-stamp the suffix when it
|
||||
completes. The immediate refresh_dynamic_suffix call stamps a fresh
|
||||
timestamp using the last-known recall blocks so every iteration of
|
||||
THIS user turn sees a byte-stable prompt (prompt cache hits on the
|
||||
static block). Phase-change injections and worker-report injections
|
||||
go through agent_loop.inject_event() and do NOT publish
|
||||
CLIENT_INPUT_RECEIVED, so this runs exactly once per real user turn.
|
||||
"""
|
||||
query = (event.data or {}).get("content", "")
|
||||
# Immediate: stamp "now" into the frozen suffix, using whatever
|
||||
# recall blocks we already cached (from the prior turn or seeding).
|
||||
phase_state.refresh_dynamic_suffix()
|
||||
|
||||
async def _bg_refresh() -> None:
|
||||
try:
|
||||
await _refresh_recall_cache(query)
|
||||
# Re-stamp with the fresh recall blocks. Any iteration that
|
||||
# read the suffix before this point used the older recall
|
||||
# — acceptable; recall was already eventual-consistency.
|
||||
phase_state.refresh_dynamic_suffix()
|
||||
except Exception:
|
||||
logger.debug("background recall refresh failed", exc_info=True)
|
||||
|
||||
import asyncio as _asyncio
|
||||
|
||||
_asyncio.create_task(_bg_refresh())
|
||||
|
||||
session.event_bus.subscribe(
|
||||
[EventType.CLIENT_INPUT_RECEIVED],
|
||||
@@ -631,6 +887,9 @@ async def create_queen(
|
||||
except Exception:
|
||||
logger.debug("recall: initial seeding failed", exc_info=True)
|
||||
|
||||
# Freeze the dynamic suffix once so the first real turn sends a
|
||||
# byte-stable prompt even before CLIENT_INPUT_RECEIVED fires.
|
||||
phase_state.refresh_dynamic_suffix()
|
||||
return HookResult(system_prompt=phase_state.get_current_prompt())
|
||||
|
||||
# ---- Colony preparation -------------------------------------------
|
||||
@@ -675,10 +934,21 @@ async def create_queen(
|
||||
# token stays local to this task.
|
||||
try:
|
||||
from framework.loader.tool_registry import ToolRegistry
|
||||
from framework.tasks.scoping import session_task_list_id
|
||||
|
||||
ToolRegistry.set_execution_context(profile=session.id)
|
||||
queen_agent_id = getattr(session, "agent_id", None) or "queen"
|
||||
queen_list_id = session_task_list_id(queen_agent_id, session.id)
|
||||
colony_id = getattr(session, "colony_id", None) or getattr(
|
||||
getattr(session, "colony_runtime", None), "_colony_id", None
|
||||
)
|
||||
ToolRegistry.set_execution_context(
|
||||
profile=session.id,
|
||||
agent_id=queen_agent_id,
|
||||
task_list_id=queen_list_id,
|
||||
colony_id=colony_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Queen: failed to set browser profile for session %s", session.id, exc_info=True)
|
||||
logger.debug("Queen: failed to set execution context for session %s", session.id, exc_info=True)
|
||||
try:
|
||||
lc = _queen_loop_config
|
||||
queen_loop_config = LoopConfig(
|
||||
@@ -730,7 +1000,8 @@ async def create_queen(
|
||||
stream_id="queen",
|
||||
execution_id=session.id,
|
||||
dynamic_tools_provider=phase_state.get_current_tools,
|
||||
dynamic_prompt_provider=phase_state.get_current_prompt,
|
||||
dynamic_prompt_provider=phase_state.get_static_prompt,
|
||||
dynamic_prompt_suffix_provider=phase_state.get_dynamic_suffix,
|
||||
iteration_metadata_provider=lambda: {"phase": phase_state.phase},
|
||||
skills_catalog_prompt=phase_state.skills_catalog_prompt,
|
||||
protocols_prompt=phase_state.protocols_prompt,
|
||||
|
||||
@@ -0,0 +1,329 @@
|
||||
"""Per-colony MCP tool allowlist routes.
|
||||
|
||||
- GET /api/colony/{colony_name}/tools -- enumerate colony tool surface
|
||||
- PATCH /api/colony/{colony_name}/tools -- set or clear the allowlist
|
||||
|
||||
A colony's tool set is inherited from the queen that forked it, so the
|
||||
tool surface mirrors the queen's MCP servers. Lifecycle/synthetic tools
|
||||
are included for display only. MCP tools are grouped by origin server
|
||||
with per-tool ``enabled`` flags.
|
||||
|
||||
Semantics:
|
||||
|
||||
- ``enabled_mcp_tools: null`` → allow every MCP tool (default).
|
||||
- ``enabled_mcp_tools: []`` → allow no MCP tools (only lifecycle /
|
||||
synthetic pass through).
|
||||
- ``enabled_mcp_tools: [...]`` → only listed names pass.
|
||||
|
||||
The allowlist is persisted in a dedicated ``tools.json`` sidecar at
|
||||
``~/.hive/colonies/{colony_name}/tools.json``. Changes take effect on
|
||||
the *next* worker spawn. In-flight workers keep the tool list they
|
||||
booted with because workers have no dynamic tools provider today —
|
||||
mutating their tool set mid-turn would diverge from the list the LLM
|
||||
is already using.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from framework.host.colony_metadata import colony_metadata_path
|
||||
from framework.host.colony_tools_config import (
|
||||
load_colony_tools_config,
|
||||
update_colony_tools_config,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_SYNTHETIC_NAMES = {"ask_user"}
|
||||
|
||||
|
||||
def _synthetic_entries() -> list[dict[str, Any]]:
|
||||
try:
|
||||
from framework.agent_loop.internals.synthetic_tools import build_ask_user_tool
|
||||
|
||||
tool = build_ask_user_tool()
|
||||
return [
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"editable": False,
|
||||
}
|
||||
]
|
||||
except Exception:
|
||||
return [
|
||||
{
|
||||
"name": "ask_user",
|
||||
"description": "Pause and ask the user a structured question.",
|
||||
"editable": False,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def _colony_runtimes_for_name(manager: Any, colony_name: str) -> list[Any]:
|
||||
"""Return every live ColonyRuntime whose session is attached to ``colony_name``."""
|
||||
sessions = getattr(manager, "_sessions", None) or {}
|
||||
runtimes: list[Any] = []
|
||||
for session in sessions.values():
|
||||
if getattr(session, "colony_name", None) != colony_name:
|
||||
continue
|
||||
# Both ``session.colony`` (queen-side unified runtime) and
|
||||
# ``session.colony_runtime`` (legacy worker runtime) may carry
|
||||
# tools that need the allowlist applied. We update both.
|
||||
for attr in ("colony", "colony_runtime"):
|
||||
rt = getattr(session, attr, None)
|
||||
if rt is not None and rt not in runtimes:
|
||||
runtimes.append(rt)
|
||||
return runtimes
|
||||
|
||||
|
||||
async def _render_catalog(manager: Any, colony_name: str) -> dict[str, list[dict[str, Any]]]:
|
||||
"""Build a per-server tool catalog for this colony.
|
||||
|
||||
All colonies inherit the queen's MCP surface, so we reuse the
|
||||
manager-level ``_mcp_tool_catalog`` populated during queen boot.
|
||||
"""
|
||||
# If a live runtime exists and carries its own registry, prefer it —
|
||||
# it's authoritative (reflects any post-queen-boot MCP additions).
|
||||
for rt in _colony_runtimes_for_name(manager, colony_name):
|
||||
tools = getattr(rt, "_tools", None)
|
||||
if not tools:
|
||||
continue
|
||||
mcp_names = set(getattr(rt, "_mcp_tool_names_all", set()) or set())
|
||||
if not mcp_names:
|
||||
continue
|
||||
catalog: dict[str, list[dict[str, Any]]] = {"(mcp)": []}
|
||||
for tool in tools:
|
||||
name = getattr(tool, "name", None)
|
||||
if name in mcp_names:
|
||||
catalog["(mcp)"].append(
|
||||
{
|
||||
"name": name,
|
||||
"description": getattr(tool, "description", ""),
|
||||
"input_schema": getattr(tool, "parameters", {}),
|
||||
}
|
||||
)
|
||||
return catalog
|
||||
|
||||
# Otherwise fall back to the queen-level snapshot. Build it on demand
|
||||
# (off the event loop) when empty so the Tool Library works before
|
||||
# any queen has been started in this process.
|
||||
cached = getattr(manager, "_mcp_tool_catalog", None)
|
||||
if isinstance(cached, dict) and cached:
|
||||
return cached
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
from framework.server.queen_orchestrator import build_queen_tool_registry_bare
|
||||
|
||||
registry, built = await asyncio.to_thread(build_queen_tool_registry_bare)
|
||||
if manager is not None:
|
||||
manager._mcp_tool_catalog = built # type: ignore[attr-defined]
|
||||
manager._bootstrap_tool_registry = registry # type: ignore[attr-defined]
|
||||
return built
|
||||
except Exception:
|
||||
logger.warning("Colony tools: catalog bootstrap failed", exc_info=True)
|
||||
return {}
|
||||
|
||||
|
||||
def _lifecycle_entries_from_runtime(manager: Any, colony_name: str) -> list[dict[str, Any]]:
|
||||
"""Non-MCP tools currently registered on the colony runtime (if any).
|
||||
|
||||
When no live runtime is available we fall back to the bootstrap
|
||||
registry stashed on the manager by ``routes_queen_tools`` — it
|
||||
already has queen lifecycle tools registered, which are also the
|
||||
lifecycle tools colonies inherit at spawn time.
|
||||
"""
|
||||
out: list[dict[str, Any]] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
def _push(name: str, description: str) -> None:
|
||||
if not name or name in seen:
|
||||
return
|
||||
if name in _SYNTHETIC_NAMES:
|
||||
return
|
||||
seen.add(name)
|
||||
out.append({"name": name, "description": description, "editable": False})
|
||||
|
||||
runtimes = _colony_runtimes_for_name(manager, colony_name)
|
||||
if runtimes:
|
||||
for rt in runtimes:
|
||||
mcp_names = set(getattr(rt, "_mcp_tool_names_all", set()) or set())
|
||||
for tool in getattr(rt, "_tools", []) or []:
|
||||
name = getattr(tool, "name", None)
|
||||
if name in mcp_names:
|
||||
continue
|
||||
_push(name, getattr(tool, "description", ""))
|
||||
else:
|
||||
# No live runtime — derive from the bootstrap registry.
|
||||
from framework.server.routes_queen_tools import _lifecycle_entries_without_session
|
||||
|
||||
catalog = getattr(manager, "_mcp_tool_catalog", {}) or {}
|
||||
mcp_names: set[str] = set()
|
||||
for entries in catalog.values():
|
||||
for entry in entries:
|
||||
if entry.get("name"):
|
||||
mcp_names.add(entry["name"])
|
||||
out.extend(_lifecycle_entries_without_session(manager, mcp_names))
|
||||
return out
|
||||
return sorted(out, key=lambda e: e["name"])
|
||||
|
||||
|
||||
def _render_servers(
|
||||
catalog: dict[str, list[dict[str, Any]]],
|
||||
enabled_mcp_tools: list[str] | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
allowed: set[str] | None = None if enabled_mcp_tools is None else set(enabled_mcp_tools)
|
||||
servers: list[dict[str, Any]] = []
|
||||
for name in sorted(catalog):
|
||||
tools = []
|
||||
for entry in catalog[name]:
|
||||
tool_name = entry.get("name")
|
||||
tools.append(
|
||||
{
|
||||
"name": tool_name,
|
||||
"description": entry.get("description", ""),
|
||||
"input_schema": entry.get("input_schema", {}),
|
||||
"enabled": True if allowed is None else tool_name in allowed,
|
||||
}
|
||||
)
|
||||
servers.append({"name": name, "tools": tools})
|
||||
return servers
|
||||
|
||||
|
||||
async def handle_get_tools(request: web.Request) -> web.Response:
|
||||
"""GET /api/colony/{colony_name}/tools."""
|
||||
colony_name = request.match_info["colony_name"]
|
||||
if not colony_metadata_path(colony_name).exists():
|
||||
return web.json_response({"error": f"Colony '{colony_name}' not found"}, status=404)
|
||||
|
||||
manager = request.app.get("manager")
|
||||
# Allowlist now lives in a dedicated tools.json sidecar; helper
|
||||
# migrates any legacy metadata.json field on first read.
|
||||
enabled = load_colony_tools_config(colony_name)
|
||||
|
||||
catalog = await _render_catalog(manager, colony_name)
|
||||
stale = not catalog
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"colony_name": colony_name,
|
||||
"enabled_mcp_tools": enabled,
|
||||
"stale": stale,
|
||||
"lifecycle": _lifecycle_entries_from_runtime(manager, colony_name),
|
||||
"synthetic": _synthetic_entries(),
|
||||
"mcp_servers": _render_servers(catalog, enabled),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def handle_patch_tools(request: web.Request) -> web.Response:
|
||||
"""PATCH /api/colony/{colony_name}/tools."""
|
||||
colony_name = request.match_info["colony_name"]
|
||||
if not colony_metadata_path(colony_name).exists():
|
||||
return web.json_response({"error": f"Colony '{colony_name}' not found"}, status=404)
|
||||
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return web.json_response({"error": "Invalid JSON body"}, status=400)
|
||||
if not isinstance(body, dict) or "enabled_mcp_tools" not in body:
|
||||
return web.json_response(
|
||||
{"error": "Body must be an object with an 'enabled_mcp_tools' field"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
enabled = body["enabled_mcp_tools"]
|
||||
if enabled is not None:
|
||||
if not isinstance(enabled, list) or not all(isinstance(x, str) for x in enabled):
|
||||
return web.json_response(
|
||||
{"error": "'enabled_mcp_tools' must be null or a list of strings"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
manager = request.app.get("manager")
|
||||
|
||||
# Validate names against the known MCP catalog — lifts the same
|
||||
# typo-catching guarantee we already offer on queen tools.
|
||||
catalog = await _render_catalog(manager, colony_name)
|
||||
known: set[str] = {e.get("name") for entries in catalog.values() for e in entries if e.get("name")}
|
||||
if enabled is not None and known:
|
||||
unknown = sorted(set(enabled) - known)
|
||||
if unknown:
|
||||
return web.json_response(
|
||||
{"error": "Unknown MCP tool name(s)", "unknown": unknown},
|
||||
status=400,
|
||||
)
|
||||
|
||||
# Persist — tools.json sidecar, not metadata.json. Missing directory
|
||||
# is already guarded by the 404 check above.
|
||||
try:
|
||||
update_colony_tools_config(colony_name, enabled)
|
||||
except FileNotFoundError:
|
||||
return web.json_response({"error": f"Colony '{colony_name}' not found"}, status=404)
|
||||
|
||||
# Update any live runtimes so the NEXT worker spawn reflects the change.
|
||||
# We do NOT rebuild in-flight workers' tool lists (see module docstring).
|
||||
refreshed = 0
|
||||
for rt in _colony_runtimes_for_name(manager, colony_name):
|
||||
setter = getattr(rt, "set_tool_allowlist", None)
|
||||
if callable(setter):
|
||||
try:
|
||||
setter(enabled)
|
||||
refreshed += 1
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Colony tools: set_tool_allowlist failed on runtime for %s",
|
||||
colony_name,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Colony tools: colony=%s allowlist=%s refreshed_runtimes=%d",
|
||||
colony_name,
|
||||
"null" if enabled is None else f"{len(enabled)} tool(s)",
|
||||
refreshed,
|
||||
)
|
||||
return web.json_response(
|
||||
{
|
||||
"colony_name": colony_name,
|
||||
"enabled_mcp_tools": enabled,
|
||||
"refreshed_runtimes": refreshed,
|
||||
"note": "Changes apply to the next worker spawn. Running workers keep their booted tool list.",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def handle_list_colonies(request: web.Request) -> web.Response:
|
||||
"""GET /api/colonies — list colonies with their tool allowlist status.
|
||||
|
||||
Powers the Tool Library page's colony picker.
|
||||
"""
|
||||
from framework.host.colony_metadata import list_colony_names, load_colony_metadata
|
||||
|
||||
colonies: list[dict[str, Any]] = []
|
||||
for name in list_colony_names():
|
||||
meta = load_colony_metadata(name)
|
||||
# Provenance stays in metadata.json; allowlist lives in tools.json.
|
||||
allowlist = load_colony_tools_config(name)
|
||||
colonies.append(
|
||||
{
|
||||
"name": name,
|
||||
"queen_name": meta.get("queen_name"),
|
||||
"created_at": meta.get("created_at"),
|
||||
"has_allowlist": allowlist is not None,
|
||||
"enabled_count": len(allowlist) if isinstance(allowlist, list) else None,
|
||||
}
|
||||
)
|
||||
return web.json_response({"colonies": colonies})
|
||||
|
||||
|
||||
def register_routes(app: web.Application) -> None:
|
||||
"""Register per-colony tool routes."""
|
||||
app.router.add_get("/api/colonies/tools-index", handle_list_colonies)
|
||||
app.router.add_get("/api/colony/{colony_name}/tools", handle_get_tools)
|
||||
app.router.add_patch("/api/colony/{colony_name}/tools", handle_patch_tools)
|
||||
@@ -1577,6 +1577,39 @@ async def fork_session_into_colony(
|
||||
}
|
||||
metadata_path.write_text(json.dumps(metadata, indent=2, ensure_ascii=False), encoding="utf-8")
|
||||
|
||||
# ── 4a. Inherit the queen's tool allowlist into the colony ───
|
||||
# A colony forked from a curated queen should start with the same
|
||||
# tool surface (otherwise the colony silently falls back to its own
|
||||
# "allow every MCP tool" default, undoing the parent's curation).
|
||||
# We copy the queen's LIVE effective allowlist so the snapshot
|
||||
# reflects whatever was in force the moment the user clicked "Create
|
||||
# Colony". Users can further narrow the colony via the Tool Library.
|
||||
# Skip the write when the queen is on allow-all (None) so the colony
|
||||
# keeps the same semantics without creating an inert sidecar.
|
||||
try:
|
||||
queen_enabled = getattr(
|
||||
getattr(session, "phase_state", None),
|
||||
"enabled_mcp_tools",
|
||||
None,
|
||||
)
|
||||
if isinstance(queen_enabled, list):
|
||||
from framework.host.colony_tools_config import update_colony_tools_config
|
||||
|
||||
update_colony_tools_config(colony_name, list(queen_enabled))
|
||||
logger.info(
|
||||
"Inherited queen allowlist into colony '%s' (%d tools)",
|
||||
colony_name,
|
||||
len(queen_enabled),
|
||||
)
|
||||
except Exception:
|
||||
# Inheritance is best-effort — don't let a tools.json hiccup
|
||||
# abort colony creation.
|
||||
logger.warning(
|
||||
"Failed to inherit queen allowlist into colony '%s'",
|
||||
colony_name,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# ── 5. Update source queen session meta.json ─────────────────
|
||||
# Link the originating session back to the colony for discovery.
|
||||
source_meta_path = source_queen_dir / "meta.json"
|
||||
|
||||
@@ -0,0 +1,291 @@
|
||||
"""MCP server registration routes.
|
||||
|
||||
Thin HTTP wrapper around ``MCPRegistry`` so the frontend can add, remove,
|
||||
enable, and health-check user-registered MCP servers. The CLI path
|
||||
(``hive mcp add`` / ``hive mcp remove`` / etc.) is unchanged.
|
||||
|
||||
- GET /api/mcp/servers -- list installed servers
|
||||
- POST /api/mcp/servers -- register a local server
|
||||
- DELETE /api/mcp/servers/{name} -- remove a local server
|
||||
- POST /api/mcp/servers/{name}/enable -- enable a server
|
||||
- POST /api/mcp/servers/{name}/disable -- disable a server
|
||||
- POST /api/mcp/servers/{name}/health -- probe server health
|
||||
|
||||
New servers take effect on the *next* queen session start. Existing live
|
||||
queen sessions keep the tool list they booted with to avoid mid-turn
|
||||
cache invalidation. The ``add`` response hints at this explicitly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from framework.loader.mcp_errors import MCPError
|
||||
from framework.loader.mcp_registry import MCPRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_VALID_TRANSPORTS = {"stdio", "http", "sse", "unix"}
|
||||
|
||||
|
||||
def _registry() -> MCPRegistry:
|
||||
# MCPRegistry is a thin wrapper around ~/.hive/mcp_registry/installed.json
|
||||
# so instantiation is cheap — no need to cache on app["..."].
|
||||
reg = MCPRegistry()
|
||||
reg.initialize()
|
||||
return reg
|
||||
|
||||
|
||||
def _package_builtin_servers() -> list[dict[str, Any]]:
|
||||
"""Return the package-baked queen MCP servers from ``queen/mcp_servers.json``.
|
||||
|
||||
Those servers are loaded directly by ``ToolRegistry.load_mcp_config``
|
||||
at queen boot and never go through ``MCPRegistry.list_installed``,
|
||||
so the raw registry view shows them as missing. Surface them here so
|
||||
the Tool Library reflects what the queen actually talks to.
|
||||
|
||||
Entries carry ``source: "built-in"`` and are NOT removable / toggleable
|
||||
— editing them requires changing the repo file.
|
||||
"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import framework.agents.queen as _queen_pkg
|
||||
|
||||
path = Path(_queen_pkg.__file__).parent / "mcp_servers.json"
|
||||
if not path.exists():
|
||||
return []
|
||||
try:
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return []
|
||||
|
||||
out: list[dict[str, Any]] = []
|
||||
for name, cfg in data.items():
|
||||
if not isinstance(cfg, dict):
|
||||
continue
|
||||
out.append(
|
||||
{
|
||||
"name": name,
|
||||
"source": "built-in",
|
||||
"transport": cfg.get("transport", "stdio"),
|
||||
"description": cfg.get("description", "") or "",
|
||||
"enabled": True,
|
||||
"last_health_status": None,
|
||||
"last_error": None,
|
||||
"last_health_check_at": None,
|
||||
"tool_count": None,
|
||||
"removable": False,
|
||||
}
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def _server_to_summary(entry: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Shape an installed.json entry for API responses.
|
||||
|
||||
Strips the full manifest body (which can be large) but keeps the tool
|
||||
list if the manifest already embeds one (happens for registry-installed
|
||||
servers). Users with ``source: "local"`` only get a tool list after
|
||||
running a health check.
|
||||
"""
|
||||
manifest = entry.get("manifest") or {}
|
||||
tools = manifest.get("tools") if isinstance(manifest, dict) else None
|
||||
if not isinstance(tools, list):
|
||||
tools = None
|
||||
return {
|
||||
"name": entry.get("name"),
|
||||
"source": entry.get("source"),
|
||||
"transport": entry.get("transport"),
|
||||
"description": (manifest.get("description") if isinstance(manifest, dict) else None) or "",
|
||||
"enabled": entry.get("enabled", True),
|
||||
"last_health_status": entry.get("last_health_status"),
|
||||
"last_error": entry.get("last_error"),
|
||||
"last_health_check_at": entry.get("last_health_check_at"),
|
||||
"tool_count": (len(tools) if tools is not None else None),
|
||||
}
|
||||
|
||||
|
||||
def _mcp_error_response(exc: MCPError, *, default_status: int = 400) -> web.Response:
|
||||
return web.json_response(
|
||||
{
|
||||
"error": exc.what,
|
||||
"code": exc.code.value,
|
||||
"what": exc.what,
|
||||
"why": exc.why,
|
||||
"fix": exc.fix,
|
||||
},
|
||||
status=default_status,
|
||||
)
|
||||
|
||||
|
||||
async def handle_list_servers(request: web.Request) -> web.Response:
|
||||
"""GET /api/mcp/servers — list every server the queen actually uses.
|
||||
|
||||
Merges two sources:
|
||||
|
||||
- ``MCPRegistry.list_installed()`` — servers registered via
|
||||
``hive mcp add`` / the ``/api/mcp/servers`` POST route, stored in
|
||||
``~/.hive/mcp_registry/installed.json``. These carry
|
||||
``source: "local"`` (user-added) or ``source: "registry"``
|
||||
(installed from the remote registry).
|
||||
- Repo-baked queen servers from
|
||||
``core/framework/agents/queen/mcp_servers.json``. These are loaded
|
||||
directly by the queen's ``ToolRegistry`` at boot and never touch
|
||||
``MCPRegistry``; we surface them here so the UI reflects what the
|
||||
queen really talks to. They are not removable from the UI because
|
||||
editing them requires changing the repo.
|
||||
|
||||
If a name collides between the two sources, the registry entry wins
|
||||
because that's the one the user has customized.
|
||||
"""
|
||||
reg = _registry()
|
||||
registry_entries = [_server_to_summary(e) for e in reg.list_installed()]
|
||||
seen_names = {e.get("name") for e in registry_entries}
|
||||
|
||||
package_entries = [e for e in _package_builtin_servers() if e.get("name") not in seen_names]
|
||||
|
||||
servers = [*package_entries, *registry_entries]
|
||||
return web.json_response({"servers": servers})
|
||||
|
||||
|
||||
async def handle_add_server(request: web.Request) -> web.Response:
|
||||
"""POST /api/mcp/servers — register a local MCP server.
|
||||
|
||||
Body mirrors ``MCPRegistry.add_local`` args:
|
||||
|
||||
::
|
||||
|
||||
{
|
||||
"name": "my-tool",
|
||||
"transport": "stdio" | "http" | "sse" | "unix",
|
||||
"command": "...", "args": [...], "env": {...}, "cwd": "...",
|
||||
"url": "...", "headers": {...},
|
||||
"socket_path": "...",
|
||||
"description": "..."
|
||||
}
|
||||
"""
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return web.json_response({"error": "Invalid JSON body"}, status=400)
|
||||
if not isinstance(body, dict):
|
||||
return web.json_response({"error": "Body must be a JSON object"}, status=400)
|
||||
|
||||
name = body.get("name")
|
||||
transport = body.get("transport")
|
||||
if not isinstance(name, str) or not name.strip():
|
||||
return web.json_response({"error": "'name' is required"}, status=400)
|
||||
if transport not in _VALID_TRANSPORTS:
|
||||
return web.json_response(
|
||||
{"error": f"'transport' must be one of {sorted(_VALID_TRANSPORTS)}"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
reg = _registry()
|
||||
try:
|
||||
entry = reg.add_local(
|
||||
name=name.strip(),
|
||||
transport=transport,
|
||||
command=body.get("command"),
|
||||
args=body.get("args"),
|
||||
env=body.get("env"),
|
||||
cwd=body.get("cwd"),
|
||||
url=body.get("url"),
|
||||
headers=body.get("headers"),
|
||||
socket_path=body.get("socket_path"),
|
||||
description=body.get("description", ""),
|
||||
)
|
||||
except MCPError as exc:
|
||||
status = 409 if "already exists" in exc.what else 400
|
||||
return _mcp_error_response(exc, default_status=status)
|
||||
except Exception as exc:
|
||||
logger.exception("MCP add_local failed for %r", name)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
summary = _server_to_summary({"name": name, **entry})
|
||||
return web.json_response(
|
||||
{
|
||||
"server": summary,
|
||||
"hint": "Start a new queen session to use this server's tools.",
|
||||
},
|
||||
status=201,
|
||||
)
|
||||
|
||||
|
||||
async def handle_remove_server(request: web.Request) -> web.Response:
|
||||
"""DELETE /api/mcp/servers/{name} — remove a local server."""
|
||||
name = request.match_info["name"]
|
||||
reg = _registry()
|
||||
|
||||
existing = reg.get_server(name)
|
||||
if existing is None:
|
||||
return web.json_response({"error": f"Server '{name}' not installed"}, status=404)
|
||||
if existing.get("source") != "local":
|
||||
return web.json_response(
|
||||
{
|
||||
"error": f"Server '{name}' is a built-in; it cannot be removed from the UI.",
|
||||
},
|
||||
status=400,
|
||||
)
|
||||
|
||||
try:
|
||||
reg.remove(name)
|
||||
except MCPError as exc:
|
||||
return _mcp_error_response(exc, default_status=404)
|
||||
return web.json_response({"removed": name})
|
||||
|
||||
|
||||
async def handle_set_enabled(request: web.Request, *, enabled: bool) -> web.Response:
|
||||
name = request.match_info["name"]
|
||||
reg = _registry()
|
||||
try:
|
||||
if enabled:
|
||||
reg.enable(name)
|
||||
else:
|
||||
reg.disable(name)
|
||||
except MCPError as exc:
|
||||
return _mcp_error_response(exc, default_status=404)
|
||||
return web.json_response({"name": name, "enabled": enabled})
|
||||
|
||||
|
||||
async def handle_enable(request: web.Request) -> web.Response:
|
||||
"""POST /api/mcp/servers/{name}/enable."""
|
||||
return await handle_set_enabled(request, enabled=True)
|
||||
|
||||
|
||||
async def handle_disable(request: web.Request) -> web.Response:
|
||||
"""POST /api/mcp/servers/{name}/disable."""
|
||||
return await handle_set_enabled(request, enabled=False)
|
||||
|
||||
|
||||
async def handle_health(request: web.Request) -> web.Response:
|
||||
"""POST /api/mcp/servers/{name}/health — probe one server."""
|
||||
name = request.match_info["name"]
|
||||
reg = _registry()
|
||||
try:
|
||||
# MCPRegistry.health_check blocks on subprocess IO — run it off
|
||||
# the event loop so the HTTP worker stays responsive.
|
||||
import asyncio
|
||||
|
||||
result = await asyncio.to_thread(reg.health_check, name)
|
||||
except MCPError as exc:
|
||||
return _mcp_error_response(exc, default_status=404)
|
||||
except Exception as exc:
|
||||
logger.exception("MCP health_check failed for %r", name)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
return web.json_response(result)
|
||||
|
||||
|
||||
def register_routes(app: web.Application) -> None:
|
||||
"""Register MCP server CRUD routes."""
|
||||
app.router.add_get("/api/mcp/servers", handle_list_servers)
|
||||
app.router.add_post("/api/mcp/servers", handle_add_server)
|
||||
app.router.add_delete("/api/mcp/servers/{name}", handle_remove_server)
|
||||
app.router.add_post("/api/mcp/servers/{name}/enable", handle_enable)
|
||||
app.router.add_post("/api/mcp/servers/{name}/disable", handle_disable)
|
||||
app.router.add_post("/api/mcp/servers/{name}/health", handle_health)
|
||||
@@ -0,0 +1,506 @@
|
||||
"""Per-queen MCP tool allowlist routes.
|
||||
|
||||
- GET /api/queen/{queen_id}/tools -- enumerate the queen's tool surface
|
||||
- PATCH /api/queen/{queen_id}/tools -- set or clear the MCP tool allowlist
|
||||
|
||||
Lifecycle and synthetic tools (``ask_user``) are always part of the queen's
|
||||
surface in INDEPENDENT mode and are returned with ``editable: false``. MCP
|
||||
tools are grouped by origin server and carry per-tool ``enabled`` flags.
|
||||
|
||||
The allowlist is persisted in a dedicated ``tools.json`` sidecar at
|
||||
``~/.hive/agents/queens/{queen_id}/tools.json``:
|
||||
|
||||
- ``null`` / missing file -> "allow every MCP tool" (default)
|
||||
- ``[]`` -> explicitly disable every MCP tool
|
||||
- ``["foo", "bar"]`` -> only these MCP tools pass through to the LLM
|
||||
|
||||
Filtering happens in ``QueenPhaseState.rebuild_independent_filter`` so the
|
||||
LLM prompt cache stays warm between saves.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from framework.agents.queen.queen_profiles import (
|
||||
ensure_default_queens,
|
||||
load_queen_profile,
|
||||
)
|
||||
from framework.agents.queen.queen_tools_config import (
|
||||
delete_queen_tools_config,
|
||||
load_queen_tools_config,
|
||||
tools_config_exists,
|
||||
update_queen_tools_config,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_SYNTHETIC_NAMES = {"ask_user"}
|
||||
|
||||
|
||||
async def _ensure_manager_catalog(manager: Any) -> dict[str, list[dict[str, Any]]]:
|
||||
"""Return the cached MCP tool catalog, building it on first call.
|
||||
|
||||
``queen_orchestrator.create_queen`` populates ``_mcp_tool_catalog`` on
|
||||
every queen boot. On a fresh backend process the user may open the
|
||||
Tool Library before any queen session has started, so the catalog is
|
||||
empty. In that case we build one from the shared MCP config; the
|
||||
first call pays an MCP-subprocess-spawn cost, subsequent calls are
|
||||
cache hits. The build runs off the event loop via asyncio.to_thread
|
||||
so the HTTP worker stays responsive while MCP servers initialize.
|
||||
"""
|
||||
if manager is None:
|
||||
return {}
|
||||
catalog = getattr(manager, "_mcp_tool_catalog", None)
|
||||
if isinstance(catalog, dict) and catalog:
|
||||
return catalog
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
from framework.server.queen_orchestrator import build_queen_tool_registry_bare
|
||||
|
||||
registry, built = await asyncio.to_thread(build_queen_tool_registry_bare)
|
||||
manager._mcp_tool_catalog = built # type: ignore[attr-defined]
|
||||
manager._bootstrap_tool_registry = registry # type: ignore[attr-defined]
|
||||
return built
|
||||
except Exception:
|
||||
logger.warning("Tool catalog bootstrap failed", exc_info=True)
|
||||
return {}
|
||||
|
||||
|
||||
def _lifecycle_entries_without_session(
|
||||
manager: Any,
|
||||
mcp_names: set[str],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Derive lifecycle tool names from the registry even without a session.
|
||||
|
||||
We register queen lifecycle tools against a temporary registry using a
|
||||
minimal stub, then subtract the MCP-origin set and the synthetic set.
|
||||
The result matches what the queen sees at runtime (minus context-
|
||||
specific variants).
|
||||
"""
|
||||
registry = getattr(manager, "_bootstrap_tool_registry", None)
|
||||
# If the bootstrap registry exists but doesn't carry lifecycle tools
|
||||
# yet, register them now.
|
||||
if registry is not None and not getattr(registry, "_lifecycle_bootstrap_done", False):
|
||||
try:
|
||||
from types import SimpleNamespace
|
||||
|
||||
from framework.tools.queen_lifecycle_tools import register_queen_lifecycle_tools
|
||||
|
||||
stub_session = SimpleNamespace(
|
||||
id="tool-library-bootstrap",
|
||||
colony_runtime=None,
|
||||
event_bus=None,
|
||||
worker_path=None,
|
||||
phase_state=None,
|
||||
llm=None,
|
||||
)
|
||||
register_queen_lifecycle_tools(
|
||||
registry,
|
||||
session=stub_session,
|
||||
session_id=stub_session.id,
|
||||
session_manager=None,
|
||||
manager_session_id=stub_session.id,
|
||||
phase_state=None,
|
||||
)
|
||||
registry._lifecycle_bootstrap_done = True # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("lifecycle bootstrap failed", exc_info=True)
|
||||
|
||||
if registry is None:
|
||||
return []
|
||||
|
||||
out: list[dict[str, Any]] = []
|
||||
for name, tool in sorted(registry.get_tools().items()):
|
||||
if name in mcp_names or name in _SYNTHETIC_NAMES:
|
||||
continue
|
||||
out.append(
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"editable": False,
|
||||
}
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def _synthetic_entries() -> list[dict[str, Any]]:
|
||||
"""Return display metadata for synthetic tools injected by the agent loop.
|
||||
|
||||
Kept behind a lazy import so test harnesses that don't wire the agent
|
||||
loop can still hit this route without blowing up.
|
||||
"""
|
||||
try:
|
||||
from framework.agent_loop.internals.synthetic_tools import build_ask_user_tool
|
||||
|
||||
tool = build_ask_user_tool()
|
||||
return [
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"editable": False,
|
||||
}
|
||||
]
|
||||
except Exception:
|
||||
return [
|
||||
{
|
||||
"name": "ask_user",
|
||||
"description": "Pause and ask the user a structured question.",
|
||||
"editable": False,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def _live_queen_session(manager: Any, queen_id: str) -> Any:
|
||||
"""Return any live DM session owned by this queen, or ``None``."""
|
||||
sessions = getattr(manager, "_sessions", None) or {}
|
||||
for session in sessions.values():
|
||||
if getattr(session, "queen_name", None) != queen_id:
|
||||
continue
|
||||
# Prefer DM (non-colony) sessions
|
||||
if getattr(session, "colony_runtime", None) is None:
|
||||
return session
|
||||
return None
|
||||
|
||||
|
||||
def _render_mcp_servers(
|
||||
*,
|
||||
mcp_tool_names_by_server: dict[str, list[dict[str, Any]]],
|
||||
enabled_mcp_tools: list[str] | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Shape the mcp_tool_catalog entries for the API response."""
|
||||
allowed: set[str] | None = None if enabled_mcp_tools is None else set(enabled_mcp_tools)
|
||||
servers: list[dict[str, Any]] = []
|
||||
for server_name in sorted(mcp_tool_names_by_server):
|
||||
entries = mcp_tool_names_by_server[server_name]
|
||||
tools = []
|
||||
for entry in entries:
|
||||
name = entry.get("name")
|
||||
enabled = True if allowed is None else name in allowed
|
||||
tools.append(
|
||||
{
|
||||
"name": name,
|
||||
"description": entry.get("description", ""),
|
||||
"input_schema": entry.get("input_schema", {}),
|
||||
"enabled": enabled,
|
||||
}
|
||||
)
|
||||
servers.append({"name": server_name, "tools": tools})
|
||||
return servers
|
||||
|
||||
|
||||
def _catalog_from_live_session(session: Any) -> dict[str, list[dict[str, Any]]]:
|
||||
"""Rebuild a per-server tool catalog from a live queen session.
|
||||
|
||||
The session's registry is authoritative — this reflects any hot-added
|
||||
MCP servers since the manager-level snapshot was cached.
|
||||
"""
|
||||
registry = getattr(session, "_queen_tool_registry", None)
|
||||
if registry is None:
|
||||
# session._queen_tools_by_name is a stash from create_queen; we
|
||||
# only have registry via the tools list, so reconstruct from the
|
||||
# phase state instead.
|
||||
phase_state = getattr(session, "phase_state", None)
|
||||
if phase_state is None:
|
||||
return {}
|
||||
mcp_names = getattr(phase_state, "mcp_tool_names_all", set()) or set()
|
||||
independent_tools = getattr(phase_state, "independent_tools", []) or []
|
||||
result: dict[str, list[dict[str, Any]]] = {"MCP Tools": []}
|
||||
for tool in independent_tools:
|
||||
if tool.name not in mcp_names:
|
||||
continue
|
||||
result["MCP Tools"].append(
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": tool.parameters,
|
||||
}
|
||||
)
|
||||
return result if result["MCP Tools"] else {}
|
||||
|
||||
server_map = getattr(registry, "_mcp_server_tools", {}) or {}
|
||||
tools_by_name = {t.name: t for t in registry.get_tools().values()}
|
||||
catalog: dict[str, list[dict[str, Any]]] = {}
|
||||
for server_name, tool_names in server_map.items():
|
||||
entries: list[dict[str, Any]] = []
|
||||
for name in sorted(tool_names):
|
||||
tool = tools_by_name.get(name)
|
||||
if tool is None:
|
||||
continue
|
||||
entries.append(
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": tool.parameters,
|
||||
}
|
||||
)
|
||||
catalog[server_name] = entries
|
||||
return catalog
|
||||
|
||||
|
||||
def _lifecycle_entries(
|
||||
*,
|
||||
session: Any,
|
||||
mcp_tool_names_all: set[str],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Lifecycle tools = independent_tools minus MCP-origin minus synthetic.
|
||||
|
||||
We compute this from a live session when available so the list exactly
|
||||
matches what the queen actually sees on her next turn.
|
||||
"""
|
||||
if session is None:
|
||||
return []
|
||||
phase_state = getattr(session, "phase_state", None)
|
||||
if phase_state is None:
|
||||
return []
|
||||
result: list[dict[str, Any]] = []
|
||||
for tool in getattr(phase_state, "independent_tools", []) or []:
|
||||
if tool.name in mcp_tool_names_all:
|
||||
continue
|
||||
if tool.name in _SYNTHETIC_NAMES:
|
||||
continue
|
||||
result.append(
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"editable": False,
|
||||
}
|
||||
)
|
||||
return sorted(result, key=lambda x: x["name"])
|
||||
|
||||
|
||||
async def handle_get_tools(request: web.Request) -> web.Response:
|
||||
"""GET /api/queen/{queen_id}/tools — enumerate tool surface for the UI."""
|
||||
queen_id = request.match_info["queen_id"]
|
||||
ensure_default_queens()
|
||||
try:
|
||||
load_queen_profile(queen_id)
|
||||
except FileNotFoundError:
|
||||
return web.json_response({"error": f"Queen '{queen_id}' not found"}, status=404)
|
||||
|
||||
manager = request.app.get("manager")
|
||||
session = _live_queen_session(manager, queen_id) if manager is not None else None
|
||||
|
||||
# Prefer a live session's registry for freshness. Otherwise use (or
|
||||
# build on demand) the manager-level catalog so the Tool Library works
|
||||
# even before any queen has been started in this process.
|
||||
if session is not None:
|
||||
catalog = _catalog_from_live_session(session)
|
||||
else:
|
||||
catalog = await _ensure_manager_catalog(manager)
|
||||
stale = not catalog
|
||||
|
||||
mcp_tool_names_all: set[str] = set()
|
||||
for entries in catalog.values():
|
||||
for entry in entries:
|
||||
if entry.get("name"):
|
||||
mcp_tool_names_all.add(entry["name"])
|
||||
|
||||
if session is not None:
|
||||
lifecycle = _lifecycle_entries(
|
||||
session=session,
|
||||
mcp_tool_names_all=mcp_tool_names_all,
|
||||
)
|
||||
else:
|
||||
lifecycle = _lifecycle_entries_without_session(manager, mcp_tool_names_all)
|
||||
|
||||
# Allowlist lives in the dedicated tools.json sidecar; helper
|
||||
# migrates legacy profile.yaml field on first read, and falls back
|
||||
# to the role-based default when no sidecar exists.
|
||||
enabled_mcp_tools = load_queen_tools_config(queen_id, mcp_catalog=catalog)
|
||||
is_role_default = not tools_config_exists(queen_id)
|
||||
|
||||
response = {
|
||||
"queen_id": queen_id,
|
||||
"enabled_mcp_tools": enabled_mcp_tools,
|
||||
"is_role_default": is_role_default,
|
||||
"stale": stale,
|
||||
"lifecycle": lifecycle,
|
||||
"synthetic": _synthetic_entries(),
|
||||
"mcp_servers": _render_mcp_servers(
|
||||
mcp_tool_names_by_server=catalog,
|
||||
enabled_mcp_tools=enabled_mcp_tools,
|
||||
),
|
||||
}
|
||||
return web.json_response(response)
|
||||
|
||||
|
||||
async def handle_patch_tools(request: web.Request) -> web.Response:
|
||||
"""PATCH /api/queen/{queen_id}/tools — persist the MCP tool allowlist.
|
||||
|
||||
Body: ``{"enabled_mcp_tools": null | string[]}``.
|
||||
|
||||
- ``null`` resets to "allow every MCP tool" (default).
|
||||
- A list is validated against the known MCP catalog; unknown names
|
||||
are rejected with 400 so the frontend catches typos.
|
||||
"""
|
||||
queen_id = request.match_info["queen_id"]
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return web.json_response({"error": "Invalid JSON body"}, status=400)
|
||||
if not isinstance(body, dict) or "enabled_mcp_tools" not in body:
|
||||
return web.json_response(
|
||||
{"error": "Body must be an object with an 'enabled_mcp_tools' field"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
enabled = body["enabled_mcp_tools"]
|
||||
if enabled is not None:
|
||||
if not isinstance(enabled, list) or not all(isinstance(x, str) for x in enabled):
|
||||
return web.json_response(
|
||||
{"error": "'enabled_mcp_tools' must be null or a list of strings"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
ensure_default_queens()
|
||||
try:
|
||||
load_queen_profile(queen_id)
|
||||
except FileNotFoundError:
|
||||
return web.json_response({"error": f"Queen '{queen_id}' not found"}, status=404)
|
||||
|
||||
# Validate names against the known MCP tool catalog. We prefer a live
|
||||
# session's registry for the most up-to-date set, then fall back to
|
||||
# the manager-level snapshot (building it on demand if absent).
|
||||
manager = request.app.get("manager")
|
||||
session = _live_queen_session(manager, queen_id) if manager is not None else None
|
||||
if session is not None:
|
||||
catalog = _catalog_from_live_session(session)
|
||||
else:
|
||||
catalog = await _ensure_manager_catalog(manager)
|
||||
known_names: set[str] = set()
|
||||
for entries in catalog.values():
|
||||
for entry in entries:
|
||||
if entry.get("name"):
|
||||
known_names.add(entry["name"])
|
||||
|
||||
if enabled is not None and known_names:
|
||||
unknown = sorted(set(enabled) - known_names)
|
||||
if unknown:
|
||||
return web.json_response(
|
||||
{"error": "Unknown MCP tool name(s)", "unknown": unknown},
|
||||
status=400,
|
||||
)
|
||||
|
||||
# Persist — tools.json sidecar, not profile.yaml.
|
||||
try:
|
||||
update_queen_tools_config(queen_id, enabled)
|
||||
except FileNotFoundError:
|
||||
return web.json_response({"error": f"Queen '{queen_id}' not found"}, status=404)
|
||||
|
||||
# Hot-reload every live DM session for this queen. The filter memo is
|
||||
# rebuilt so the very next turn sees the new allowlist without a
|
||||
# session restart, and the prompt cache is invalidated exactly once.
|
||||
refreshed = 0
|
||||
sessions = getattr(manager, "_sessions", None) or {}
|
||||
for sess in sessions.values():
|
||||
if getattr(sess, "queen_name", None) != queen_id:
|
||||
continue
|
||||
phase_state = getattr(sess, "phase_state", None)
|
||||
if phase_state is None:
|
||||
continue
|
||||
phase_state.enabled_mcp_tools = enabled
|
||||
rebuild = getattr(phase_state, "rebuild_independent_filter", None)
|
||||
if callable(rebuild):
|
||||
try:
|
||||
rebuild()
|
||||
refreshed += 1
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Queen tools: rebuild_independent_filter failed for session %s",
|
||||
getattr(sess, "id", "?"),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Queen tools: queen_id=%s allowlist=%s refreshed_sessions=%d",
|
||||
queen_id,
|
||||
"null" if enabled is None else f"{len(enabled)} tool(s)",
|
||||
refreshed,
|
||||
)
|
||||
return web.json_response(
|
||||
{
|
||||
"queen_id": queen_id,
|
||||
"enabled_mcp_tools": enabled,
|
||||
"refreshed_sessions": refreshed,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def handle_delete_tools(request: web.Request) -> web.Response:
|
||||
"""DELETE /api/queen/{queen_id}/tools — drop the sidecar, fall back to role defaults.
|
||||
|
||||
Users click "Reset to role default" in the Tool Library. That
|
||||
removes ``tools.json`` so the queen's effective allowlist becomes
|
||||
the role-based default (or allow-all if the queen has no role
|
||||
entry). Live sessions are refreshed so the next turn reflects the
|
||||
change without a restart.
|
||||
"""
|
||||
queen_id = request.match_info["queen_id"]
|
||||
ensure_default_queens()
|
||||
try:
|
||||
load_queen_profile(queen_id)
|
||||
except FileNotFoundError:
|
||||
return web.json_response({"error": f"Queen '{queen_id}' not found"}, status=404)
|
||||
|
||||
removed = delete_queen_tools_config(queen_id)
|
||||
|
||||
# Recompute the queen's effective allowlist from the role defaults
|
||||
# so we can hot-reload live sessions in one pass (same shape as
|
||||
# PATCH).
|
||||
manager = request.app.get("manager")
|
||||
session = _live_queen_session(manager, queen_id) if manager is not None else None
|
||||
if session is not None:
|
||||
catalog = _catalog_from_live_session(session)
|
||||
else:
|
||||
catalog = await _ensure_manager_catalog(manager)
|
||||
new_enabled = load_queen_tools_config(queen_id, mcp_catalog=catalog)
|
||||
|
||||
refreshed = 0
|
||||
sessions = getattr(manager, "_sessions", None) or {}
|
||||
for sess in sessions.values():
|
||||
if getattr(sess, "queen_name", None) != queen_id:
|
||||
continue
|
||||
phase_state = getattr(sess, "phase_state", None)
|
||||
if phase_state is None:
|
||||
continue
|
||||
phase_state.enabled_mcp_tools = new_enabled
|
||||
rebuild = getattr(phase_state, "rebuild_independent_filter", None)
|
||||
if callable(rebuild):
|
||||
try:
|
||||
rebuild()
|
||||
refreshed += 1
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Queen tools: rebuild_independent_filter failed for session %s",
|
||||
getattr(sess, "id", "?"),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Queen tools: queen_id=%s reset-to-default removed=%s refreshed_sessions=%d",
|
||||
queen_id,
|
||||
removed,
|
||||
refreshed,
|
||||
)
|
||||
return web.json_response(
|
||||
{
|
||||
"queen_id": queen_id,
|
||||
"removed": removed,
|
||||
"enabled_mcp_tools": new_enabled,
|
||||
"is_role_default": True,
|
||||
"refreshed_sessions": refreshed,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def register_routes(app: web.Application) -> None:
|
||||
"""Register queen-tools routes."""
|
||||
app.router.add_get("/api/queen/{queen_id}/tools", handle_get_tools)
|
||||
app.router.add_patch("/api/queen/{queen_id}/tools", handle_patch_tools)
|
||||
app.router.add_delete("/api/queen/{queen_id}/tools", handle_delete_tools)
|
||||
@@ -248,15 +248,22 @@ async def handle_queen_session(request: web.Request) -> web.Response:
|
||||
# Skip colony sessions: a colony forked from this queen also carries
|
||||
# queen_name == queen_id, but it has a worker loaded (colony_id /
|
||||
# worker_path set) and is the colony's chat, not the queen's DM.
|
||||
for session in manager.list_sessions():
|
||||
if session.queen_name == queen_id and session.colony_id is None and session.worker_path is None:
|
||||
return web.json_response(
|
||||
{
|
||||
"session_id": session.id,
|
||||
"queen_id": queen_id,
|
||||
"status": "live",
|
||||
}
|
||||
)
|
||||
# When multiple DM sessions for this queen are live at once (e.g. the
|
||||
# user created a new session, then navigated away and back), return
|
||||
# the most recently loaded one so we don't resurrect a stale older
|
||||
# session ahead of a freshly created one.
|
||||
live_matches = [
|
||||
s for s in manager.list_sessions() if s.queen_name == queen_id and s.colony_id is None and s.worker_path is None
|
||||
]
|
||||
if live_matches:
|
||||
latest = max(live_matches, key=lambda s: s.loaded_at)
|
||||
return web.json_response(
|
||||
{
|
||||
"session_id": latest.id,
|
||||
"queen_id": queen_id,
|
||||
"status": "live",
|
||||
}
|
||||
)
|
||||
|
||||
# 2. Find the most recent cold session for this queen and resume it.
|
||||
# IMPORTANT: skip sessions that don't belong in the queen DM:
|
||||
@@ -378,6 +385,8 @@ async def handle_select_queen_session(request: web.Request) -> web.Response:
|
||||
|
||||
async def handle_new_queen_session(request: web.Request) -> web.Response:
|
||||
"""POST /api/queen/{queen_id}/session/new -- create a fresh queen session."""
|
||||
from framework.tools.queen_lifecycle_tools import QUEEN_PHASES
|
||||
|
||||
queen_id = request.match_info["queen_id"]
|
||||
manager = request.app["manager"]
|
||||
|
||||
@@ -387,9 +396,25 @@ async def handle_new_queen_session(request: web.Request) -> web.Response:
|
||||
except FileNotFoundError:
|
||||
return web.json_response({"error": f"Queen '{queen_id}' not found"}, status=404)
|
||||
|
||||
body = await request.json() if request.can_read_body else {}
|
||||
if request.can_read_body:
|
||||
try:
|
||||
body = await request.json()
|
||||
except json.JSONDecodeError:
|
||||
return web.json_response({"error": "Invalid JSON body"}, status=400)
|
||||
if not isinstance(body, dict):
|
||||
return web.json_response({"error": "Request body must be a JSON object"}, status=400)
|
||||
else:
|
||||
body = {}
|
||||
initial_prompt = body.get("initial_prompt")
|
||||
initial_phase = body.get("initial_phase") or "independent"
|
||||
if initial_phase not in QUEEN_PHASES:
|
||||
return web.json_response(
|
||||
{
|
||||
"error": f"Invalid initial_phase '{initial_phase}'",
|
||||
"valid": sorted(QUEEN_PHASES),
|
||||
},
|
||||
status=400,
|
||||
)
|
||||
|
||||
session = await manager.create_session(
|
||||
initial_prompt=initial_prompt,
|
||||
|
||||
@@ -122,8 +122,19 @@ async def handle_create_session(request: web.Request) -> web.Response:
|
||||
(equivalent to the old POST /api/agents). Otherwise creates a queen-only
|
||||
session that can later have a colony loaded via POST /sessions/{id}/colony.
|
||||
"""
|
||||
from framework.agents.queen.queen_profiles import ensure_default_queens, load_queen_profile
|
||||
from framework.tools.queen_lifecycle_tools import QUEEN_PHASES
|
||||
|
||||
manager = _get_manager(request)
|
||||
body = await request.json() if request.can_read_body else {}
|
||||
if request.can_read_body:
|
||||
try:
|
||||
body = await request.json()
|
||||
except json.JSONDecodeError:
|
||||
return web.json_response({"error": "Invalid JSON body"}, status=400)
|
||||
if not isinstance(body, dict):
|
||||
return web.json_response({"error": "Request body must be a JSON object"}, status=400)
|
||||
else:
|
||||
body = {}
|
||||
agent_path = body.get("agent_path")
|
||||
agent_id = body.get("agent_id")
|
||||
session_id = body.get("session_id")
|
||||
@@ -134,6 +145,21 @@ async def handle_create_session(request: web.Request) -> web.Response:
|
||||
initial_phase = body.get("initial_phase")
|
||||
worker_name = body.get("worker_name")
|
||||
|
||||
if initial_phase is not None and initial_phase not in QUEEN_PHASES:
|
||||
return web.json_response(
|
||||
{
|
||||
"error": f"Invalid initial_phase '{initial_phase}'",
|
||||
"valid": sorted(QUEEN_PHASES),
|
||||
},
|
||||
status=400,
|
||||
)
|
||||
if queen_name:
|
||||
ensure_default_queens()
|
||||
try:
|
||||
load_queen_profile(queen_name)
|
||||
except FileNotFoundError:
|
||||
return web.json_response({"error": f"Queen '{queen_name}' not found"}, status=404)
|
||||
|
||||
if agent_path:
|
||||
try:
|
||||
agent_path = str(validate_agent_path(agent_path))
|
||||
@@ -160,6 +186,7 @@ async def handle_create_session(request: web.Request) -> web.Response:
|
||||
model=model,
|
||||
initial_prompt=initial_prompt,
|
||||
queen_resume_from=queen_resume_from,
|
||||
queen_name=queen_name,
|
||||
initial_phase=initial_phase,
|
||||
)
|
||||
except ValueError as e:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,112 @@
|
||||
"""REST routes for task lists.
|
||||
|
||||
GET /api/tasks/{task_list_id} -- snapshot of one list
|
||||
GET /api/colonies/{colony_id}/task_lists -- helper for colony view
|
||||
GET /api/sessions/{session_id}/task_list_id -- helper for session view
|
||||
|
||||
The task_list_id segment uses URL-encoded colons (``colony%3Aabc`` /
|
||||
``session%3Aagent%3Asess``); aiohttp decodes them automatically.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from framework.tasks import get_task_store
|
||||
from framework.tasks.scoping import (
|
||||
colony_task_list_id,
|
||||
session_task_list_id,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def handle_get_task_list(request: web.Request) -> web.Response:
|
||||
raw = request.match_info.get("task_list_id", "")
|
||||
if not raw:
|
||||
return web.json_response({"error": "task_list_id required"}, status=400)
|
||||
|
||||
store = get_task_store()
|
||||
if not await store.list_exists(raw):
|
||||
return web.json_response(
|
||||
{"error": f"Task list {raw!r} not found", "task_list_id": raw, "tasks": []},
|
||||
status=404,
|
||||
)
|
||||
|
||||
meta = await store.get_meta(raw)
|
||||
records = await store.list_tasks(raw)
|
||||
return web.json_response(
|
||||
{
|
||||
"task_list_id": raw,
|
||||
"role": meta.role.value if meta else "session",
|
||||
"meta": meta.model_dump(mode="json") if meta else None,
|
||||
"tasks": [
|
||||
{
|
||||
"id": r.id,
|
||||
"subject": r.subject,
|
||||
"description": r.description,
|
||||
"active_form": r.active_form,
|
||||
"owner": r.owner,
|
||||
"status": r.status.value,
|
||||
"blocks": list(r.blocks),
|
||||
"blocked_by": list(r.blocked_by),
|
||||
"metadata": dict(r.metadata),
|
||||
"created_at": r.created_at,
|
||||
"updated_at": r.updated_at,
|
||||
}
|
||||
for r in records
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def handle_get_colony_task_lists(request: web.Request) -> web.Response:
|
||||
"""Return template_task_list_id and queen_session_task_list_id for a colony."""
|
||||
colony_id = request.match_info.get("colony_id", "")
|
||||
if not colony_id:
|
||||
return web.json_response({"error": "colony_id required"}, status=400)
|
||||
|
||||
template_id = colony_task_list_id(colony_id)
|
||||
# Queen's session list — the queen-of-colony's session_id == the
|
||||
# browser-facing colony session id. The frontend already knows that
|
||||
# value; we surface what we have on disk for completeness.
|
||||
queen_session_id = request.query.get("queen_session_id")
|
||||
queen_list_id = session_task_list_id("queen", queen_session_id) if queen_session_id else None
|
||||
return web.json_response(
|
||||
{
|
||||
"template_task_list_id": template_id,
|
||||
"queen_session_task_list_id": queen_list_id,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def handle_get_session_task_list_id(request: web.Request) -> web.Response:
|
||||
"""Return task_list_id and picked_up_from for a session.
|
||||
|
||||
The session_id is the queen's session id or a worker's session id;
|
||||
both follow the same path. The agent_id is read from the request query
|
||||
(passed by the frontend, which already knows which agent the session
|
||||
belongs to).
|
||||
"""
|
||||
session_id = request.match_info.get("session_id", "")
|
||||
agent_id = request.query.get("agent_id", "queen")
|
||||
if not session_id:
|
||||
return web.json_response({"error": "session_id required"}, status=400)
|
||||
|
||||
task_list_id = session_task_list_id(agent_id, session_id)
|
||||
store = get_task_store()
|
||||
exists = await store.list_exists(task_list_id)
|
||||
return web.json_response(
|
||||
{
|
||||
"task_list_id": task_list_id if exists else None,
|
||||
"picked_up_from": None,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def register_routes(app: web.Application) -> None:
|
||||
app.router.add_get("/api/tasks/{task_list_id}", handle_get_task_list)
|
||||
app.router.add_get("/api/colonies/{colony_id}/task_lists", handle_get_colony_task_lists)
|
||||
app.router.add_get("/api/sessions/{session_id}/task_list_id", handle_get_session_task_list_id)
|
||||
@@ -1223,8 +1223,27 @@ class SessionManager:
|
||||
logger.info("Session '%s': shutdown reflection spawned", session_id)
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
except Exception:
|
||||
logger.warning("Session '%s': failed to spawn shutdown reflection", session_id, exc_info=True)
|
||||
except RuntimeError as exc:
|
||||
# Most common when a session is stopped after the event loop
|
||||
# has closed (e.g. during server shutdown or from an atexit
|
||||
# handler). The reflection would have had nothing to write
|
||||
# anyway — no new turns since the last periodic reflection.
|
||||
logger.warning(
|
||||
"Session '%s': shutdown reflection skipped — event loop unavailable (%s). "
|
||||
"Normal during server shutdown; anything worth persisting was saved by the "
|
||||
"periodic reflection after the last turn.",
|
||||
session_id,
|
||||
exc,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Session '%s': failed to spawn shutdown reflection: %s: %s. "
|
||||
"Check that queen_dir exists and session.llm is configured; full traceback follows.",
|
||||
session_id,
|
||||
type(exc).__name__,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if session.queen_task is not None:
|
||||
session.queen_task.cancel()
|
||||
@@ -1516,8 +1535,46 @@ class SessionManager:
|
||||
tool_executor=queen_tool_executor,
|
||||
event_bus=session.event_bus,
|
||||
colony_id=session.id,
|
||||
# Wire the on-disk colony name and queen id so
|
||||
# ColonyRuntime auto-derives its override paths. DM sessions
|
||||
# have no colony_name (session.colony_name is None), which
|
||||
# keeps them out of the per-colony JSON store.
|
||||
colony_name=getattr(session, "colony_name", None),
|
||||
queen_id=getattr(session, "queen_name", None) or None,
|
||||
pipeline_stages=[], # queen pipeline runs in queen_orchestrator, not here
|
||||
)
|
||||
|
||||
# Per-colony tool allowlist, loaded from the colony's metadata.json
|
||||
# when this session is attached to a real forked colony. For pure
|
||||
# queen DM sessions (session.colony_name is None) we only capture
|
||||
# the MCP-origin set — the allowlist stays ``None`` so every MCP
|
||||
# tool passes through by default.
|
||||
try:
|
||||
mcp_tool_names_all: set[str] = set()
|
||||
mgr_catalog = getattr(self, "_mcp_tool_catalog", None)
|
||||
if isinstance(mgr_catalog, dict):
|
||||
for entries in mgr_catalog.values():
|
||||
for entry in entries:
|
||||
name = entry.get("name") if isinstance(entry, dict) else None
|
||||
if name:
|
||||
mcp_tool_names_all.add(name)
|
||||
enabled_mcp_tools: list[str] | None = None
|
||||
colony_name = getattr(session, "colony_name", None)
|
||||
if colony_name:
|
||||
# Colony tool allowlist lives in a dedicated tools.json
|
||||
# sidecar next to metadata.json. The helper migrates any
|
||||
# legacy field out of metadata.json on first read.
|
||||
from framework.host.colony_tools_config import load_colony_tools_config
|
||||
|
||||
enabled_mcp_tools = load_colony_tools_config(colony_name)
|
||||
colony.set_tool_allowlist(enabled_mcp_tools, mcp_tool_names_all)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Colony allowlist bootstrap failed for session %s",
|
||||
session.id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
await colony.start()
|
||||
session.colony = colony
|
||||
|
||||
@@ -1707,6 +1764,42 @@ class SessionManager:
|
||||
def list_sessions(self) -> list[Session]:
|
||||
return list(self._sessions.values())
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Skill override helpers — used by routes_skills to find every live
|
||||
# SkillsManager affected by a queen- or colony-scope mutation so a
|
||||
# single HTTP call can reload them all.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def iter_queen_sessions(self, queen_id: str):
|
||||
"""Yield live sessions whose queen matches ``queen_id``."""
|
||||
for s in self._sessions.values():
|
||||
if getattr(s, "queen_name", None) == queen_id:
|
||||
yield s
|
||||
|
||||
def iter_colony_runtimes(
|
||||
self,
|
||||
*,
|
||||
queen_id: str | None = None,
|
||||
colony_name: str | None = None,
|
||||
):
|
||||
"""Yield live ``ColonyRuntime`` instances matching the filters.
|
||||
|
||||
``queen_id`` alone → every runtime whose ``queen_id`` matches
|
||||
(useful when the user toggles a queen-scope skill — all her
|
||||
colonies must reload). ``colony_name`` alone → the single
|
||||
runtime pinned to that colony. Both → intersection. No filters
|
||||
→ every live runtime (used by global ``/api/skills`` reload).
|
||||
"""
|
||||
for s in self._sessions.values():
|
||||
colony = getattr(s, "colony", None)
|
||||
if colony is None:
|
||||
continue
|
||||
if queen_id is not None and getattr(colony, "queen_id", None) != queen_id:
|
||||
continue
|
||||
if colony_name is not None and getattr(colony, "colony_name", None) != colony_name:
|
||||
continue
|
||||
yield colony
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Cold session helpers (disk-only, no live runtime required)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,300 @@
|
||||
"""Tests for the per-colony MCP tool allowlist filter + routes.
|
||||
|
||||
Covers:
|
||||
1. ``ColonyRuntime`` filter semantics (default-allow, allowlist, empty,
|
||||
lifecycle passes through).
|
||||
2. routes_colony_tools round trip (GET/PATCH, validation, 404).
|
||||
3. Colony index route for the Tool Library picker.
|
||||
|
||||
Routes never touch the real ``~/.hive/colonies`` tree — we redirect
|
||||
``COLONIES_DIR`` into ``tmp_path`` via monkeypatch.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
from framework.host.colony_runtime import ColonyRuntime
|
||||
from framework.llm.provider import Tool
|
||||
from framework.server import routes_colony_tools
|
||||
|
||||
|
||||
def _tool(name: str) -> Tool:
|
||||
return Tool(name=name, description=f"desc of {name}", parameters={"type": "object"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ColonyRuntime filter unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _bare_runtime() -> ColonyRuntime:
|
||||
rt = ColonyRuntime.__new__(ColonyRuntime)
|
||||
rt._enabled_mcp_tools = None
|
||||
rt._mcp_tool_names_all = set()
|
||||
return rt
|
||||
|
||||
|
||||
class TestColonyFilter:
|
||||
def test_default_is_noop(self):
|
||||
rt = _bare_runtime()
|
||||
tools = [_tool("mcp_a"), _tool("lc_b")]
|
||||
assert rt._apply_tool_allowlist(tools) == tools
|
||||
|
||||
def test_allowlist_gates_mcp_only(self):
|
||||
rt = _bare_runtime()
|
||||
rt._mcp_tool_names_all = {"mcp_a", "mcp_b"}
|
||||
rt._enabled_mcp_tools = ["mcp_a"]
|
||||
tools = [_tool("mcp_a"), _tool("mcp_b"), _tool("lc_c")]
|
||||
names = [t.name for t in rt._apply_tool_allowlist(tools)]
|
||||
assert names == ["mcp_a", "lc_c"]
|
||||
|
||||
def test_empty_allowlist_keeps_lifecycle(self):
|
||||
rt = _bare_runtime()
|
||||
rt._mcp_tool_names_all = {"mcp_a", "mcp_b"}
|
||||
rt._enabled_mcp_tools = []
|
||||
tools = [_tool("mcp_a"), _tool("mcp_b"), _tool("lc_c")]
|
||||
names = [t.name for t in rt._apply_tool_allowlist(tools)]
|
||||
assert names == ["lc_c"]
|
||||
|
||||
def test_setter_mutates_live_state(self):
|
||||
rt = _bare_runtime()
|
||||
rt.set_tool_allowlist(["x"], {"x", "y"})
|
||||
assert rt._enabled_mcp_tools == ["x"]
|
||||
assert rt._mcp_tool_names_all == {"x", "y"}
|
||||
|
||||
# Passing None on allowlist clears gating; mcp_tool_names_all
|
||||
# defaults to "keep current" so a subsequent caller doesn't need
|
||||
# to repeat the set.
|
||||
rt.set_tool_allowlist(None)
|
||||
assert rt._enabled_mcp_tools is None
|
||||
assert rt._mcp_tool_names_all == {"x", "y"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Route round-trip tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeSession:
|
||||
colony_name: str
|
||||
colony: Any = None
|
||||
colony_runtime: Any = None
|
||||
id: str = "sess-1"
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeManager:
|
||||
_sessions: dict = field(default_factory=dict)
|
||||
_mcp_tool_catalog: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def colony_dir(tmp_path, monkeypatch):
|
||||
"""Point COLONIES_DIR into a tmp tree and seed a colony."""
|
||||
colonies = tmp_path / "colonies"
|
||||
colonies.mkdir()
|
||||
monkeypatch.setattr("framework.host.colony_metadata.COLONIES_DIR", colonies)
|
||||
monkeypatch.setattr("framework.host.colony_tools_config.COLONIES_DIR", colonies)
|
||||
|
||||
name = "my_colony"
|
||||
cdir = colonies / name
|
||||
cdir.mkdir()
|
||||
(cdir / "metadata.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"colony_name": name,
|
||||
"queen_name": "queen_technology",
|
||||
"created_at": "2026-04-20T00:00:00+00:00",
|
||||
}
|
||||
)
|
||||
)
|
||||
return colonies, name
|
||||
|
||||
|
||||
async def _app(manager: _FakeManager) -> web.Application:
|
||||
app = web.Application()
|
||||
app["manager"] = manager
|
||||
routes_colony_tools.register_routes(app)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tools_default_allow(colony_dir):
|
||||
_, name = colony_dir
|
||||
manager = _FakeManager(
|
||||
_mcp_tool_catalog={
|
||||
"coder-tools": [
|
||||
{"name": "read_file", "description": "read", "input_schema": {}},
|
||||
{"name": "write_file", "description": "write", "input_schema": {}},
|
||||
],
|
||||
}
|
||||
)
|
||||
app = await _app(manager)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.get(f"/api/colony/{name}/tools")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["enabled_mcp_tools"] is None
|
||||
assert body["stale"] is False
|
||||
tools = {t["name"]: t for t in body["mcp_servers"][0]["tools"]}
|
||||
assert all(t["enabled"] for t in tools.values())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_patch_persists_and_validates(colony_dir):
|
||||
colonies_dir, name = colony_dir
|
||||
manager = _FakeManager(
|
||||
_mcp_tool_catalog={
|
||||
"coder-tools": [
|
||||
{"name": "read_file", "description": "", "input_schema": {}},
|
||||
{"name": "write_file", "description": "", "input_schema": {}},
|
||||
]
|
||||
}
|
||||
)
|
||||
app = await _app(manager)
|
||||
tools_path = colonies_dir / name / "tools.json"
|
||||
metadata_path = colonies_dir / name / "metadata.json"
|
||||
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.patch(f"/api/colony/{name}/tools", json={"enabled_mcp_tools": ["read_file"]})
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["enabled_mcp_tools"] == ["read_file"]
|
||||
|
||||
# Persisted to tools.json; metadata.json does not carry the field.
|
||||
sidecar = json.loads(tools_path.read_text())
|
||||
assert sidecar["enabled_mcp_tools"] == ["read_file"]
|
||||
assert "updated_at" in sidecar
|
||||
meta = json.loads(metadata_path.read_text())
|
||||
assert "enabled_mcp_tools" not in meta
|
||||
|
||||
# GET reflects the allowlist
|
||||
resp = await client.get(f"/api/colony/{name}/tools")
|
||||
body = await resp.json()
|
||||
tools = {t["name"]: t for t in body["mcp_servers"][0]["tools"]}
|
||||
assert tools["read_file"]["enabled"] is True
|
||||
assert tools["write_file"]["enabled"] is False
|
||||
|
||||
# Unknown → 400
|
||||
resp = await client.patch(f"/api/colony/{name}/tools", json={"enabled_mcp_tools": ["ghost"]})
|
||||
assert resp.status == 400
|
||||
assert "ghost" in (await resp.json()).get("unknown", [])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_patch_refreshes_live_runtime(colony_dir):
|
||||
_, name = colony_dir
|
||||
|
||||
rt = _bare_runtime()
|
||||
rt._mcp_tool_names_all = {"read_file", "write_file"}
|
||||
rt.set_tool_allowlist(None)
|
||||
|
||||
session = _FakeSession(colony_name=name, colony=rt)
|
||||
manager = _FakeManager(
|
||||
_sessions={session.id: session},
|
||||
_mcp_tool_catalog={
|
||||
"coder-tools": [
|
||||
{"name": "read_file", "description": "", "input_schema": {}},
|
||||
{"name": "write_file", "description": "", "input_schema": {}},
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
app = await _app(manager)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.patch(f"/api/colony/{name}/tools", json={"enabled_mcp_tools": ["read_file"]})
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["refreshed_runtimes"] == 1
|
||||
assert rt._enabled_mcp_tools == ["read_file"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_404_for_unknown_colony(colony_dir):
|
||||
manager = _FakeManager()
|
||||
app = await _app(manager)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.get("/api/colony/unknown/tools")
|
||||
assert resp.status == 404
|
||||
resp = await client.patch("/api/colony/unknown/tools", json={"enabled_mcp_tools": None})
|
||||
assert resp.status == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tools_index_lists_colonies(colony_dir):
|
||||
_, name = colony_dir
|
||||
manager = _FakeManager()
|
||||
app = await _app(manager)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.get("/api/colonies/tools-index")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
entries = {c["name"]: c for c in body["colonies"]}
|
||||
assert name in entries
|
||||
assert entries[name]["queen_name"] == "queen_technology"
|
||||
assert entries[name]["has_allowlist"] is False
|
||||
|
||||
|
||||
def test_queen_allowlist_inherits_into_new_colony(tmp_path, monkeypatch):
|
||||
"""A colony forked with a curated queen inherits her allowlist.
|
||||
|
||||
Exercises the inheritance hook in
|
||||
``routes_execution.fork_session_into_colony`` without running the
|
||||
full fork machinery — we just call
|
||||
``update_colony_tools_config`` the same way the hook does and
|
||||
assert the colony's ``tools.json`` matches the queen's live list.
|
||||
"""
|
||||
colonies = tmp_path / "colonies"
|
||||
colonies.mkdir()
|
||||
monkeypatch.setattr("framework.host.colony_tools_config.COLONIES_DIR", colonies)
|
||||
|
||||
from framework.host.colony_tools_config import (
|
||||
load_colony_tools_config,
|
||||
update_colony_tools_config,
|
||||
)
|
||||
|
||||
colony_name = "forked_child"
|
||||
(colonies / colony_name).mkdir()
|
||||
|
||||
# Simulate: queen has a curated allowlist (e.g. role default resolved
|
||||
# to a concrete list). The inheritance hook copies it verbatim.
|
||||
queen_live_allowlist = ["read_file", "web_scrape", "csv_read"]
|
||||
update_colony_tools_config(colony_name, list(queen_live_allowlist))
|
||||
|
||||
assert load_colony_tools_config(colony_name) == queen_live_allowlist
|
||||
|
||||
|
||||
def test_legacy_metadata_field_migrates_to_sidecar(colony_dir):
|
||||
"""A legacy enabled_mcp_tools field in metadata.json is hoisted to tools.json."""
|
||||
colonies_dir, name = colony_dir
|
||||
meta_path = colonies_dir / name / "metadata.json"
|
||||
tools_path = colonies_dir / name / "tools.json"
|
||||
|
||||
# Seed legacy field in metadata.json.
|
||||
meta = json.loads(meta_path.read_text())
|
||||
meta["enabled_mcp_tools"] = ["read_file"]
|
||||
meta_path.write_text(json.dumps(meta))
|
||||
|
||||
from framework.host.colony_tools_config import load_colony_tools_config
|
||||
|
||||
# First load migrates.
|
||||
assert load_colony_tools_config(name) == ["read_file"]
|
||||
assert tools_path.exists()
|
||||
sidecar = json.loads(tools_path.read_text())
|
||||
assert sidecar["enabled_mcp_tools"] == ["read_file"]
|
||||
|
||||
# metadata.json no longer contains the field; provenance fields preserved.
|
||||
migrated = json.loads(meta_path.read_text())
|
||||
assert "enabled_mcp_tools" not in migrated
|
||||
assert migrated["queen_name"] == "queen_technology"
|
||||
|
||||
# Second load is a direct sidecar read.
|
||||
assert load_colony_tools_config(name) == ["read_file"]
|
||||
@@ -0,0 +1,239 @@
|
||||
"""Tests for the MCP server CRUD HTTP routes.
|
||||
|
||||
Monkey-patches ``MCPRegistry`` inside ``routes_mcp`` so the HTTP layer is
|
||||
exercised without reading or writing ``~/.hive/mcp_registry/installed.json``
|
||||
or spawning actual subprocesses.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
from framework.loader.mcp_errors import MCPError, MCPErrorCode
|
||||
from framework.server import routes_mcp
|
||||
|
||||
|
||||
class _FakeRegistry:
|
||||
"""Stand-in for MCPRegistry — just enough surface for the routes."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._servers: dict[str, dict[str, Any]] = {
|
||||
"built-in-seed": {
|
||||
"source": "registry",
|
||||
"transport": "stdio",
|
||||
"enabled": True,
|
||||
"manifest": {"description": "Factory-seeded server", "tools": []},
|
||||
"last_health_status": "healthy",
|
||||
"last_error": None,
|
||||
"last_health_check_at": None,
|
||||
}
|
||||
}
|
||||
|
||||
def initialize(self) -> None: # noqa: D401 — registry idempotent init
|
||||
return
|
||||
|
||||
def list_installed(self) -> list[dict[str, Any]]:
|
||||
return [{"name": name, **entry} for name, entry in self._servers.items()]
|
||||
|
||||
def get_server(self, name: str) -> dict | None:
|
||||
if name not in self._servers:
|
||||
return None
|
||||
return {"name": name, **self._servers[name]}
|
||||
|
||||
def add_local(self, *, name: str, transport: str, **kwargs: Any) -> dict:
|
||||
if name in self._servers:
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Server '{name}' already exists",
|
||||
why="A server with this name is already registered locally.",
|
||||
fix=f"Run: hive mcp remove {name}",
|
||||
)
|
||||
entry = {
|
||||
"source": "local",
|
||||
"transport": transport,
|
||||
"enabled": True,
|
||||
"manifest": {"description": kwargs.get("description") or ""},
|
||||
"last_health_status": None,
|
||||
"last_error": None,
|
||||
"last_health_check_at": None,
|
||||
}
|
||||
self._servers[name] = entry
|
||||
return entry
|
||||
|
||||
def remove(self, name: str) -> None:
|
||||
if name not in self._servers:
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what=f"Cannot remove server '{name}'",
|
||||
why="Server is not installed.",
|
||||
fix="Run: hive mcp list",
|
||||
)
|
||||
del self._servers[name]
|
||||
|
||||
def enable(self, name: str) -> None:
|
||||
if name not in self._servers:
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what="not found",
|
||||
why="not found",
|
||||
fix="x",
|
||||
)
|
||||
self._servers[name]["enabled"] = True
|
||||
|
||||
def disable(self, name: str) -> None:
|
||||
if name not in self._servers:
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_INSTALL_FAILED,
|
||||
what="not found",
|
||||
why="not found",
|
||||
fix="x",
|
||||
)
|
||||
self._servers[name]["enabled"] = False
|
||||
|
||||
def health_check(self, name: str) -> dict[str, Any]:
|
||||
if name not in self._servers:
|
||||
raise MCPError(
|
||||
code=MCPErrorCode.MCP_HEALTH_FAILED,
|
||||
what="not found",
|
||||
why="not found",
|
||||
fix="x",
|
||||
)
|
||||
return {"name": name, "status": "healthy", "tools": 3, "error": None}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def registry(monkeypatch):
|
||||
reg = _FakeRegistry()
|
||||
monkeypatch.setattr(routes_mcp, "_registry", lambda: reg)
|
||||
return reg
|
||||
|
||||
|
||||
async def _make_app() -> web.Application:
|
||||
app = web.Application()
|
||||
routes_mcp.register_routes(app)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_servers_returns_built_in(registry):
|
||||
app = await _make_app()
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.get("/api/mcp/servers")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
names = {s["name"] for s in body["servers"]}
|
||||
# The registry fake carries one entry; the list also merges package-
|
||||
# baked entries from core/framework/agents/queen/mcp_servers.json so
|
||||
# the UI matches what the queen actually loads. Both should appear.
|
||||
assert "built-in-seed" in names
|
||||
sources = {s["name"]: s["source"] for s in body["servers"]}
|
||||
assert sources.get("built-in-seed") == "registry"
|
||||
# The package-baked servers (coder-tools/gcu-tools/hive_tools) carry
|
||||
# source=="built-in" and are non-removable.
|
||||
pkg_entries = [s for s in body["servers"] if s["source"] == "built-in"]
|
||||
assert pkg_entries, "expected at least one package-baked MCP server"
|
||||
assert all(s.get("removable") is False for s in pkg_entries)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_local_server(registry):
|
||||
app = await _make_app()
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.post(
|
||||
"/api/mcp/servers",
|
||||
json={
|
||||
"name": "my-tool",
|
||||
"transport": "stdio",
|
||||
"command": "echo",
|
||||
"args": ["hi"],
|
||||
"description": "says hi",
|
||||
},
|
||||
)
|
||||
assert resp.status == 201
|
||||
body = await resp.json()
|
||||
assert body["server"]["name"] == "my-tool"
|
||||
assert body["server"]["source"] == "local"
|
||||
|
||||
resp = await client.get("/api/mcp/servers")
|
||||
names = [s["name"] for s in (await resp.json())["servers"]]
|
||||
assert "my-tool" in names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_rejects_duplicate(registry):
|
||||
app = await _make_app()
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
for _ in range(2):
|
||||
resp = await client.post(
|
||||
"/api/mcp/servers",
|
||||
json={"name": "dup", "transport": "stdio", "command": "x"},
|
||||
)
|
||||
assert resp.status == 409
|
||||
body = await resp.json()
|
||||
assert "already exists" in body["error"].lower()
|
||||
assert body["fix"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_rejects_invalid_transport(registry):
|
||||
app = await _make_app()
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.post(
|
||||
"/api/mcp/servers",
|
||||
json={"name": "x", "transport": "nope"},
|
||||
)
|
||||
assert resp.status == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enable_disable_cycle(registry):
|
||||
app = await _make_app()
|
||||
# Seed a local server
|
||||
registry.add_local(name="local-one", transport="stdio", command="x")
|
||||
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.post("/api/mcp/servers/local-one/disable")
|
||||
assert resp.status == 200
|
||||
assert (await resp.json())["enabled"] is False
|
||||
assert registry._servers["local-one"]["enabled"] is False
|
||||
|
||||
resp = await client.post("/api/mcp/servers/local-one/enable")
|
||||
assert resp.status == 200
|
||||
assert (await resp.json())["enabled"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_local_only(registry):
|
||||
app = await _make_app()
|
||||
registry.add_local(name="local-two", transport="stdio", command="x")
|
||||
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
# Built-ins are protected
|
||||
resp = await client.delete("/api/mcp/servers/built-in-seed")
|
||||
assert resp.status == 400
|
||||
|
||||
# Missing
|
||||
resp = await client.delete("/api/mcp/servers/ghost")
|
||||
assert resp.status == 404
|
||||
|
||||
# Happy path
|
||||
resp = await client.delete("/api/mcp/servers/local-two")
|
||||
assert resp.status == 200
|
||||
assert "local-two" not in registry._servers
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check(registry, monkeypatch):
|
||||
app = await _make_app()
|
||||
registry.add_local(name="pingable", transport="stdio", command="x")
|
||||
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.post("/api/mcp/servers/pingable/health")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["status"] == "healthy"
|
||||
assert body["tools"] == 3
|
||||
@@ -0,0 +1,443 @@
|
||||
"""Tests for the per-queen MCP tool allowlist filter + routes.
|
||||
|
||||
Covers:
|
||||
1. QueenPhaseState filter semantics (default-allow, allowlist, empty, phase-
|
||||
isolation, memo identity for LLM prompt-cache stability).
|
||||
2. routes_queen_tools round trip (GET, PATCH, validation, live-session
|
||||
hot-reload).
|
||||
|
||||
Route tests monkey-patch a tiny queen profile + manager catalog; they never
|
||||
spawn an MCP subprocess.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from aiohttp import web
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
from framework.llm.provider import Tool
|
||||
from framework.server import routes_queen_tools
|
||||
from framework.tools.queen_lifecycle_tools import QueenPhaseState
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# QueenPhaseState filter — pure unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _tool(name: str) -> Tool:
|
||||
return Tool(name=name, description=f"desc of {name}", parameters={"type": "object"})
|
||||
|
||||
|
||||
class TestPhaseStateFilter:
|
||||
def test_default_allow_returns_every_tool(self):
|
||||
ps = QueenPhaseState(phase="independent")
|
||||
ps.independent_tools = [_tool("mcp_a"), _tool("mcp_b"), _tool("lc_c")]
|
||||
ps.mcp_tool_names_all = {"mcp_a", "mcp_b"}
|
||||
ps.enabled_mcp_tools = None
|
||||
ps.rebuild_independent_filter()
|
||||
|
||||
names = [t.name for t in ps.get_current_tools()]
|
||||
assert names == ["mcp_a", "mcp_b", "lc_c"]
|
||||
|
||||
def test_allowlist_keeps_listed_mcp_plus_all_lifecycle(self):
|
||||
ps = QueenPhaseState(phase="independent")
|
||||
ps.independent_tools = [_tool("mcp_a"), _tool("mcp_b"), _tool("lc_c")]
|
||||
ps.mcp_tool_names_all = {"mcp_a", "mcp_b"}
|
||||
ps.enabled_mcp_tools = ["mcp_a"]
|
||||
ps.rebuild_independent_filter()
|
||||
|
||||
names = [t.name for t in ps.get_current_tools()]
|
||||
assert names == ["mcp_a", "lc_c"]
|
||||
|
||||
def test_empty_allowlist_keeps_only_lifecycle(self):
|
||||
ps = QueenPhaseState(phase="independent")
|
||||
ps.independent_tools = [_tool("mcp_a"), _tool("mcp_b"), _tool("lc_c")]
|
||||
ps.mcp_tool_names_all = {"mcp_a", "mcp_b"}
|
||||
ps.enabled_mcp_tools = []
|
||||
ps.rebuild_independent_filter()
|
||||
|
||||
names = [t.name for t in ps.get_current_tools()]
|
||||
assert names == ["lc_c"]
|
||||
|
||||
def test_filter_isolated_to_independent_phase(self):
|
||||
ps = QueenPhaseState(phase="independent")
|
||||
ps.independent_tools = [_tool("mcp_a"), _tool("lc_c")]
|
||||
ps.working_tools = [_tool("mcp_a"), _tool("lc_c")]
|
||||
ps.mcp_tool_names_all = {"mcp_a"}
|
||||
ps.enabled_mcp_tools = []
|
||||
ps.rebuild_independent_filter()
|
||||
|
||||
# Independent → filtered
|
||||
assert [t.name for t in ps.get_current_tools()] == ["lc_c"]
|
||||
|
||||
# Other phases → unaffected
|
||||
ps.phase = "working"
|
||||
assert [t.name for t in ps.get_current_tools()] == ["mcp_a", "lc_c"]
|
||||
|
||||
def test_memo_returns_stable_identity_for_prompt_cache(self):
|
||||
"""Same Python list object across turns → LLM prompt cache stays warm."""
|
||||
ps = QueenPhaseState(phase="independent")
|
||||
ps.independent_tools = [_tool("mcp_a"), _tool("lc_c")]
|
||||
ps.mcp_tool_names_all = {"mcp_a"}
|
||||
ps.enabled_mcp_tools = None
|
||||
ps.rebuild_independent_filter()
|
||||
|
||||
first = ps.get_current_tools()
|
||||
second = ps.get_current_tools()
|
||||
assert first is second, "memoized list must be the same object across turns"
|
||||
|
||||
# A rebuild should produce a different object so downstream caches
|
||||
# correctly invalidate.
|
||||
ps.enabled_mcp_tools = ["mcp_a"]
|
||||
ps.rebuild_independent_filter()
|
||||
third = ps.get_current_tools()
|
||||
assert third is not first
|
||||
assert [t.name for t in third] == ["mcp_a", "lc_c"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Route round-trip tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeSession:
|
||||
queen_name: str
|
||||
phase_state: QueenPhaseState
|
||||
colony_runtime: Any = None
|
||||
id: str = "sess-1"
|
||||
_queen_tool_registry: Any = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeManager:
|
||||
_sessions: dict = field(default_factory=dict)
|
||||
_mcp_tool_catalog: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def queen_dir(tmp_path, monkeypatch):
|
||||
"""Redirect queen profile + tools storage into a tmp dir."""
|
||||
queens_dir = tmp_path / "queens"
|
||||
queens_dir.mkdir()
|
||||
monkeypatch.setattr("framework.agents.queen.queen_profiles.QUEENS_DIR", queens_dir)
|
||||
monkeypatch.setattr("framework.agents.queen.queen_tools_config.QUEENS_DIR", queens_dir)
|
||||
|
||||
queen_id = "queen_technology"
|
||||
(queens_dir / queen_id).mkdir()
|
||||
(queens_dir / queen_id / "profile.yaml").write_text(
|
||||
yaml.safe_dump({"name": "Alexandra", "title": "Head of Technology"})
|
||||
)
|
||||
return queens_dir, queen_id
|
||||
|
||||
|
||||
async def _make_app(*, manager: _FakeManager) -> web.Application:
|
||||
app = web.Application()
|
||||
app["manager"] = manager
|
||||
routes_queen_tools.register_routes(app)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tools_default_allows_everything_for_unknown_queen(queen_dir, monkeypatch):
|
||||
"""Queens NOT in the role-default table fall back to allow-all."""
|
||||
monkeypatch.setattr(routes_queen_tools, "ensure_default_queens", lambda: None)
|
||||
|
||||
queens_dir, _ = queen_dir
|
||||
# Use a queen id that isn't in QUEEN_DEFAULT_CATEGORIES so we exercise
|
||||
# the fallback-to-allow-all path.
|
||||
custom_id = "queen_custom_unknown"
|
||||
(queens_dir / custom_id).mkdir()
|
||||
(queens_dir / custom_id / "profile.yaml").write_text(yaml.safe_dump({"name": "Custom", "title": "Custom Role"}))
|
||||
|
||||
manager = _FakeManager()
|
||||
manager._mcp_tool_catalog = {
|
||||
"coder-tools": [
|
||||
{"name": "read_file", "description": "read", "input_schema": {}},
|
||||
{"name": "write_file", "description": "write", "input_schema": {}},
|
||||
],
|
||||
}
|
||||
|
||||
app = await _make_app(manager=manager)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.get(f"/api/queen/{custom_id}/tools")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
|
||||
assert body["enabled_mcp_tools"] is None
|
||||
assert body["is_role_default"] is True # no sidecar → default-allow
|
||||
assert body["stale"] is False
|
||||
servers = {s["name"]: s for s in body["mcp_servers"]}
|
||||
assert set(servers) == {"coder-tools"}
|
||||
for tool in servers["coder-tools"]["tools"]:
|
||||
assert tool["enabled"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tools_applies_role_default(queen_dir, monkeypatch):
|
||||
"""Known persona queens get their role-based default allowlist."""
|
||||
monkeypatch.setattr(routes_queen_tools, "ensure_default_queens", lambda: None)
|
||||
_, queen_id = queen_dir # queen_technology — has a role default
|
||||
|
||||
manager = _FakeManager()
|
||||
# Seed a catalog covering tools the role default references so the
|
||||
# response reflects what the queen would actually see on boot.
|
||||
manager._mcp_tool_catalog = {
|
||||
"coder-tools": [
|
||||
{"name": "read_file", "description": "", "input_schema": {}},
|
||||
{"name": "port_scan", "description": "", "input_schema": {}}, # security
|
||||
{"name": "excel_read", "description": "", "input_schema": {}}, # data
|
||||
{"name": "fluffy_unknown_tool", "description": "", "input_schema": {}},
|
||||
],
|
||||
}
|
||||
|
||||
app = await _make_app(manager=manager)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.get(f"/api/queen/{queen_id}/tools")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
|
||||
# queen_technology's role default includes file_read, data, security, etc.
|
||||
assert body["is_role_default"] is True
|
||||
enabled = set(body["enabled_mcp_tools"] or [])
|
||||
assert "read_file" in enabled
|
||||
assert "port_scan" in enabled # technology role includes security
|
||||
assert "excel_read" in enabled
|
||||
# Tools not in any category (and not in a @server: expansion target
|
||||
# the role references) are NOT part of the default.
|
||||
assert "fluffy_unknown_tool" not in enabled
|
||||
|
||||
|
||||
def test_resolve_queen_default_tools_expands_server_shorthand():
|
||||
"""@server:NAME shorthand expands against the provided catalog."""
|
||||
from framework.agents.queen.queen_tools_defaults import resolve_queen_default_tools
|
||||
|
||||
catalog = {
|
||||
"gcu-tools": [
|
||||
{"name": "browser_navigate"},
|
||||
{"name": "browser_click"},
|
||||
],
|
||||
}
|
||||
# queen_brand_design uses "browser" category → expands via @server:gcu-tools.
|
||||
result = resolve_queen_default_tools("queen_brand_design", catalog)
|
||||
assert result is not None
|
||||
assert "browser_navigate" in result
|
||||
assert "browser_click" in result
|
||||
|
||||
|
||||
def test_resolve_queen_default_tools_unknown_queen_returns_none():
|
||||
from framework.agents.queen.queen_tools_defaults import resolve_queen_default_tools
|
||||
|
||||
assert resolve_queen_default_tools("queen_made_up", {}) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_patch_persists_and_validates(queen_dir, monkeypatch):
|
||||
monkeypatch.setattr(routes_queen_tools, "ensure_default_queens", lambda: None)
|
||||
queens_dir, queen_id = queen_dir
|
||||
|
||||
manager = _FakeManager()
|
||||
manager._mcp_tool_catalog = {
|
||||
"coder-tools": [
|
||||
{"name": "read_file", "description": "", "input_schema": {}},
|
||||
{"name": "write_file", "description": "", "input_schema": {}},
|
||||
]
|
||||
}
|
||||
|
||||
app = await _make_app(manager=manager)
|
||||
tools_path = queens_dir / queen_id / "tools.json"
|
||||
profile_path = queens_dir / queen_id / "profile.yaml"
|
||||
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
# Happy path
|
||||
resp = await client.patch(
|
||||
f"/api/queen/{queen_id}/tools",
|
||||
json={"enabled_mcp_tools": ["read_file"]},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["enabled_mcp_tools"] == ["read_file"]
|
||||
|
||||
# Sidecar persisted; profile YAML untouched by tools PATCH
|
||||
sidecar = json.loads(tools_path.read_text())
|
||||
assert sidecar["enabled_mcp_tools"] == ["read_file"]
|
||||
assert "updated_at" in sidecar
|
||||
profile = yaml.safe_load(profile_path.read_text())
|
||||
assert "enabled_mcp_tools" not in profile
|
||||
|
||||
# GET reflects the new state
|
||||
resp = await client.get(f"/api/queen/{queen_id}/tools")
|
||||
body = await resp.json()
|
||||
assert body["is_role_default"] is False # user has explicitly saved
|
||||
servers = {t["name"]: t for t in body["mcp_servers"][0]["tools"]}
|
||||
assert servers["read_file"]["enabled"] is True
|
||||
assert servers["write_file"]["enabled"] is False
|
||||
|
||||
# Null resets
|
||||
resp = await client.patch(f"/api/queen/{queen_id}/tools", json={"enabled_mcp_tools": None})
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["enabled_mcp_tools"] is None
|
||||
sidecar = json.loads(tools_path.read_text())
|
||||
assert sidecar["enabled_mcp_tools"] is None
|
||||
|
||||
# Unknown tool name → 400; sidecar unchanged
|
||||
resp = await client.patch(
|
||||
f"/api/queen/{queen_id}/tools",
|
||||
json={"enabled_mcp_tools": ["nope_not_a_tool"]},
|
||||
)
|
||||
assert resp.status == 400
|
||||
detail = await resp.json()
|
||||
assert "nope_not_a_tool" in detail.get("unknown", [])
|
||||
sidecar = json.loads(tools_path.read_text())
|
||||
assert sidecar["enabled_mcp_tools"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_patch_hot_reloads_live_session(queen_dir, monkeypatch):
|
||||
monkeypatch.setattr(routes_queen_tools, "ensure_default_queens", lambda: None)
|
||||
_, queen_id = queen_dir
|
||||
|
||||
# Build a fake live session whose phase state carries a tool list the
|
||||
# filter can gate. We also need a fake registry so
|
||||
# _catalog_from_live_session can enumerate tools.
|
||||
class _FakeRegistry:
|
||||
def __init__(self, server_map, tools_by_name):
|
||||
self._mcp_server_tools = server_map
|
||||
self._tools_by_name = tools_by_name
|
||||
|
||||
def get_tools(self):
|
||||
return {n: MagicMock(name=n) for n in self._tools_by_name}
|
||||
|
||||
tools_by_name = {"read_file": _tool("read_file"), "write_file": _tool("write_file")}
|
||||
registry = _FakeRegistry(
|
||||
server_map={"coder-tools": {"read_file", "write_file"}},
|
||||
tools_by_name=tools_by_name,
|
||||
)
|
||||
# Patch get_tools to return real Tool objects for name/description plumbing.
|
||||
registry.get_tools = lambda: tools_by_name # type: ignore[method-assign]
|
||||
|
||||
phase_state = QueenPhaseState(phase="independent")
|
||||
phase_state.independent_tools = [tools_by_name["read_file"], tools_by_name["write_file"]]
|
||||
phase_state.mcp_tool_names_all = {"read_file", "write_file"}
|
||||
phase_state.enabled_mcp_tools = None
|
||||
phase_state.rebuild_independent_filter()
|
||||
|
||||
session = _FakeSession(queen_name=queen_id, phase_state=phase_state)
|
||||
session._queen_tool_registry = registry
|
||||
manager = _FakeManager(_sessions={"sess-1": session})
|
||||
|
||||
app = await _make_app(manager=manager)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.patch(
|
||||
f"/api/queen/{queen_id}/tools",
|
||||
json={"enabled_mcp_tools": ["read_file"]},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["refreshed_sessions"] == 1
|
||||
|
||||
# Session's phase state reflects the new allowlist without a restart
|
||||
current = phase_state.get_current_tools()
|
||||
assert [t.name for t in current] == ["read_file"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_queen_returns_404(queen_dir, monkeypatch):
|
||||
monkeypatch.setattr(routes_queen_tools, "ensure_default_queens", lambda: None)
|
||||
manager = _FakeManager()
|
||||
|
||||
app = await _make_app(manager=manager)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.get("/api/queen/queen_nonexistent/tools")
|
||||
assert resp.status == 404
|
||||
|
||||
resp = await client.patch(
|
||||
"/api/queen/queen_nonexistent/tools",
|
||||
json={"enabled_mcp_tools": None},
|
||||
)
|
||||
assert resp.status == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_restores_role_default(queen_dir, monkeypatch):
|
||||
"""DELETE removes tools.json so the queen falls back to the role default."""
|
||||
monkeypatch.setattr(routes_queen_tools, "ensure_default_queens", lambda: None)
|
||||
queens_dir, queen_id = queen_dir
|
||||
tools_path = queens_dir / queen_id / "tools.json"
|
||||
|
||||
manager = _FakeManager()
|
||||
manager._mcp_tool_catalog = {
|
||||
"coder-tools": [
|
||||
{"name": "read_file", "description": "", "input_schema": {}},
|
||||
{"name": "port_scan", "description": "", "input_schema": {}},
|
||||
],
|
||||
}
|
||||
|
||||
app = await _make_app(manager=manager)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
# Seed a custom allowlist first so we have a sidecar to delete.
|
||||
resp = await client.patch(
|
||||
f"/api/queen/{queen_id}/tools",
|
||||
json={"enabled_mcp_tools": ["read_file"]},
|
||||
)
|
||||
assert resp.status == 200
|
||||
assert tools_path.exists()
|
||||
|
||||
resp = await client.delete(f"/api/queen/{queen_id}/tools")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["removed"] is True
|
||||
assert body["is_role_default"] is True
|
||||
assert not tools_path.exists()
|
||||
|
||||
# The new effective list is the role default for queen_technology,
|
||||
# which includes both read_file (file_read) and port_scan (security).
|
||||
enabled = set(body["enabled_mcp_tools"] or [])
|
||||
assert "read_file" in enabled
|
||||
assert "port_scan" in enabled
|
||||
|
||||
# GET confirms.
|
||||
resp = await client.get(f"/api/queen/{queen_id}/tools")
|
||||
body = await resp.json()
|
||||
assert body["is_role_default"] is True
|
||||
|
||||
# Deleting again is a no-op.
|
||||
resp = await client.delete(f"/api/queen/{queen_id}/tools")
|
||||
assert resp.status == 200
|
||||
assert (await resp.json())["removed"] is False
|
||||
|
||||
|
||||
def test_legacy_profile_field_migrates_to_sidecar(queen_dir):
|
||||
"""A legacy enabled_mcp_tools field in profile.yaml is hoisted to tools.json."""
|
||||
queens_dir, queen_id = queen_dir
|
||||
profile_path = queens_dir / queen_id / "profile.yaml"
|
||||
tools_path = queens_dir / queen_id / "tools.json"
|
||||
|
||||
# Seed legacy field in profile.yaml.
|
||||
profile = yaml.safe_load(profile_path.read_text()) or {}
|
||||
profile["enabled_mcp_tools"] = ["read_file", "write_file"]
|
||||
profile_path.write_text(yaml.safe_dump(profile, sort_keys=False))
|
||||
|
||||
from framework.agents.queen.queen_tools_config import load_queen_tools_config
|
||||
|
||||
# First load migrates.
|
||||
assert load_queen_tools_config(queen_id) == ["read_file", "write_file"]
|
||||
assert tools_path.exists()
|
||||
sidecar = json.loads(tools_path.read_text())
|
||||
assert sidecar["enabled_mcp_tools"] == ["read_file", "write_file"]
|
||||
|
||||
# profile.yaml no longer contains the field; other fields preserved.
|
||||
migrated_profile = yaml.safe_load(profile_path.read_text())
|
||||
assert "enabled_mcp_tools" not in migrated_profile
|
||||
assert migrated_profile["name"] == "Alexandra"
|
||||
|
||||
# Second load is a direct read — no migration work to do.
|
||||
assert load_queen_tools_config(queen_id) == ["read_file", "write_file"]
|
||||
@@ -0,0 +1,203 @@
|
||||
"""Shared skill authoring primitives.
|
||||
|
||||
Validates and materializes a skill folder. Used by three callers:
|
||||
|
||||
1. Queen's ``create_colony`` tool (``queen_lifecycle_tools.py``) — inline
|
||||
content passed by the queen during colony creation.
|
||||
2. HTTP POST / PUT routes under ``/api/**/skills`` — UI-driven creation.
|
||||
3. Future ``create_learned_skill`` tool — runtime learning.
|
||||
|
||||
Keeping the validators and writer here ensures the three paths share one
|
||||
authority; changes to the name regex or frontmatter layout happen in one
|
||||
place.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import shutil
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Framework skill names include dots (``hive.note-taking``), so the
|
||||
# validator needs to allow them even though the queen's ``create_colony``
|
||||
# tool historically forbade dots. User-created skills without dots still
|
||||
# pass; the dot cap just prevents us from rejecting existing framework
|
||||
# names when the UI toggles them via ``validate_skill_name``.
|
||||
_SKILL_NAME_RE = re.compile(r"^[a-z0-9.-]+$")
|
||||
_MAX_NAME_LEN = 64
|
||||
_MAX_DESC_LEN = 1024
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillFile:
|
||||
"""Supporting file bundled with a skill (relative path + content)."""
|
||||
|
||||
rel_path: Path
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillDraft:
|
||||
"""Validated skill content ready to be written to disk."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
body: str
|
||||
files: list[SkillFile] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def skill_md_text(self) -> str:
|
||||
"""Assemble the final SKILL.md text (frontmatter + body)."""
|
||||
body_norm = self.body.rstrip() + "\n"
|
||||
return f"---\nname: {self.name}\ndescription: {self.description}\n---\n\n{body_norm}"
|
||||
|
||||
|
||||
def validate_skill_name(raw: str) -> tuple[str | None, str | None]:
|
||||
"""Return ``(normalized_name, error)``. Either side may be None."""
|
||||
name = (raw or "").strip() if isinstance(raw, str) else ""
|
||||
if not name:
|
||||
return None, "skill_name is required"
|
||||
if not _SKILL_NAME_RE.match(name):
|
||||
return None, f"skill_name '{name}' must match [a-z0-9-] pattern"
|
||||
if name.startswith("-") or name.endswith("-") or "--" in name:
|
||||
return None, f"skill_name '{name}' has leading/trailing/consecutive hyphens"
|
||||
if len(name) > _MAX_NAME_LEN:
|
||||
return None, f"skill_name '{name}' exceeds {_MAX_NAME_LEN} chars"
|
||||
return name, None
|
||||
|
||||
|
||||
def validate_description(raw: str) -> tuple[str | None, str | None]:
|
||||
desc = (raw or "").strip() if isinstance(raw, str) else ""
|
||||
if not desc:
|
||||
return None, "skill_description is required"
|
||||
if len(desc) > _MAX_DESC_LEN:
|
||||
return None, f"skill_description must be 1–{_MAX_DESC_LEN} chars"
|
||||
# Frontmatter descriptions are line-oriented — the parser reads one value.
|
||||
if "\n" in desc or "\r" in desc:
|
||||
return None, "skill_description must be a single line (no newlines)"
|
||||
return desc, None
|
||||
|
||||
|
||||
def validate_files(raw: list[dict] | None) -> tuple[list[SkillFile] | None, str | None]:
|
||||
if not raw:
|
||||
return [], None
|
||||
if not isinstance(raw, list):
|
||||
return None, "skill_files must be an array"
|
||||
out: list[SkillFile] = []
|
||||
for entry in raw:
|
||||
if not isinstance(entry, dict):
|
||||
return None, "each skill_files entry must be an object with 'path' and 'content'"
|
||||
rel_raw = entry.get("path")
|
||||
content = entry.get("content")
|
||||
if not isinstance(rel_raw, str) or not rel_raw.strip():
|
||||
return None, "skill_files entry missing non-empty 'path'"
|
||||
if not isinstance(content, str):
|
||||
return None, f"skill_files entry '{rel_raw}' missing string 'content'"
|
||||
rel_stripped = rel_raw.strip()
|
||||
# Allow './foo' but reject '/foo' — relativizing absolute paths silently
|
||||
# has bitten other tools; make the intent loud instead.
|
||||
if rel_stripped.startswith("./"):
|
||||
rel_stripped = rel_stripped[2:]
|
||||
rel_path = Path(rel_stripped)
|
||||
if rel_stripped.startswith("/") or rel_path.is_absolute() or ".." in rel_path.parts:
|
||||
return None, f"skill_files path '{rel_raw}' must be relative and inside the skill folder"
|
||||
if rel_path.as_posix() == "SKILL.md":
|
||||
return None, "skill_files must not contain SKILL.md — pass skill_body instead"
|
||||
out.append(SkillFile(rel_path=rel_path, content=content))
|
||||
return out, None
|
||||
|
||||
|
||||
def build_draft(
|
||||
*,
|
||||
skill_name: str,
|
||||
skill_description: str,
|
||||
skill_body: str,
|
||||
skill_files: list[dict] | None = None,
|
||||
) -> tuple[SkillDraft | None, str | None]:
|
||||
"""Validate all inputs and return an immutable draft ready for writing."""
|
||||
name, err = validate_skill_name(skill_name)
|
||||
if err or name is None:
|
||||
return None, err
|
||||
desc, err = validate_description(skill_description)
|
||||
if err or desc is None:
|
||||
return None, err
|
||||
body = skill_body if isinstance(skill_body, str) else ""
|
||||
if not body.strip():
|
||||
return None, (
|
||||
"skill_body is required — the operational procedure the colony worker needs to run this job unattended"
|
||||
)
|
||||
files, err = validate_files(skill_files)
|
||||
if err or files is None:
|
||||
return None, err
|
||||
return SkillDraft(name=name, description=desc, body=body, files=list(files)), None
|
||||
|
||||
|
||||
def write_skill(
|
||||
draft: SkillDraft,
|
||||
*,
|
||||
target_root: Path,
|
||||
replace_existing: bool = True,
|
||||
) -> tuple[Path | None, str | None, bool]:
|
||||
"""Write the draft under ``target_root/{draft.name}/``.
|
||||
|
||||
``target_root`` is the parent scope dir (e.g.
|
||||
``~/.hive/agents/queens/{id}/skills`` or
|
||||
``{colony_dir}/.hive/skills``). The function creates it if needed.
|
||||
|
||||
Returns ``(installed_path, error, replaced)``. On success ``error`` is
|
||||
``None``; on failure ``installed_path`` is ``None`` and the target is
|
||||
left as it was before the call (best-effort).
|
||||
|
||||
When ``replace_existing=False`` and the target dir already exists,
|
||||
the write is refused with a non-fatal error (caller decides whether
|
||||
to surface it as a 409 or a warning).
|
||||
"""
|
||||
try:
|
||||
target_root.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as e:
|
||||
return None, f"failed to create skills root: {e}", False
|
||||
|
||||
target = target_root / draft.name
|
||||
replaced = False
|
||||
try:
|
||||
if target.exists():
|
||||
if not replace_existing:
|
||||
return None, f"skill '{draft.name}' already exists", False
|
||||
# Remove the old dir outright so stale files from a prior
|
||||
# version don't linger alongside the new ones.
|
||||
replaced = True
|
||||
shutil.rmtree(target)
|
||||
target.mkdir(parents=True, exist_ok=False)
|
||||
(target / "SKILL.md").write_text(draft.skill_md_text, encoding="utf-8")
|
||||
for sf in draft.files:
|
||||
full_path = target / sf.rel_path
|
||||
full_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
full_path.write_text(sf.content, encoding="utf-8")
|
||||
except OSError as e:
|
||||
return None, f"failed to write skill folder {target}: {e}", replaced
|
||||
return target, None, replaced
|
||||
|
||||
|
||||
def remove_skill(target_root: Path, skill_name: str) -> tuple[bool, str | None]:
|
||||
"""Rm-tree the skill directory under ``target_root/{skill_name}/``.
|
||||
|
||||
Returns ``(removed, error)``. ``removed=False, error=None`` means
|
||||
the directory didn't exist (idempotent). Name is validated on the
|
||||
way in so an attacker with UI access can't traverse out of the
|
||||
scope root.
|
||||
"""
|
||||
name, err = validate_skill_name(skill_name)
|
||||
if err or name is None:
|
||||
return False, err
|
||||
target = target_root / name
|
||||
if not target.exists():
|
||||
return False, None
|
||||
try:
|
||||
shutil.rmtree(target)
|
||||
except OSError as e:
|
||||
return False, f"failed to remove skill folder {target}: {e}"
|
||||
return True, None
|
||||
@@ -7,7 +7,7 @@ locations. Resolves name collisions deterministically.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from framework.skills.parser import ParsedSkill, parse_skill_md
|
||||
@@ -30,16 +30,40 @@ _SKIP_DIRS = frozenset(
|
||||
)
|
||||
|
||||
# Scope priority (higher = takes precedence)
|
||||
# ``preset`` sits between framework and user: bundled alongside the
|
||||
# framework distribution, but off by default — capability packs the user
|
||||
# opts into per queen/colony rather than globally-enabled infra.
|
||||
_SCOPE_PRIORITY = {
|
||||
"framework": 0,
|
||||
"user": 1,
|
||||
"project": 2,
|
||||
"preset": 1,
|
||||
"user": 2,
|
||||
"queen_ui": 3,
|
||||
"colony_ui": 4,
|
||||
"project": 5,
|
||||
}
|
||||
|
||||
# Within the same scope, Hive-specific paths override cross-client paths.
|
||||
# We encode this by scanning cross-client first, then Hive-specific (later wins).
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtraScope:
|
||||
"""Additional scope dir to scan beyond the standard five.
|
||||
|
||||
Used by :class:`framework.skills.manager.SkillsManager` to surface
|
||||
per-queen (``queen_ui``) and per-colony (``colony_ui``) skill
|
||||
directories created through the UI. The ``label`` feeds
|
||||
:attr:`ParsedSkill.source_scope` so downstream consumers (trust
|
||||
gate, UI provenance resolver) can distinguish scope origins.
|
||||
"""
|
||||
|
||||
directory: Path
|
||||
label: str
|
||||
# Kept for forward-compat with the priority table; discovery itself
|
||||
# relies on scan order for last-wins resolution.
|
||||
priority: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiscoveryConfig:
|
||||
"""Configuration for skill discovery."""
|
||||
@@ -49,6 +73,10 @@ class DiscoveryConfig:
|
||||
skip_framework_scope: bool = False
|
||||
max_depth: int = 4
|
||||
max_dirs: int = 2000
|
||||
# Additional scope dirs scanned between user and project scopes,
|
||||
# in the order they are provided. Use ``ExtraScope`` to tag each
|
||||
# with its logical label (``queen_ui`` / ``colony_ui``).
|
||||
extra_scopes: list[ExtraScope] = field(default_factory=list)
|
||||
|
||||
|
||||
class SkillDiscovery:
|
||||
@@ -82,13 +110,22 @@ class SkillDiscovery:
|
||||
all_skills: list[ParsedSkill] = []
|
||||
self._scanned_dirs = []
|
||||
|
||||
# Framework scope (lowest precedence)
|
||||
# Framework scope (lowest precedence) — always-on infra skills.
|
||||
if not self._config.skip_framework_scope:
|
||||
framework_dir = Path(__file__).parent / "_default_skills"
|
||||
if framework_dir.is_dir():
|
||||
self._scanned_dirs.append(framework_dir)
|
||||
all_skills.extend(self._scan_scope(framework_dir, "framework"))
|
||||
|
||||
# Preset scope — bundled capability packs that ship with the
|
||||
# framework but default to OFF. User opts in per queen/colony
|
||||
# via the Skills Library. ``skip_framework_scope`` covers both
|
||||
# bundled directories since they live side-by-side on disk.
|
||||
preset_dir = Path(__file__).parent / "_preset_skills"
|
||||
if preset_dir.is_dir():
|
||||
self._scanned_dirs.append(preset_dir)
|
||||
all_skills.extend(self._scan_scope(preset_dir, "preset"))
|
||||
|
||||
# User scope
|
||||
if not self._config.skip_user_scope:
|
||||
home = Path.home()
|
||||
@@ -105,6 +142,13 @@ class SkillDiscovery:
|
||||
self._scanned_dirs.append(user_hive)
|
||||
all_skills.extend(self._scan_scope(user_hive, "user"))
|
||||
|
||||
# Extra scopes (queen_ui / colony_ui), scanned between user and project
|
||||
# so colony overrides beat queen overrides, and both beat user-scope.
|
||||
for extra in self._config.extra_scopes:
|
||||
if extra.directory.is_dir():
|
||||
self._scanned_dirs.append(extra.directory)
|
||||
all_skills.extend(self._scan_scope(extra.directory, extra.label))
|
||||
|
||||
# Project scope (highest precedence)
|
||||
if self._config.project_root:
|
||||
root = self._config.project_root
|
||||
|
||||
@@ -23,6 +23,7 @@ Typical usage — **bare** (exported agents, SDK users)::
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
@@ -44,6 +45,18 @@ class SkillsManagerConfig:
|
||||
even when ``project_root`` is set.
|
||||
interactive: Whether trust gating can prompt the user interactively.
|
||||
When ``False``, untrusted project skills are silently skipped.
|
||||
queen_id: Optional queen identifier. When set, enables the
|
||||
``queen_ui`` scope and per-queen override file.
|
||||
queen_overrides_path: Path to
|
||||
``~/.hive/agents/queens/{queen_id}/skills_overrides.json``.
|
||||
When set, the store is loaded and its entries override
|
||||
discovery results (disable skills, record provenance).
|
||||
colony_name: Optional colony identifier; mirrors ``queen_id`` for
|
||||
the ``colony_ui`` scope.
|
||||
colony_overrides_path: Per-colony override file path.
|
||||
extra_scope_dirs: Extra scope dirs scanned between user and
|
||||
project scopes. Typically populated by the caller with the
|
||||
queen/colony UI skill directories.
|
||||
"""
|
||||
|
||||
skills_config: SkillsConfig = field(default_factory=SkillsConfig)
|
||||
@@ -51,6 +64,15 @@ class SkillsManagerConfig:
|
||||
skip_community_discovery: bool = False
|
||||
interactive: bool = True
|
||||
|
||||
# Override support
|
||||
queen_id: str | None = None
|
||||
queen_overrides_path: Path | None = None
|
||||
colony_name: str | None = None
|
||||
colony_overrides_path: Path | None = None
|
||||
# Typed at the call site as ``list[ExtraScope]`` — not imported here
|
||||
# to keep this module free of discovery-layer dependencies.
|
||||
extra_scope_dirs: list = field(default_factory=list)
|
||||
|
||||
|
||||
class SkillsManager:
|
||||
"""Unified skill lifecycle: discovery → loading → prompt renderation.
|
||||
@@ -65,13 +87,21 @@ class SkillsManager:
|
||||
self._config = config or SkillsManagerConfig()
|
||||
self._loaded = False
|
||||
self._catalog: object = None # SkillCatalog, set after load()
|
||||
self._all_skills: list = [] # list[ParsedSkill], pre-override-filter
|
||||
self._catalog_prompt: str = ""
|
||||
self._protocols_prompt: str = ""
|
||||
self._allowlisted_dirs: list[str] = []
|
||||
self._default_mgr: object = None # DefaultSkillManager, set after load()
|
||||
# Override stores (loaded lazily in _do_load). Queen-scope and
|
||||
# colony-scope are read together; colony entries win on collision.
|
||||
self._queen_overrides: object = None # SkillOverrideStore | None
|
||||
self._colony_overrides: object = None # SkillOverrideStore | None
|
||||
# Hot-reload state
|
||||
self._watched_dirs: list[str] = []
|
||||
self._watched_files: list[str] = []
|
||||
self._watcher_task: object = None # asyncio.Task, set by start_watching()
|
||||
# Serializes in-process mutations (HTTP handlers + create_colony).
|
||||
self._mutation_lock = asyncio.Lock()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Factory for backwards-compat bridge
|
||||
@@ -119,6 +149,7 @@ class SkillsManager:
|
||||
from framework.skills.catalog import SkillCatalog
|
||||
from framework.skills.defaults import DefaultSkillManager
|
||||
from framework.skills.discovery import DiscoveryConfig, SkillDiscovery
|
||||
from framework.skills.overrides import SkillOverrideStore
|
||||
|
||||
skills_config = self._config.skills_config
|
||||
|
||||
@@ -128,12 +159,13 @@ class SkillsManager:
|
||||
DiscoveryConfig(
|
||||
project_root=self._config.project_root,
|
||||
skip_framework_scope=False,
|
||||
extra_scopes=list(self._config.extra_scope_dirs or []),
|
||||
)
|
||||
)
|
||||
discovered = discovery.discover()
|
||||
self._watched_dirs = discovery.scanned_directories
|
||||
|
||||
# Trust-gate project-scope skills (AS-13)
|
||||
# Trust-gate project-scope skills (AS-13). UI scopes bypass.
|
||||
if self._config.project_root is not None and not self._config.skip_community_discovery:
|
||||
from framework.skills.trust import TrustGate
|
||||
|
||||
@@ -141,6 +173,31 @@ class SkillsManager:
|
||||
discovered, project_dir=self._config.project_root
|
||||
)
|
||||
|
||||
# 1b. Load per-scope override stores. Missing files → empty stores.
|
||||
queen_store = None
|
||||
if self._config.queen_overrides_path is not None:
|
||||
queen_store = SkillOverrideStore.load(
|
||||
self._config.queen_overrides_path,
|
||||
scope_label=f"queen:{self._config.queen_id or ''}",
|
||||
)
|
||||
colony_store = None
|
||||
if self._config.colony_overrides_path is not None:
|
||||
colony_store = SkillOverrideStore.load(
|
||||
self._config.colony_overrides_path,
|
||||
scope_label=f"colony:{self._config.colony_name or ''}",
|
||||
)
|
||||
self._queen_overrides = queen_store
|
||||
self._colony_overrides = colony_store
|
||||
self._watched_files = [
|
||||
str(p) for p in (self._config.queen_overrides_path, self._config.colony_overrides_path) if p is not None
|
||||
]
|
||||
|
||||
# 1c. Apply override filtering. Colony entries take precedence over
|
||||
# queen entries on name collision; the store's ``is_disabled`` keeps
|
||||
# the resolution rule in one place.
|
||||
self._all_skills = list(discovered)
|
||||
discovered = self._apply_overrides(discovered, skills_config, queen_store, colony_store)
|
||||
|
||||
catalog = SkillCatalog(discovered)
|
||||
self._catalog = catalog
|
||||
self._allowlisted_dirs = catalog.allowlisted_dirs
|
||||
@@ -174,6 +231,101 @@ class SkillsManager:
|
||||
len(catalog_prompt),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Override application
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _apply_overrides(
|
||||
discovered: list,
|
||||
skills_config: SkillsConfig,
|
||||
queen_store: object,
|
||||
colony_store: object,
|
||||
) -> list:
|
||||
"""Filter ``discovered`` per the queen + colony override stores.
|
||||
|
||||
Resolution rule:
|
||||
1. Tombstoned names (``deleted_ui_skills``) drop out.
|
||||
2. An explicit ``enabled=False`` override drops the skill.
|
||||
3. An explicit ``enabled=True`` override keeps it (wins over
|
||||
``all_defaults_disabled`` for framework defaults AND over the
|
||||
preset-scope default-off rule).
|
||||
4. Otherwise: preset-scope skills are off by default; everything
|
||||
else inherits :meth:`SkillsConfig.is_default_enabled`.
|
||||
"""
|
||||
from framework.skills.overrides import SkillOverrideStore
|
||||
|
||||
stores: list[SkillOverrideStore] = [s for s in (queen_store, colony_store) if s is not None]
|
||||
|
||||
tombstones: set[str] = set()
|
||||
for store in stores:
|
||||
tombstones |= set(store.deleted_ui_skills)
|
||||
|
||||
out = []
|
||||
for skill in discovered:
|
||||
if skill.name in tombstones:
|
||||
continue
|
||||
# Check colony first so colony overrides win over queen's.
|
||||
explicit: bool | None = None
|
||||
master_disabled = False
|
||||
for store in reversed(stores): # colony, then queen
|
||||
entry = store.get(skill.name)
|
||||
if entry is not None and entry.enabled is not None:
|
||||
explicit = entry.enabled
|
||||
break
|
||||
if store.all_defaults_disabled:
|
||||
master_disabled = True
|
||||
if explicit is False:
|
||||
continue
|
||||
if explicit is True:
|
||||
out.append(skill)
|
||||
continue
|
||||
# Preset-scope capability packs are bundled but ship OFF; the
|
||||
# user must explicitly enable them per queen or colony. This
|
||||
# runs even when no store is present so bare agents don't
|
||||
# silently load x-automation etc.
|
||||
if skill.source_scope == "preset":
|
||||
continue
|
||||
# No explicit entry — master switch takes effect against framework defaults.
|
||||
default_enabled = skills_config.is_default_enabled(skill.name)
|
||||
if master_disabled and default_enabled and skill.source_scope == "framework":
|
||||
continue
|
||||
if default_enabled:
|
||||
out.append(skill)
|
||||
return out
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Override accessors
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def queen_overrides(self) -> object:
|
||||
"""The queen-scope :class:`SkillOverrideStore` or ``None``."""
|
||||
return self._queen_overrides
|
||||
|
||||
@property
|
||||
def colony_overrides(self) -> object:
|
||||
"""The colony-scope :class:`SkillOverrideStore` or ``None``."""
|
||||
return self._colony_overrides
|
||||
|
||||
@property
|
||||
def mutation_lock(self) -> asyncio.Lock:
|
||||
"""Serializes in-process override mutations (routes + queen tools)."""
|
||||
return self._mutation_lock
|
||||
|
||||
def reload(self) -> None:
|
||||
"""Re-run discovery and rebuild cached prompts. Public wrapper for ``_reload``."""
|
||||
self._reload()
|
||||
|
||||
def enumerate_skills_with_source(self) -> list:
|
||||
"""Return every discovered skill, including ones disabled by overrides.
|
||||
|
||||
The UI relies on this: a disabled framework skill needs to render
|
||||
in the list so the user can toggle it back on. The post-filter
|
||||
catalog omits those entries.
|
||||
"""
|
||||
return list(self._all_skills)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Hot-reload: watch skill directories for SKILL.md changes.
|
||||
# ------------------------------------------------------------------
|
||||
@@ -181,14 +333,14 @@ class SkillsManager:
|
||||
async def start_watching(self) -> None:
|
||||
"""Start a background task watching skill directories for changes.
|
||||
|
||||
When a ``SKILL.md`` file is added/modified/removed, the cached
|
||||
``skills_catalog_prompt`` is rebuilt. The next node iteration picks
|
||||
up the new prompt automatically via the ``dynamic_prompt_provider``.
|
||||
Triggers a reload when any ``SKILL.md`` changes or an override
|
||||
JSON file is modified. The next node iteration picks up the new
|
||||
prompt via the ``dynamic_prompt_provider`` / per-worker
|
||||
``dynamic_skills_catalog_provider``.
|
||||
|
||||
Silently no-ops when ``watchfiles`` is not installed or when no
|
||||
directories are being watched (e.g. bare mode, no project_root).
|
||||
Silently no-ops when ``watchfiles`` is not installed or there
|
||||
are no paths to watch.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
import watchfiles # noqa: F401 -- optional dep check
|
||||
@@ -196,7 +348,7 @@ class SkillsManager:
|
||||
logger.debug("watchfiles not installed; skill hot-reload disabled")
|
||||
return
|
||||
|
||||
if not self._watched_dirs:
|
||||
if not self._watched_dirs and not self._watched_files:
|
||||
logger.debug("No skill directories to watch; hot-reload skipped")
|
||||
return
|
||||
|
||||
@@ -208,14 +360,13 @@ class SkillsManager:
|
||||
name="skills-hot-reload",
|
||||
)
|
||||
logger.info(
|
||||
"Skill hot-reload enabled (watching %d directories)",
|
||||
"Skill hot-reload enabled (watching %d dirs, %d override files)",
|
||||
len(self._watched_dirs),
|
||||
len(self._watched_files),
|
||||
)
|
||||
|
||||
async def stop_watching(self) -> None:
|
||||
"""Cancel the background watcher task (if running)."""
|
||||
import asyncio
|
||||
|
||||
task = self._watcher_task
|
||||
if task is None:
|
||||
return
|
||||
@@ -228,22 +379,35 @@ class SkillsManager:
|
||||
pass
|
||||
|
||||
async def _watch_loop(self) -> None:
|
||||
"""Background coroutine that watches SKILL.md files and triggers reload."""
|
||||
import asyncio
|
||||
|
||||
"""Watch SKILL.md + override JSON files and trigger reload on change."""
|
||||
import watchfiles
|
||||
|
||||
def _filter(_change: object, path: str) -> bool:
|
||||
return path.endswith("SKILL.md")
|
||||
return path.endswith("SKILL.md") or path.endswith("skills_overrides.json")
|
||||
|
||||
# watchfiles accepts a mix of dirs and files; file watches survive
|
||||
# a tmp+rename (the containing dir sees the event).
|
||||
watch_targets = list(self._watched_dirs)
|
||||
for f in self._watched_files:
|
||||
# watchfiles needs the parent dir for file-level events to fire
|
||||
# reliably through atomic replace; adding the file path directly
|
||||
# works on Linux/macOS inotify/FSEvents but a dir watch is
|
||||
# belt-and-braces.
|
||||
parent = str(Path(f).parent)
|
||||
if parent not in watch_targets:
|
||||
watch_targets.append(parent)
|
||||
|
||||
if not watch_targets:
|
||||
return
|
||||
|
||||
try:
|
||||
async for changes in watchfiles.awatch(
|
||||
*self._watched_dirs,
|
||||
*watch_targets,
|
||||
watch_filter=_filter,
|
||||
debounce=1000,
|
||||
):
|
||||
paths = [p for _, p in changes]
|
||||
logger.info("SKILL.md changes detected: %s", paths)
|
||||
logger.info("Skill state changes detected: %s", paths)
|
||||
try:
|
||||
self._reload()
|
||||
except Exception:
|
||||
|
||||
@@ -0,0 +1,254 @@
|
||||
"""Per-scope skill override store.
|
||||
|
||||
Sits between :mod:`framework.skills.discovery` and
|
||||
:class:`framework.skills.catalog.SkillCatalog`: records the user's
|
||||
per-queen and per-colony decisions about which skills are enabled,
|
||||
who created them (provenance), and any parameter tweaks.
|
||||
|
||||
Two well-known paths back this module:
|
||||
|
||||
* Queen scope: ``~/.hive/agents/queens/{queen_id}/skills_overrides.json``
|
||||
* Colony scope: ``~/.hive/colonies/{colony_name}/skills_overrides.json``
|
||||
|
||||
The schema is intentionally small; see :class:`SkillOverrideStore` for
|
||||
the JSON shape. Atomic writes mirror
|
||||
:class:`framework.skills.trust.TrustedRepoStore` (tmp + rename).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SCHEMA_VERSION = 1
|
||||
|
||||
|
||||
class Provenance(StrEnum):
|
||||
"""Where a skill came from.
|
||||
|
||||
The override store is the authoritative provenance ledger for anything
|
||||
the UI or the queen tools touched. Framework / user-dropped /
|
||||
project-dropped skills don't need an entry unless they've been
|
||||
explicitly configured.
|
||||
"""
|
||||
|
||||
FRAMEWORK = "framework"
|
||||
PRESET = "preset"
|
||||
USER_DROPPED = "user_dropped"
|
||||
USER_UI_CREATED = "user_ui_created"
|
||||
QUEEN_CREATED = "queen_created"
|
||||
LEARNED_RUNTIME = "learned_runtime"
|
||||
PROJECT_DROPPED = "project_dropped"
|
||||
# Catch-all for skills with no recorded authorship: legacy rows from
|
||||
# before the override store existed, PATCHes that precede any CREATE,
|
||||
# etc. Keeps the ledger honest rather than forcing a guess.
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OverrideEntry:
|
||||
"""Per-skill override record inside a scope's store."""
|
||||
|
||||
enabled: bool | None = None
|
||||
provenance: Provenance = Provenance.FRAMEWORK
|
||||
trust: str | None = None
|
||||
param_overrides: dict[str, Any] = field(default_factory=dict)
|
||||
notes: str | None = None
|
||||
created_at: datetime | None = None
|
||||
created_by: str | None = None
|
||||
|
||||
def clone(self) -> OverrideEntry:
|
||||
"""Return a deep-enough copy (dict fields are re-allocated)."""
|
||||
return OverrideEntry(
|
||||
enabled=self.enabled,
|
||||
provenance=self.provenance,
|
||||
trust=self.trust,
|
||||
param_overrides=dict(self.param_overrides),
|
||||
notes=self.notes,
|
||||
created_at=self.created_at,
|
||||
created_by=self.created_by,
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
out: dict[str, Any] = {"provenance": str(self.provenance)}
|
||||
if self.enabled is not None:
|
||||
out["enabled"] = bool(self.enabled)
|
||||
if self.trust is not None:
|
||||
out["trust"] = self.trust
|
||||
if self.param_overrides:
|
||||
out["param_overrides"] = dict(self.param_overrides)
|
||||
if self.notes is not None:
|
||||
out["notes"] = self.notes
|
||||
if self.created_at is not None:
|
||||
out["created_at"] = self.created_at.isoformat()
|
||||
if self.created_by is not None:
|
||||
out["created_by"] = self.created_by
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, raw: dict[str, Any]) -> OverrideEntry:
|
||||
created_at_raw = raw.get("created_at")
|
||||
created_at: datetime | None = None
|
||||
if isinstance(created_at_raw, str):
|
||||
try:
|
||||
created_at = datetime.fromisoformat(created_at_raw)
|
||||
except ValueError:
|
||||
created_at = None
|
||||
provenance_raw = raw.get("provenance") or Provenance.FRAMEWORK
|
||||
try:
|
||||
provenance = Provenance(provenance_raw)
|
||||
except ValueError:
|
||||
logger.warning("override: unknown provenance %r; defaulting to framework", provenance_raw)
|
||||
provenance = Provenance.FRAMEWORK
|
||||
enabled = raw.get("enabled")
|
||||
return cls(
|
||||
enabled=enabled if isinstance(enabled, bool) else None,
|
||||
provenance=provenance,
|
||||
trust=raw.get("trust") if isinstance(raw.get("trust"), str) else None,
|
||||
param_overrides=dict(raw.get("param_overrides") or {}),
|
||||
notes=raw.get("notes") if isinstance(raw.get("notes"), str) else None,
|
||||
created_at=created_at,
|
||||
created_by=raw.get("created_by") if isinstance(raw.get("created_by"), str) else None,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillOverrideStore:
|
||||
"""Persistent per-scope override file.
|
||||
|
||||
The file is created lazily on first save; a missing file behaves like
|
||||
an empty store (all skills inherit defaults, no metadata recorded).
|
||||
"""
|
||||
|
||||
path: Path
|
||||
scope_label: str = ""
|
||||
version: int = _SCHEMA_VERSION
|
||||
all_defaults_disabled: bool = False
|
||||
overrides: dict[str, OverrideEntry] = field(default_factory=dict)
|
||||
deleted_ui_skills: set[str] = field(default_factory=set)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Factory
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Path, scope_label: str = "") -> SkillOverrideStore:
|
||||
"""Load the store from disk; return an empty store if the file is absent.
|
||||
|
||||
Permissive on parse errors: logs and returns an empty store rather
|
||||
than raising, so a corrupted file never takes down skill loading.
|
||||
"""
|
||||
store = cls(path=path, scope_label=scope_label)
|
||||
try:
|
||||
raw = json.loads(path.read_text(encoding="utf-8"))
|
||||
except FileNotFoundError:
|
||||
return store
|
||||
except Exception as exc:
|
||||
logger.warning("override: failed to read %s (%s); starting empty", path, exc)
|
||||
return store
|
||||
if not isinstance(raw, dict):
|
||||
logger.warning("override: %s is not an object; starting empty", path)
|
||||
return store
|
||||
|
||||
store.version = int(raw.get("version", _SCHEMA_VERSION))
|
||||
store.all_defaults_disabled = bool(raw.get("all_defaults_disabled", False))
|
||||
raw_overrides = raw.get("overrides") or {}
|
||||
if isinstance(raw_overrides, dict):
|
||||
for name, entry_raw in raw_overrides.items():
|
||||
if not isinstance(name, str) or not isinstance(entry_raw, dict):
|
||||
continue
|
||||
store.overrides[name] = OverrideEntry.from_dict(entry_raw)
|
||||
deleted = raw.get("deleted_ui_skills") or []
|
||||
if isinstance(deleted, list):
|
||||
store.deleted_ui_skills = {s for s in deleted if isinstance(s, str)}
|
||||
return store
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Mutations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def upsert(self, skill_name: str, entry: OverrideEntry) -> None:
|
||||
"""Insert or replace a skill's override entry."""
|
||||
self.overrides[skill_name] = entry
|
||||
# If we're explicitly managing this skill again, lift any tombstone.
|
||||
self.deleted_ui_skills.discard(skill_name)
|
||||
|
||||
def set_enabled(self, skill_name: str, enabled: bool, *, provenance: Provenance | None = None) -> None:
|
||||
"""Convenience: toggle enabled without rewriting other fields."""
|
||||
existing = self.overrides.get(skill_name)
|
||||
if existing is None:
|
||||
existing = OverrideEntry(
|
||||
enabled=enabled,
|
||||
provenance=provenance or Provenance.FRAMEWORK,
|
||||
)
|
||||
else:
|
||||
existing.enabled = enabled
|
||||
if provenance is not None:
|
||||
existing.provenance = provenance
|
||||
self.overrides[skill_name] = existing
|
||||
|
||||
def remove(self, skill_name: str, *, tombstone: bool = True) -> None:
|
||||
"""Drop a skill's override entry; optionally leave a tombstone.
|
||||
|
||||
Tombstones matter for UI-created skills: if the user deletes a
|
||||
queen-scope skill via the UI, we rm-tree its directory, but the
|
||||
file watcher might lag or a background process might have an
|
||||
open handle. A tombstone ensures the loader treats the skill as
|
||||
gone even if a stale SKILL.md lingers.
|
||||
"""
|
||||
self.overrides.pop(skill_name, None)
|
||||
if tombstone:
|
||||
self.deleted_ui_skills.add(skill_name)
|
||||
|
||||
def is_disabled(self, skill_name: str, *, default_enabled: bool) -> bool:
|
||||
"""Return True when this scope's override force-disables the skill."""
|
||||
if self.all_defaults_disabled and default_enabled:
|
||||
# Caller says "default enabled"; master switch flips it off unless
|
||||
# an explicit enabled=True override re-enables.
|
||||
entry = self.overrides.get(skill_name)
|
||||
if entry is not None and entry.enabled is True:
|
||||
return False
|
||||
return True
|
||||
entry = self.overrides.get(skill_name)
|
||||
if entry is None:
|
||||
return not default_enabled
|
||||
if entry.enabled is None:
|
||||
return not default_enabled
|
||||
return not entry.enabled
|
||||
|
||||
def effective_enabled(self, skill_name: str, *, default_enabled: bool) -> bool:
|
||||
"""The inverse of :meth:`is_disabled`, for readability at call sites."""
|
||||
return not self.is_disabled(skill_name, default_enabled=default_enabled)
|
||||
|
||||
def get(self, skill_name: str) -> OverrideEntry | None:
|
||||
return self.overrides.get(skill_name)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Persistence
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def save(self) -> None:
|
||||
"""Atomic write: tmp + rename. Creates the parent dir if needed."""
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload: dict[str, Any] = {
|
||||
"version": self.version,
|
||||
"all_defaults_disabled": self.all_defaults_disabled,
|
||||
"overrides": {name: entry.to_dict() for name, entry in sorted(self.overrides.items())},
|
||||
}
|
||||
if self.deleted_ui_skills:
|
||||
payload["deleted_ui_skills"] = sorted(self.deleted_ui_skills)
|
||||
tmp = self.path.with_suffix(self.path.suffix + ".tmp")
|
||||
tmp.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
||||
tmp.replace(self.path)
|
||||
|
||||
|
||||
def utc_now() -> datetime:
|
||||
"""Single source of truth for override timestamps."""
|
||||
return datetime.now(tz=UTC)
|
||||
@@ -20,9 +20,17 @@ from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_SKILLS_DIR = Path(__file__).parent / "_default_skills"
|
||||
# Bundled skills live in two sibling dirs: ``_default_skills`` (always-on
|
||||
# infra) and ``_preset_skills`` (capability packs, off by default but
|
||||
# still bundled). Tool-gated pre-activation walks both so ``browser_*``
|
||||
# tools still pull in the browser-automation preset even though it isn't
|
||||
# default-enabled in the catalog.
|
||||
_BUNDLED_DIRS: tuple[Path, ...] = (
|
||||
Path(__file__).parent / "_default_skills",
|
||||
Path(__file__).parent / "_preset_skills",
|
||||
)
|
||||
|
||||
# (tool-name prefix, default skill directory name, display name)
|
||||
# (tool-name prefix, skill directory name, display name)
|
||||
_TOOL_GATED_SKILLS: list[tuple[str, str, str]] = [
|
||||
("browser_", "browser-automation", "hive.browser-automation"),
|
||||
]
|
||||
@@ -31,12 +39,23 @@ _BODY_CACHE: dict[str, str] = {}
|
||||
|
||||
|
||||
def _load_body(dir_name: str) -> str:
|
||||
"""Load the markdown body of a framework default skill, cached."""
|
||||
"""Load the markdown body of a bundled skill, cached. Searches every
|
||||
bundled directory (default + preset) so the mapping table doesn't
|
||||
need to know which dir a skill lives in.
|
||||
"""
|
||||
if dir_name in _BODY_CACHE:
|
||||
return _BODY_CACHE[dir_name]
|
||||
|
||||
path = _DEFAULT_SKILLS_DIR / dir_name / "SKILL.md"
|
||||
path: Path | None = None
|
||||
for parent in _BUNDLED_DIRS:
|
||||
candidate = parent / dir_name / "SKILL.md"
|
||||
if candidate.exists():
|
||||
path = candidate
|
||||
break
|
||||
body = ""
|
||||
if path is None:
|
||||
_BODY_CACHE[dir_name] = body
|
||||
return body
|
||||
try:
|
||||
raw = path.read_text(encoding="utf-8")
|
||||
# Strip YAML frontmatter (between the first two '---' fences)
|
||||
|
||||
@@ -318,13 +318,19 @@ class TrustGate:
|
||||
) -> list[ParsedSkill]:
|
||||
"""Return the subset of skills that are trusted for loading.
|
||||
|
||||
- Framework and user-scope skills: always included.
|
||||
- Framework, user, queen_ui, and colony_ui scopes: always included.
|
||||
(UI-created skills are authenticated by the user creating them
|
||||
through the authenticated UI — they do not go through the
|
||||
trusted_repos.json flow.)
|
||||
- Project-scope skills: classified; consent prompt shown if untrusted.
|
||||
"""
|
||||
import os
|
||||
|
||||
# Separate project skills from always-trusted scopes
|
||||
always_trusted = [s for s in skills if s.source_scope != "project"]
|
||||
# UI-authored scopes bypass the trust gate — they're implicitly
|
||||
# trusted because the user authored them through the UI. ``preset``
|
||||
# ships with the framework distribution, so it's trusted too.
|
||||
_bypass_scopes = {"framework", "preset", "user", "queen_ui", "colony_ui"}
|
||||
always_trusted = [s for s in skills if s.source_scope in _bypass_scopes]
|
||||
project_skills = [s for s in skills if s.source_scope == "project"]
|
||||
|
||||
if not project_skills:
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
"""File-backed, lock-coordinated task tracker for the hive agent loop.
|
||||
|
||||
See temp/tasks-system-implementation-plan.md for the design. Two list types:
|
||||
|
||||
colony:{colony_id} -- the queen's spawn-plan template
|
||||
session:{agent_id}:{sess_id} -- per-session working list
|
||||
|
||||
Each agent operates on its own session list via the session task tools
|
||||
(`task_create_batch`, `task_create`, `task_update`, `task_list`,
|
||||
`task_get`). The colony
|
||||
template is addressed only by the queen's `colony_template_*` tools and by
|
||||
the UI/event surface.
|
||||
"""
|
||||
|
||||
from framework.tasks.models import (
|
||||
ClaimResult,
|
||||
TaskListMeta,
|
||||
TaskListRole,
|
||||
TaskRecord,
|
||||
TaskStatus,
|
||||
)
|
||||
from framework.tasks.scoping import (
|
||||
colony_task_list_id,
|
||||
parse_task_list_id,
|
||||
resolve_task_list_id,
|
||||
session_task_list_id,
|
||||
)
|
||||
from framework.tasks.store import (
|
||||
TaskStore,
|
||||
get_task_store,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ClaimResult",
|
||||
"TaskListMeta",
|
||||
"TaskListRole",
|
||||
"TaskRecord",
|
||||
"TaskStatus",
|
||||
"TaskStore",
|
||||
"colony_task_list_id",
|
||||
"get_task_store",
|
||||
"parse_task_list_id",
|
||||
"resolve_task_list_id",
|
||||
"session_task_list_id",
|
||||
]
|
||||
@@ -0,0 +1,158 @@
|
||||
"""Bridge from the task store to the EventBus.
|
||||
|
||||
The store is intentionally event-free — it's pure storage. The tool
|
||||
executors (and run_parallel_workers, and any future colony_template_*
|
||||
caller) are responsible for emitting the lifecycle events to the bus
|
||||
after successful mutations.
|
||||
|
||||
Events are scoped to a stream_id pulled from the execution context if
|
||||
available; otherwise they fan out at the global ``primary`` stream so the
|
||||
UI's broad subscriptions still see them.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from framework.host.event_bus import AgentEvent, EventBus, EventType
|
||||
from framework.tasks.models import TaskRecord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Process-global default — set by the runner / orchestrator at bringup.
|
||||
_DEFAULT_BUS: EventBus | None = None
|
||||
|
||||
|
||||
def set_default_event_bus(bus: EventBus | None) -> None:
|
||||
global _DEFAULT_BUS
|
||||
_DEFAULT_BUS = bus
|
||||
|
||||
|
||||
def _get_bus(bus: EventBus | None = None) -> EventBus | None:
|
||||
return bus or _DEFAULT_BUS
|
||||
|
||||
|
||||
def _serialize_record(rec: TaskRecord) -> dict[str, Any]:
|
||||
return {
|
||||
"id": rec.id,
|
||||
"subject": rec.subject,
|
||||
"description": rec.description,
|
||||
"active_form": rec.active_form,
|
||||
"owner": rec.owner,
|
||||
"status": rec.status.value,
|
||||
"blocks": list(rec.blocks),
|
||||
"blocked_by": list(rec.blocked_by),
|
||||
"metadata": dict(rec.metadata),
|
||||
"created_at": rec.created_at,
|
||||
"updated_at": rec.updated_at,
|
||||
}
|
||||
|
||||
|
||||
async def emit_task_created(
|
||||
*,
|
||||
task_list_id: str,
|
||||
record: TaskRecord,
|
||||
stream_id: str = "primary",
|
||||
bus: EventBus | None = None,
|
||||
) -> None:
|
||||
b = _get_bus(bus)
|
||||
if b is None:
|
||||
return
|
||||
try:
|
||||
await b.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TASK_CREATED,
|
||||
stream_id=stream_id,
|
||||
data={
|
||||
"task_list_id": task_list_id,
|
||||
"task": _serialize_record(record),
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("emit_task_created failed", exc_info=True)
|
||||
|
||||
|
||||
async def emit_task_updated(
|
||||
*,
|
||||
task_list_id: str,
|
||||
record: TaskRecord,
|
||||
fields: list[str],
|
||||
stream_id: str = "primary",
|
||||
bus: EventBus | None = None,
|
||||
) -> None:
|
||||
b = _get_bus(bus)
|
||||
if b is None or not fields:
|
||||
return
|
||||
try:
|
||||
await b.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TASK_UPDATED,
|
||||
stream_id=stream_id,
|
||||
data={
|
||||
"task_list_id": task_list_id,
|
||||
"task_id": record.id,
|
||||
"after": _serialize_record(record),
|
||||
"fields": fields,
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("emit_task_updated failed", exc_info=True)
|
||||
|
||||
|
||||
async def emit_task_deleted(
|
||||
*,
|
||||
task_list_id: str,
|
||||
task_id: int,
|
||||
cascade: list[int],
|
||||
stream_id: str = "primary",
|
||||
bus: EventBus | None = None,
|
||||
) -> None:
|
||||
b = _get_bus(bus)
|
||||
if b is None:
|
||||
return
|
||||
try:
|
||||
await b.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TASK_DELETED,
|
||||
stream_id=stream_id,
|
||||
data={
|
||||
"task_list_id": task_list_id,
|
||||
"task_id": task_id,
|
||||
"cascade": cascade,
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("emit_task_deleted failed", exc_info=True)
|
||||
|
||||
|
||||
async def emit_colony_template_assignment(
|
||||
*,
|
||||
colony_id: str,
|
||||
task_id: int,
|
||||
assigned_session: str | None,
|
||||
assigned_worker_id: str | None,
|
||||
stream_id: str = "primary",
|
||||
bus: EventBus | None = None,
|
||||
) -> None:
|
||||
b = _get_bus(bus)
|
||||
if b is None:
|
||||
return
|
||||
try:
|
||||
await b.publish(
|
||||
AgentEvent(
|
||||
type=EventType.COLONY_TEMPLATE_ASSIGNMENT,
|
||||
stream_id=stream_id,
|
||||
data={
|
||||
"colony_id": colony_id,
|
||||
"task_id": task_id,
|
||||
"assigned_session": assigned_session,
|
||||
"assigned_worker_id": assigned_worker_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("emit_colony_template_assignment failed", exc_info=True)
|
||||
@@ -0,0 +1,103 @@
|
||||
"""Task lifecycle hooks.
|
||||
|
||||
Two events:
|
||||
|
||||
* ``task_created`` -- fires after the task file is written but before the
|
||||
tool returns. Hooks may raise ``BlockingHookError``
|
||||
to abort creation; the wrapper deletes the just-
|
||||
created task and returns an error tool_result.
|
||||
|
||||
* ``task_completed`` -- fires when ``task_update`` transitions a task to
|
||||
``completed``. A blocking error rolls the status
|
||||
back to ``in_progress`` and surfaces the error.
|
||||
|
||||
Hooks are registered on a process-global registry so callers (test
|
||||
fixtures, integrations) can install them without threading through the
|
||||
agent loop. They run in registration order; any hook may abort by raising
|
||||
``BlockingHookError``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
HOOK_TASK_CREATED = "task_created"
|
||||
HOOK_TASK_COMPLETED = "task_completed"
|
||||
|
||||
|
||||
class BlockingHookError(Exception):
|
||||
"""Raised by a hook to veto the surrounding tool operation."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskHookContext:
|
||||
event: str
|
||||
task_list_id: str
|
||||
task: Any # TaskRecord (avoid import cycle)
|
||||
agent_id: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
HookFn = Callable[[TaskHookContext], Any | Awaitable[Any]]
|
||||
|
||||
_HOOK_REGISTRY: dict[str, list[HookFn]] = {
|
||||
HOOK_TASK_CREATED: [],
|
||||
HOOK_TASK_COMPLETED: [],
|
||||
}
|
||||
|
||||
|
||||
def register_hook(event: str, fn: HookFn) -> None:
|
||||
if event not in _HOOK_REGISTRY:
|
||||
raise ValueError(f"Unknown hook event: {event!r}")
|
||||
_HOOK_REGISTRY[event].append(fn)
|
||||
|
||||
|
||||
def clear_hooks(event: str | None = None) -> None:
|
||||
"""Test helper. Clear all hooks (or just one event's)."""
|
||||
if event is None:
|
||||
for k in _HOOK_REGISTRY:
|
||||
_HOOK_REGISTRY[k].clear()
|
||||
else:
|
||||
_HOOK_REGISTRY.get(event, []).clear()
|
||||
|
||||
|
||||
async def run_task_hooks(
|
||||
event: str,
|
||||
*,
|
||||
task_list_id: str,
|
||||
task: Any,
|
||||
agent_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Run all hooks registered for ``event``.
|
||||
|
||||
Re-raises ``BlockingHookError`` from any hook; the caller is responsible
|
||||
for rolling back the operation.
|
||||
"""
|
||||
hooks = list(_HOOK_REGISTRY.get(event, ()))
|
||||
if not hooks:
|
||||
return
|
||||
ctx = TaskHookContext(
|
||||
event=event,
|
||||
task_list_id=task_list_id,
|
||||
task=task,
|
||||
agent_id=agent_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
for hook in hooks:
|
||||
try:
|
||||
result = hook(ctx)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
except BlockingHookError:
|
||||
raise
|
||||
except Exception:
|
||||
# Non-blocking exceptions are logged but do not abort the operation.
|
||||
logger.exception("Non-blocking hook failed for %s", event)
|
||||
@@ -0,0 +1,93 @@
|
||||
"""Data models for the task tracker.
|
||||
|
||||
The schema follows the UI-facing task-record shape with one notable
|
||||
difference: ids are integers (Python is cleaner that way) and rendered
|
||||
as ``#N`` only in user-facing strings.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TaskStatus(StrEnum):
|
||||
PENDING = "pending"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
|
||||
|
||||
class TaskListRole(StrEnum):
|
||||
"""Distinguishes a colony template from a session-scoped working list.
|
||||
|
||||
Used for sanity-checking which write paths are allowed (e.g. the four
|
||||
session tools must never touch a ``template`` list).
|
||||
"""
|
||||
|
||||
TEMPLATE = "template" # colony:{colony_id}
|
||||
SESSION = "session" # session:{agent_id}:{session_id}
|
||||
|
||||
|
||||
class TaskRecord(BaseModel):
|
||||
"""One unit of work tracked by an agent."""
|
||||
|
||||
id: int # monotonic, never reused — see store.py
|
||||
subject: str
|
||||
description: str = ""
|
||||
active_form: str | None = None # present-continuous label, surfaces in UI
|
||||
owner: str | None = None # agent_id of the owning agent
|
||||
status: TaskStatus = TaskStatus.PENDING
|
||||
blocks: list[int] = Field(default_factory=list)
|
||||
blocked_by: list[int] = Field(default_factory=list)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
created_at: float = Field(default_factory=time.time)
|
||||
updated_at: float = Field(default_factory=time.time)
|
||||
|
||||
|
||||
class TaskListMeta(BaseModel):
|
||||
"""Per-list metadata stored in ``meta.json`` next to the task files."""
|
||||
|
||||
task_list_id: str
|
||||
role: TaskListRole
|
||||
creator_agent_id: str | None = None
|
||||
created_at: float = Field(default_factory=time.time)
|
||||
last_seen_session_ids: list[str] = Field(default_factory=list)
|
||||
schema_version: int = 1
|
||||
|
||||
|
||||
# Tagged union for claim_task_with_busy_check. Used by run_parallel_workers
|
||||
# when stamping ``assigned_session`` on a colony template entry — the only
|
||||
# place a "claim" actually happens under the hive model.
|
||||
@dataclass
|
||||
class ClaimOk:
|
||||
kind: Literal["ok"]
|
||||
record: TaskRecord
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClaimNotFound:
|
||||
kind: Literal["not_found"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClaimAlreadyOwned:
|
||||
kind: Literal["already_owned"]
|
||||
by: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClaimAlreadyCompleted:
|
||||
kind: Literal["already_completed"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClaimBlocked:
|
||||
kind: Literal["blocked"]
|
||||
by: list[int]
|
||||
|
||||
|
||||
ClaimResult = ClaimOk | ClaimNotFound | ClaimAlreadyOwned | ClaimAlreadyCompleted | ClaimBlocked
|
||||
@@ -0,0 +1,104 @@
|
||||
"""Periodic task-reminder injection.
|
||||
|
||||
After enough silent turns since the last task tool call, inject a
|
||||
reminder summarizing the current open tasks. Catches the failure mode
|
||||
where the agent has silently absorbed multiple finished steps into one
|
||||
in_progress task and stopped using the task tools.
|
||||
|
||||
The reminder counter lives on the AgentLoop instance; this module owns
|
||||
the policy (threshold, cooldown, message text) and the integration
|
||||
helper. Wiring lives in :mod:`framework.tasks.integrations.agent_loop`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from framework.tasks.models import TaskRecord, TaskStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
REMINDER_THRESHOLD_TURNS = int(os.environ.get("HIVE_TASK_REMINDER_TURNS", "8"))
|
||||
REMINDER_COOLDOWN_TURNS = int(os.environ.get("HIVE_TASK_REMINDER_COOLDOWN", "8"))
|
||||
|
||||
# Names that count as "task ops" — calling any of these resets the silence
|
||||
# counter. Keep narrow: only mutating ops re-establish discipline. task_list
|
||||
# / task_get are read-only and shouldn't reset the counter (the agent could
|
||||
# read forever without making progress).
|
||||
TASK_OP_TOOL_NAMES: frozenset[str] = frozenset(
|
||||
{
|
||||
"task_create",
|
||||
"task_update",
|
||||
"colony_template_add",
|
||||
"colony_template_update",
|
||||
"colony_template_remove",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReminderState:
|
||||
"""Per-loop counter — caller bumps it each iteration."""
|
||||
|
||||
turns_since_task_op: int = 0
|
||||
turns_since_last_reminder: int = 0
|
||||
|
||||
def on_iteration(self) -> None:
|
||||
self.turns_since_task_op += 1
|
||||
self.turns_since_last_reminder += 1
|
||||
|
||||
def on_task_op(self) -> None:
|
||||
self.turns_since_task_op = 0
|
||||
|
||||
def on_reminder_sent(self) -> None:
|
||||
self.turns_since_last_reminder = 0
|
||||
|
||||
def should_remind(self, has_open_tasks: bool) -> bool:
|
||||
return (
|
||||
has_open_tasks
|
||||
and self.turns_since_task_op >= REMINDER_THRESHOLD_TURNS
|
||||
and self.turns_since_last_reminder >= REMINDER_COOLDOWN_TURNS
|
||||
)
|
||||
|
||||
|
||||
def saw_task_op(tool_names: Iterable[str]) -> bool:
|
||||
"""True if any of the names is a counter-resetting task op."""
|
||||
return any(name in TASK_OP_TOOL_NAMES for name in tool_names)
|
||||
|
||||
|
||||
def build_reminder(records: list[TaskRecord]) -> str:
|
||||
"""Compose the reminder body — pending/in-progress focus."""
|
||||
open_ = [r for r in records if r.status != TaskStatus.COMPLETED]
|
||||
if not open_:
|
||||
return ""
|
||||
in_progress = [r for r in open_ if r.status == TaskStatus.IN_PROGRESS]
|
||||
head = (
|
||||
"[task_reminder] The task tools haven't been used in several "
|
||||
"turns. If you're working on tasks that would benefit from "
|
||||
"tracked progress:"
|
||||
)
|
||||
bullets = [
|
||||
" - Mark the in_progress task `completed` THE MOMENT it's done — "
|
||||
"before starting the next step. Don't batch completions.",
|
||||
" - If you've finished work that wasn't on the list, add a "
|
||||
"task_create + task_update completed pair so the panel reflects it.",
|
||||
" - If you're umbrella-tracking ('reply to all posts' as one task), "
|
||||
"break it into one task per atomic action — use `task_create_batch` "
|
||||
"with one entry per action.",
|
||||
]
|
||||
if in_progress:
|
||||
bullets.append(
|
||||
" - Currently in_progress (consider whether they're really "
|
||||
"still active): " + ", ".join(f'#{r.id} "{r.subject}"' for r in in_progress[:5])
|
||||
)
|
||||
listing = ["", "Open tasks:"]
|
||||
for r in open_[:10]:
|
||||
listing.append(f" #{r.id} [{r.status.value}] {r.subject}")
|
||||
if len(open_) > 10:
|
||||
listing.append(f" ... and {len(open_) - 10} more")
|
||||
listing.append("\nOnly act on this if relevant to the current work. NEVER mention this reminder to the user.")
|
||||
return "\n".join([head, *bullets, *listing])
|
||||
@@ -0,0 +1,80 @@
|
||||
"""Task list id resolution.
|
||||
|
||||
Under the corrected model (see plan §5):
|
||||
|
||||
- Every agent session owns one task list: ``session:{agent_id}:{session_id}``
|
||||
- The colony has a separate template list: ``colony:{colony_id}``
|
||||
|
||||
``resolve_task_list_id(ctx)`` returns the agent's OWN session list id —
|
||||
what the four task tools write to. The colony template is addressed via
|
||||
the dedicated ``colony_template_*`` tools and the UI; never via the four
|
||||
session tools.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def session_task_list_id(agent_id: str, session_id: str) -> str:
|
||||
return f"session:{agent_id}:{session_id}"
|
||||
|
||||
|
||||
def colony_task_list_id(colony_id: str) -> str:
|
||||
return f"colony:{colony_id}"
|
||||
|
||||
|
||||
def parse_task_list_id(task_list_id: str) -> dict[str, str]:
|
||||
"""Decode a task_list_id into its component parts.
|
||||
|
||||
Returns a dict with at least ``kind`` ("session" / "colony" / "unscoped"
|
||||
/ "raw"), and the relevant ids when applicable.
|
||||
"""
|
||||
if task_list_id.startswith("session:"):
|
||||
rest = task_list_id[len("session:") :]
|
||||
agent_id, _, session_id = rest.partition(":")
|
||||
return {"kind": "session", "agent_id": agent_id, "session_id": session_id}
|
||||
if task_list_id.startswith("colony:"):
|
||||
return {"kind": "colony", "colony_id": task_list_id[len("colony:") :]}
|
||||
if task_list_id.startswith("unscoped:"):
|
||||
return {"kind": "unscoped", "agent_id": task_list_id[len("unscoped:") :]}
|
||||
return {"kind": "raw", "value": task_list_id}
|
||||
|
||||
|
||||
def resolve_task_list_id(ctx: Any) -> str:
|
||||
"""Return the agent's own session-scoped task list id.
|
||||
|
||||
Resolution priority:
|
||||
|
||||
1. ``HIVE_TASK_LIST_ID`` env var (test/CLI override)
|
||||
2. ``ctx.task_list_id`` if already populated by the runner
|
||||
3. ``session:{ctx.agent_id}:{ctx.run_id or ctx.execution_id}``
|
||||
4. ``unscoped:{ctx.agent_id}`` sentinel (should not happen in prod)
|
||||
"""
|
||||
override = os.environ.get("HIVE_TASK_LIST_ID")
|
||||
if override:
|
||||
return override
|
||||
|
||||
existing = getattr(ctx, "task_list_id", None)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
agent_id = getattr(ctx, "agent_id", None) or ""
|
||||
session_id = (
|
||||
getattr(ctx, "run_id", None) or getattr(ctx, "execution_id", None) or getattr(ctx, "stream_id", None) or ""
|
||||
)
|
||||
if agent_id and session_id:
|
||||
return session_task_list_id(agent_id, session_id)
|
||||
|
||||
fallback = f"unscoped:{agent_id or 'unknown'}"
|
||||
logger.warning(
|
||||
"resolve_task_list_id falling back to %s — agent_id=%r session_id=%r",
|
||||
fallback,
|
||||
agent_id,
|
||||
session_id,
|
||||
)
|
||||
return fallback
|
||||
@@ -0,0 +1,722 @@
|
||||
"""File-backed task store with filelock-based coordination.
|
||||
|
||||
Layout per list::
|
||||
|
||||
{root}/{task_list_id}/
|
||||
meta.json -- TaskListMeta
|
||||
tasks/
|
||||
0001.json -- TaskRecord (zero-padded for ls-sort)
|
||||
0002.json
|
||||
...
|
||||
.lock -- list-level lock
|
||||
.highwatermark -- ID floor (deleted ids never reused)
|
||||
|
||||
Two list-roots:
|
||||
|
||||
colony:{colony_id} -> ~/.hive/colonies/{colony_id}/tasks/
|
||||
session:{a}:{s} -> ~/.hive/agents/{a}/sessions/{s}/tasks/
|
||||
|
||||
All filesystem I/O is wrapped in ``asyncio.to_thread`` so the event loop
|
||||
never blocks. Locks use a 30-retry / ~2.6s budget — comfortable headroom
|
||||
for the only realistic write contender (colony template under concurrent
|
||||
``colony_template_*`` and ``run_parallel_workers`` stamps).
|
||||
|
||||
The "_unsafe" variants exist because filelock is **not re-entrant**: a
|
||||
caller already holding a lock must NOT re-acquire it (would deadlock).
|
||||
The unsafe path skips acquisition and is callable only from inside another
|
||||
locked function. See ``claim_task_with_busy_check`` and ``delete_task``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from filelock import FileLock
|
||||
|
||||
from framework.tasks.models import (
|
||||
ClaimAlreadyCompleted,
|
||||
ClaimAlreadyOwned,
|
||||
ClaimBlocked,
|
||||
ClaimNotFound,
|
||||
ClaimOk,
|
||||
ClaimResult,
|
||||
TaskListMeta,
|
||||
TaskListRole,
|
||||
TaskRecord,
|
||||
TaskStatus,
|
||||
)
|
||||
from framework.utils.io import atomic_write
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LOCK_TIMEOUT_SECONDS = 3.0 # ~30 retries × ~100ms
|
||||
|
||||
|
||||
class _Unset:
|
||||
"""Sentinel for "owner argument not provided" — distinct from owner=None."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
_UNSET_SENTINEL: _Unset = _Unset()
|
||||
|
||||
|
||||
def _hive_root() -> Path:
|
||||
"""Location of the hive data dir; honors HIVE_HOME for tests."""
|
||||
return Path(os.environ.get("HIVE_HOME", str(Path.home() / ".hive")))
|
||||
|
||||
|
||||
def task_list_path(task_list_id: str, *, hive_root: Path | None = None) -> Path:
|
||||
"""Resolve task_list_id -> on-disk root."""
|
||||
root = hive_root or _hive_root()
|
||||
if task_list_id.startswith("colony:"):
|
||||
colony_id = task_list_id[len("colony:") :]
|
||||
return root / "colonies" / colony_id / "tasks"
|
||||
if task_list_id.startswith("session:"):
|
||||
rest = task_list_id[len("session:") :]
|
||||
agent_id, _, session_id = rest.partition(":")
|
||||
if not session_id:
|
||||
raise ValueError(f"Malformed session task_list_id: {task_list_id!r}")
|
||||
return root / "agents" / agent_id / "sessions" / session_id / "tasks"
|
||||
if task_list_id.startswith("unscoped:"):
|
||||
agent_id = task_list_id[len("unscoped:") :]
|
||||
return root / "unscoped" / agent_id / "tasks"
|
||||
# Last-ditch sanitization for HIVE_TASK_LIST_ID overrides — slugify the
|
||||
# whole thing so the test/dev path can't escape the hive root.
|
||||
safe = "".join(c if c.isalnum() or c in "-_" else "_" for c in task_list_id)
|
||||
return root / "_misc" / safe
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TaskStore — public façade
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TaskStore:
|
||||
"""Async wrapper around the on-disk store.
|
||||
|
||||
A single TaskStore is fine to share across the process; locking is
|
||||
file-based, so even multiple processes are safe.
|
||||
"""
|
||||
|
||||
def __init__(self, *, hive_root: Path | None = None) -> None:
|
||||
self._hive_root = hive_root
|
||||
|
||||
# ----- list-level ---------------------------------------------------
|
||||
|
||||
async def ensure_task_list(
|
||||
self,
|
||||
task_list_id: str,
|
||||
*,
|
||||
role: TaskListRole,
|
||||
creator_agent_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> TaskListMeta:
|
||||
"""Create a list if absent; if present, append session_id to last_seen.
|
||||
|
||||
Idempotent: callers (ColonyRuntime bringup, lazy session creation)
|
||||
can call this every time.
|
||||
"""
|
||||
return await asyncio.to_thread(
|
||||
self._ensure_task_list_sync,
|
||||
task_list_id,
|
||||
role,
|
||||
creator_agent_id,
|
||||
session_id,
|
||||
)
|
||||
|
||||
async def list_exists(self, task_list_id: str) -> bool:
|
||||
"""A list exists if its meta.json OR any task file is on disk.
|
||||
|
||||
meta.json is normally written by ``ensure_task_list``, but session
|
||||
lists may be created lazily via the first ``task_create`` (see
|
||||
``_create_task_sync``) — in that case meta.json is backfilled the
|
||||
first time the list is read. Until then, we still want to expose
|
||||
the list's tasks via REST.
|
||||
"""
|
||||
|
||||
def _check() -> bool:
|
||||
root = self._list_root(task_list_id)
|
||||
if (root / "meta.json").exists():
|
||||
return True
|
||||
tasks_dir = root / "tasks"
|
||||
if tasks_dir.exists() and any(p.suffix == ".json" for p in tasks_dir.iterdir()):
|
||||
return True
|
||||
return False
|
||||
|
||||
return await asyncio.to_thread(_check)
|
||||
|
||||
async def get_meta(self, task_list_id: str) -> TaskListMeta | None:
|
||||
return await asyncio.to_thread(self._read_meta_sync, task_list_id)
|
||||
|
||||
async def reset_task_list(self, task_list_id: str) -> None:
|
||||
"""Delete all task files but preserve the high-water-mark.
|
||||
|
||||
Test helper. Never wired to runtime lifecycle.
|
||||
"""
|
||||
await asyncio.to_thread(self._reset_sync, task_list_id)
|
||||
|
||||
# ----- task CRUD ----------------------------------------------------
|
||||
|
||||
async def create_tasks_batch(
|
||||
self,
|
||||
task_list_id: str,
|
||||
specs: list[dict[str, Any]],
|
||||
) -> list[TaskRecord]:
|
||||
"""Atomically create N tasks under a single list-lock acquisition.
|
||||
|
||||
Each spec is a dict with keys: subject (required), description,
|
||||
active_form, owner, metadata. Ids are assigned sequentially and
|
||||
contiguously — if any task fails to write, an exception is raised
|
||||
and the whole batch is rolled back (file unlinked, high-water-mark
|
||||
kept at the prior value).
|
||||
|
||||
Atomic-or-none semantics matter for the tool surface: a failed
|
||||
partial batch would leave the LLM reasoning about cleanup, which
|
||||
defeats the point of batching as a single decision.
|
||||
"""
|
||||
return await asyncio.to_thread(
|
||||
self._create_tasks_batch_sync, task_list_id, specs
|
||||
)
|
||||
|
||||
async def create_task(
|
||||
self,
|
||||
task_list_id: str,
|
||||
*,
|
||||
subject: str,
|
||||
description: str = "",
|
||||
active_form: str | None = None,
|
||||
owner: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> TaskRecord:
|
||||
return await asyncio.to_thread(
|
||||
self._create_task_sync,
|
||||
task_list_id,
|
||||
subject,
|
||||
description,
|
||||
active_form,
|
||||
owner,
|
||||
metadata or {},
|
||||
)
|
||||
|
||||
async def get_task(self, task_list_id: str, task_id: int) -> TaskRecord | None:
|
||||
return await asyncio.to_thread(self._read_task_sync, task_list_id, task_id)
|
||||
|
||||
async def list_tasks(
|
||||
self,
|
||||
task_list_id: str,
|
||||
*,
|
||||
include_internal: bool = False,
|
||||
) -> list[TaskRecord]:
|
||||
records = await asyncio.to_thread(self._list_tasks_sync, task_list_id)
|
||||
if include_internal:
|
||||
return records
|
||||
return [r for r in records if not r.metadata.get("_internal")]
|
||||
|
||||
async def update_task(
|
||||
self,
|
||||
task_list_id: str,
|
||||
task_id: int,
|
||||
*,
|
||||
subject: str | None = None,
|
||||
description: str | None = None,
|
||||
active_form: str | None = None,
|
||||
owner: str | None | _Unset = _UNSET_SENTINEL,
|
||||
status: TaskStatus | None = None,
|
||||
add_blocks: list[int] | None = None,
|
||||
add_blocked_by: list[int] | None = None,
|
||||
metadata_patch: dict[str, Any] | None = None,
|
||||
) -> tuple[TaskRecord | None, list[str]]:
|
||||
"""Update a task; returns (new_record, fields_changed) or (None, [])."""
|
||||
return await asyncio.to_thread(
|
||||
self._update_task_sync,
|
||||
task_list_id,
|
||||
task_id,
|
||||
subject,
|
||||
description,
|
||||
active_form,
|
||||
owner,
|
||||
status,
|
||||
add_blocks,
|
||||
add_blocked_by,
|
||||
metadata_patch,
|
||||
)
|
||||
|
||||
async def delete_task(self, task_list_id: str, task_id: int) -> tuple[bool, list[int]]:
|
||||
"""Delete a task; returns (was_deleted, cascaded_ids).
|
||||
|
||||
``cascaded_ids`` are the ids of other tasks whose blocks/blocked_by
|
||||
referenced the deleted id and were stripped.
|
||||
"""
|
||||
return await asyncio.to_thread(self._delete_task_sync, task_list_id, task_id)
|
||||
|
||||
async def claim_task_with_busy_check(
|
||||
self,
|
||||
task_list_id: str,
|
||||
task_id: int,
|
||||
claimant: str,
|
||||
) -> ClaimResult:
|
||||
"""Atomic claim under list-lock.
|
||||
|
||||
Used internally by ``run_parallel_workers`` when stamping
|
||||
``metadata.assigned_session`` on colony template entries — not
|
||||
exposed to LLMs as a worker-facing claim race.
|
||||
"""
|
||||
return await asyncio.to_thread(self._claim_sync, task_list_id, task_id, claimant)
|
||||
|
||||
# =====================================================================
|
||||
# Sync internals — all called via asyncio.to_thread
|
||||
# =====================================================================
|
||||
|
||||
def _list_root(self, task_list_id: str) -> Path:
|
||||
return task_list_path(task_list_id, hive_root=self._hive_root)
|
||||
|
||||
def _tasks_dir(self, task_list_id: str) -> Path:
|
||||
return self._list_root(task_list_id) / "tasks"
|
||||
|
||||
def _list_lock(self, task_list_id: str) -> FileLock:
|
||||
# FileLock targets a sentinel file; it tolerates the file being absent
|
||||
# by creating it on first acquire. We use the .lock filename so it's
|
||||
# visible alongside the other list files.
|
||||
root = self._list_root(task_list_id)
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
return FileLock(str(root / ".lock"), timeout=LOCK_TIMEOUT_SECONDS)
|
||||
|
||||
def _highwatermark_path(self, task_list_id: str) -> Path:
|
||||
return self._list_root(task_list_id) / ".highwatermark"
|
||||
|
||||
def _meta_path(self, task_list_id: str) -> Path:
|
||||
return self._list_root(task_list_id) / "meta.json"
|
||||
|
||||
def _task_path(self, task_list_id: str, task_id: int) -> Path:
|
||||
return self._tasks_dir(task_list_id) / f"{task_id:04d}.json"
|
||||
|
||||
# ----- meta ---------------------------------------------------------
|
||||
|
||||
def _ensure_task_list_sync(
|
||||
self,
|
||||
task_list_id: str,
|
||||
role: TaskListRole,
|
||||
creator_agent_id: str | None,
|
||||
session_id: str | None,
|
||||
) -> TaskListMeta:
|
||||
root = self._list_root(task_list_id)
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
(root / "tasks").mkdir(exist_ok=True)
|
||||
meta_path = self._meta_path(task_list_id)
|
||||
with self._list_lock(task_list_id):
|
||||
if meta_path.exists():
|
||||
meta = self._read_meta_sync(task_list_id)
|
||||
if meta is None:
|
||||
# File existed but failed to parse — rewrite fresh.
|
||||
meta = TaskListMeta(
|
||||
task_list_id=task_list_id,
|
||||
role=role,
|
||||
creator_agent_id=creator_agent_id,
|
||||
)
|
||||
if session_id and session_id not in meta.last_seen_session_ids:
|
||||
meta.last_seen_session_ids.append(session_id)
|
||||
# Cap at 10 to keep the audit trail bounded.
|
||||
meta.last_seen_session_ids = meta.last_seen_session_ids[-10:]
|
||||
self._write_meta_sync(task_list_id, meta)
|
||||
return meta
|
||||
meta = TaskListMeta(
|
||||
task_list_id=task_list_id,
|
||||
role=role,
|
||||
creator_agent_id=creator_agent_id,
|
||||
last_seen_session_ids=[session_id] if session_id else [],
|
||||
)
|
||||
self._write_meta_sync(task_list_id, meta)
|
||||
return meta
|
||||
|
||||
def _read_meta_sync(self, task_list_id: str) -> TaskListMeta | None:
|
||||
path = self._meta_path(task_list_id)
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
return TaskListMeta.model_validate_json(path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
logger.warning("Corrupt meta.json at %s", path, exc_info=True)
|
||||
return None
|
||||
|
||||
def _write_meta_sync(self, task_list_id: str, meta: TaskListMeta) -> None:
|
||||
path = self._meta_path(task_list_id)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with atomic_write(path) as f:
|
||||
f.write(meta.model_dump_json(indent=2))
|
||||
|
||||
# ----- task IO ------------------------------------------------------
|
||||
|
||||
def _read_task_sync(self, task_list_id: str, task_id: int) -> TaskRecord | None:
|
||||
path = self._task_path(task_list_id, task_id)
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
return TaskRecord.model_validate_json(path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
logger.warning("Corrupt task file at %s", path, exc_info=True)
|
||||
return None
|
||||
|
||||
def _write_task_sync(self, task_list_id: str, record: TaskRecord) -> None:
|
||||
path = self._task_path(task_list_id, record.id)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with atomic_write(path) as f:
|
||||
f.write(record.model_dump_json(indent=2))
|
||||
|
||||
def _list_tasks_sync(self, task_list_id: str) -> list[TaskRecord]:
|
||||
d = self._tasks_dir(task_list_id)
|
||||
if not d.exists():
|
||||
return []
|
||||
records: list[TaskRecord] = []
|
||||
for path in sorted(d.iterdir()):
|
||||
if path.suffix != ".json":
|
||||
continue
|
||||
try:
|
||||
records.append(TaskRecord.model_validate_json(path.read_text(encoding="utf-8")))
|
||||
except Exception:
|
||||
logger.warning("Skipping corrupt task file %s", path, exc_info=True)
|
||||
records.sort(key=lambda r: r.id)
|
||||
return records
|
||||
|
||||
# ----- highwatermark / id assignment --------------------------------
|
||||
|
||||
def _read_highwatermark_sync(self, task_list_id: str) -> int:
|
||||
path = self._highwatermark_path(task_list_id)
|
||||
if not path.exists():
|
||||
return 0
|
||||
try:
|
||||
return int(path.read_text(encoding="utf-8").strip() or "0")
|
||||
except (ValueError, OSError):
|
||||
return 0
|
||||
|
||||
def _write_highwatermark_sync(self, task_list_id: str, value: int) -> None:
|
||||
path = self._highwatermark_path(task_list_id)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with atomic_write(path) as f:
|
||||
f.write(str(value))
|
||||
|
||||
def _next_id_sync(self, task_list_id: str) -> int:
|
||||
"""Compute next id under the assumption the list-lock is held."""
|
||||
existing = self._list_tasks_sync(task_list_id)
|
||||
max_existing = max((r.id for r in existing), default=0)
|
||||
floor = self._read_highwatermark_sync(task_list_id)
|
||||
return max(max_existing, floor) + 1
|
||||
|
||||
# ----- create -------------------------------------------------------
|
||||
|
||||
def _create_task_sync(
|
||||
self,
|
||||
task_list_id: str,
|
||||
subject: str,
|
||||
description: str,
|
||||
active_form: str | None,
|
||||
owner: str | None,
|
||||
metadata: dict[str, Any],
|
||||
) -> TaskRecord:
|
||||
with self._list_lock(task_list_id):
|
||||
# Lazy-create meta.json on first task. Session lists are
|
||||
# frequently created via the first task_create (no explicit
|
||||
# ensure_task_list call); without this backfill the REST
|
||||
# endpoint can't discover them. Role is inferred from prefix.
|
||||
if not self._meta_path(task_list_id).exists():
|
||||
inferred_role = TaskListRole.TEMPLATE if task_list_id.startswith("colony:") else TaskListRole.SESSION
|
||||
self._write_meta_sync(
|
||||
task_list_id,
|
||||
TaskListMeta(
|
||||
task_list_id=task_list_id,
|
||||
role=inferred_role,
|
||||
),
|
||||
)
|
||||
new_id = self._next_id_sync(task_list_id)
|
||||
now = time.time()
|
||||
record = TaskRecord(
|
||||
id=new_id,
|
||||
subject=subject,
|
||||
description=description,
|
||||
active_form=active_form,
|
||||
owner=owner,
|
||||
status=TaskStatus.PENDING,
|
||||
metadata=metadata,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
self._write_task_sync(task_list_id, record)
|
||||
# Bump high-water-mark eagerly so even a concurrent racer that
|
||||
# somehow missed the listing snapshot can't pick the same id.
|
||||
if new_id > self._read_highwatermark_sync(task_list_id):
|
||||
self._write_highwatermark_sync(task_list_id, new_id)
|
||||
return record
|
||||
|
||||
def _create_tasks_batch_sync(
|
||||
self,
|
||||
task_list_id: str,
|
||||
specs: list[dict[str, Any]],
|
||||
) -> list[TaskRecord]:
|
||||
if not specs:
|
||||
return []
|
||||
# Validate up-front so we don't half-create on a malformed entry.
|
||||
for i, spec in enumerate(specs):
|
||||
subj = spec.get("subject")
|
||||
if not isinstance(subj, str) or not subj.strip():
|
||||
raise ValueError(f"specs[{i}].subject must be a non-empty string")
|
||||
|
||||
with self._list_lock(task_list_id):
|
||||
# Same lazy meta backfill as _create_task_sync.
|
||||
if not self._meta_path(task_list_id).exists():
|
||||
inferred_role = (
|
||||
TaskListRole.TEMPLATE
|
||||
if task_list_id.startswith("colony:")
|
||||
else TaskListRole.SESSION
|
||||
)
|
||||
self._write_meta_sync(
|
||||
task_list_id,
|
||||
TaskListMeta(task_list_id=task_list_id, role=inferred_role),
|
||||
)
|
||||
|
||||
base_id = self._next_id_sync(task_list_id)
|
||||
now = time.time()
|
||||
records: list[TaskRecord] = []
|
||||
for offset, spec in enumerate(specs):
|
||||
rec = TaskRecord(
|
||||
id=base_id + offset,
|
||||
subject=spec["subject"],
|
||||
description=spec.get("description", ""),
|
||||
active_form=spec.get("active_form"),
|
||||
owner=spec.get("owner"),
|
||||
status=TaskStatus.PENDING,
|
||||
metadata=dict(spec.get("metadata") or {}),
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
records.append(rec)
|
||||
|
||||
# Write all task files; on any failure, unlink everything we
|
||||
# wrote so far and re-raise. High-water-mark is bumped only
|
||||
# after a successful full-batch write.
|
||||
written: list[Path] = []
|
||||
try:
|
||||
for rec in records:
|
||||
self._write_task_sync(task_list_id, rec)
|
||||
written.append(self._task_path(task_list_id, rec.id))
|
||||
except Exception:
|
||||
for path in written:
|
||||
try:
|
||||
path.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
logger.warning("Failed to roll back batch task at %s", path, exc_info=True)
|
||||
raise
|
||||
|
||||
highest = records[-1].id
|
||||
if highest > self._read_highwatermark_sync(task_list_id):
|
||||
self._write_highwatermark_sync(task_list_id, highest)
|
||||
return records
|
||||
|
||||
# ----- update -------------------------------------------------------
|
||||
|
||||
def _update_task_sync(
|
||||
self,
|
||||
task_list_id: str,
|
||||
task_id: int,
|
||||
subject: str | None,
|
||||
description: str | None,
|
||||
active_form: str | None,
|
||||
owner: str | None | _Unset,
|
||||
status: TaskStatus | None,
|
||||
add_blocks: list[int] | None,
|
||||
add_blocked_by: list[int] | None,
|
||||
metadata_patch: dict[str, Any] | None,
|
||||
) -> tuple[TaskRecord | None, list[str]]:
|
||||
with self._list_lock(task_list_id):
|
||||
current = self._read_task_sync(task_list_id, task_id)
|
||||
if current is None:
|
||||
return None, []
|
||||
return self._update_task_unsafe(
|
||||
task_list_id,
|
||||
current,
|
||||
subject=subject,
|
||||
description=description,
|
||||
active_form=active_form,
|
||||
owner=owner,
|
||||
status=status,
|
||||
add_blocks=add_blocks,
|
||||
add_blocked_by=add_blocked_by,
|
||||
metadata_patch=metadata_patch,
|
||||
)
|
||||
|
||||
def _update_task_unsafe(
|
||||
self,
|
||||
task_list_id: str,
|
||||
current: TaskRecord,
|
||||
*,
|
||||
subject: str | None = None,
|
||||
description: str | None = None,
|
||||
active_form: str | None = None,
|
||||
owner: str | None | _Unset = _UNSET_SENTINEL,
|
||||
status: TaskStatus | None = None,
|
||||
add_blocks: list[int] | None = None,
|
||||
add_blocked_by: list[int] | None = None,
|
||||
metadata_patch: dict[str, Any] | None = None,
|
||||
) -> tuple[TaskRecord, list[str]]:
|
||||
"""Update without acquiring the list-lock. Caller MUST hold it."""
|
||||
changed: list[str] = []
|
||||
new = current.model_copy(deep=True)
|
||||
|
||||
if subject is not None and subject != new.subject:
|
||||
new.subject = subject
|
||||
changed.append("subject")
|
||||
if description is not None and description != new.description:
|
||||
new.description = description
|
||||
changed.append("description")
|
||||
if active_form is not None and active_form != new.active_form:
|
||||
new.active_form = active_form
|
||||
changed.append("active_form")
|
||||
if not isinstance(owner, _Unset) and owner != new.owner:
|
||||
new.owner = owner
|
||||
changed.append("owner")
|
||||
if status is not None and status != new.status:
|
||||
new.status = status
|
||||
changed.append("status")
|
||||
if add_blocks:
|
||||
for b in add_blocks:
|
||||
if b not in new.blocks and b != new.id:
|
||||
new.blocks.append(b)
|
||||
if "blocks" not in changed:
|
||||
changed.append("blocks")
|
||||
# Maintain the bidirectional invariant by stamping
|
||||
# blocked_by on the target as well.
|
||||
target = self._read_task_sync(task_list_id, b)
|
||||
if target and new.id not in target.blocked_by:
|
||||
target.blocked_by.append(new.id)
|
||||
target.updated_at = time.time()
|
||||
self._write_task_sync(task_list_id, target)
|
||||
if add_blocked_by:
|
||||
for b in add_blocked_by:
|
||||
if b not in new.blocked_by and b != new.id:
|
||||
new.blocked_by.append(b)
|
||||
if "blocked_by" not in changed:
|
||||
changed.append("blocked_by")
|
||||
target = self._read_task_sync(task_list_id, b)
|
||||
if target and new.id not in target.blocks:
|
||||
target.blocks.append(new.id)
|
||||
target.updated_at = time.time()
|
||||
self._write_task_sync(task_list_id, target)
|
||||
if metadata_patch is not None:
|
||||
md = dict(new.metadata)
|
||||
for k, v in metadata_patch.items():
|
||||
if v is None:
|
||||
md.pop(k, None)
|
||||
else:
|
||||
md[k] = v
|
||||
if md != new.metadata:
|
||||
new.metadata = md
|
||||
changed.append("metadata")
|
||||
|
||||
if not changed:
|
||||
return new, []
|
||||
|
||||
new.updated_at = time.time()
|
||||
self._write_task_sync(task_list_id, new)
|
||||
return new, changed
|
||||
|
||||
# ----- delete -------------------------------------------------------
|
||||
|
||||
def _delete_task_sync(self, task_list_id: str, task_id: int) -> tuple[bool, list[int]]:
|
||||
with self._list_lock(task_list_id):
|
||||
path = self._task_path(task_list_id, task_id)
|
||||
if not path.exists():
|
||||
return False, []
|
||||
# 1. Bump high-water-mark BEFORE unlinking so a crash mid-delete
|
||||
# can't accidentally re-allocate the id.
|
||||
current_floor = self._read_highwatermark_sync(task_list_id)
|
||||
if task_id > current_floor:
|
||||
self._write_highwatermark_sync(task_list_id, task_id)
|
||||
# 2. Unlink the task itself.
|
||||
path.unlink()
|
||||
# 3. Cascade: strip references from all other tasks.
|
||||
cascaded: list[int] = []
|
||||
for other in self._list_tasks_sync(task_list_id):
|
||||
touched = False
|
||||
if task_id in other.blocks:
|
||||
other.blocks = [b for b in other.blocks if b != task_id]
|
||||
touched = True
|
||||
if task_id in other.blocked_by:
|
||||
other.blocked_by = [b for b in other.blocked_by if b != task_id]
|
||||
touched = True
|
||||
if touched:
|
||||
other.updated_at = time.time()
|
||||
self._write_task_sync(task_list_id, other)
|
||||
cascaded.append(other.id)
|
||||
return True, cascaded
|
||||
|
||||
# ----- reset --------------------------------------------------------
|
||||
|
||||
def _reset_sync(self, task_list_id: str) -> None:
|
||||
with self._list_lock(task_list_id):
|
||||
tasks = self._list_tasks_sync(task_list_id)
|
||||
max_id = max((r.id for r in tasks), default=0)
|
||||
floor = self._read_highwatermark_sync(task_list_id)
|
||||
new_floor = max(max_id, floor)
|
||||
self._write_highwatermark_sync(task_list_id, new_floor)
|
||||
d = self._tasks_dir(task_list_id)
|
||||
if d.exists():
|
||||
for p in d.iterdir():
|
||||
if p.suffix == ".json":
|
||||
p.unlink()
|
||||
|
||||
# ----- claim --------------------------------------------------------
|
||||
|
||||
def _claim_sync(self, task_list_id: str, task_id: int, claimant: str) -> ClaimResult:
|
||||
with self._list_lock(task_list_id):
|
||||
current = self._read_task_sync(task_list_id, task_id)
|
||||
if current is None:
|
||||
return ClaimNotFound(kind="not_found")
|
||||
if current.status == TaskStatus.COMPLETED:
|
||||
return ClaimAlreadyCompleted(kind="already_completed")
|
||||
if current.owner is not None and current.owner != claimant:
|
||||
return ClaimAlreadyOwned(kind="already_owned", by=current.owner)
|
||||
unresolved_blockers: list[int] = []
|
||||
for b in current.blocked_by:
|
||||
blocker = self._read_task_sync(task_list_id, b)
|
||||
if blocker is not None and blocker.status != TaskStatus.COMPLETED:
|
||||
unresolved_blockers.append(b)
|
||||
if unresolved_blockers:
|
||||
return ClaimBlocked(kind="blocked", by=unresolved_blockers)
|
||||
new, _ = self._update_task_unsafe(task_list_id, current, owner=claimant)
|
||||
return ClaimOk(kind="ok", record=new)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Process-wide singleton (small, stateless wrapper)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_default_store: TaskStore | None = None
|
||||
|
||||
|
||||
def get_task_store() -> TaskStore:
|
||||
"""Process-wide default TaskStore (resolves HIVE_HOME at first call).
|
||||
|
||||
Tests should construct a TaskStore directly with hive_root=tmp_path
|
||||
rather than relying on the singleton.
|
||||
"""
|
||||
global _default_store
|
||||
if _default_store is None:
|
||||
_default_store = TaskStore()
|
||||
return _default_store
|
||||
|
||||
|
||||
# Convenience for tests / utilities.
|
||||
def fingerprint_for_test(task_list_id: str, hive_root: Path) -> Iterable[Path]:
|
||||
"""Yield every file under a list root — used by tests to assert
|
||||
byte-equivalence pre/post shutdown.
|
||||
"""
|
||||
root = task_list_path(task_list_id, hive_root=hive_root)
|
||||
if not root.exists():
|
||||
return []
|
||||
return sorted(root.rglob("*"))
|
||||
@@ -0,0 +1,280 @@
|
||||
"""End-to-end tests:
|
||||
|
||||
- Session task tools fire EventBus events
|
||||
- REST routes return correct snapshots
|
||||
- run_parallel_workers-style flow stamps assigned_session
|
||||
- Durability: store survives a process boundary (subprocess)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from aiohttp import web
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
from framework.host.event_bus import AgentEvent, EventBus, EventType
|
||||
from framework.llm.provider import ToolUse
|
||||
from framework.loader.tool_registry import ToolRegistry
|
||||
from framework.tasks import TaskListRole, TaskStore
|
||||
from framework.tasks.events import set_default_event_bus
|
||||
from framework.tasks.hooks import clear_hooks
|
||||
from framework.tasks.tools import register_colony_template_tools, register_task_tools
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_hooks() -> None:
|
||||
clear_hooks()
|
||||
yield
|
||||
clear_hooks()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(tmp_path: Path) -> TaskStore:
|
||||
return TaskStore(hive_root=tmp_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def registry(store: TaskStore) -> ToolRegistry:
|
||||
reg = ToolRegistry()
|
||||
register_task_tools(reg, store=store)
|
||||
register_colony_template_tools(reg, colony_id="abc", store=store)
|
||||
return reg
|
||||
|
||||
|
||||
async def _invoke(registry: ToolRegistry, name: str, **inputs):
|
||||
executor = registry.get_executor()
|
||||
result = executor(ToolUse(id=f"call_{name}", name=name, input=inputs))
|
||||
if asyncio.iscoroutine(result):
|
||||
result = await result
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EventBus integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_created_emits_event(registry: ToolRegistry) -> None:
|
||||
bus = EventBus()
|
||||
set_default_event_bus(bus)
|
||||
received: list[AgentEvent] = []
|
||||
|
||||
async def handler(ev: AgentEvent) -> None:
|
||||
received.append(ev)
|
||||
|
||||
bus.subscribe([EventType.TASK_CREATED], handler)
|
||||
|
||||
token = ToolRegistry.set_execution_context(agent_id="alice", task_list_id="session:alice:s1")
|
||||
try:
|
||||
await _invoke(registry, "task_create", subject="hello")
|
||||
finally:
|
||||
ToolRegistry.reset_execution_context(token)
|
||||
|
||||
# Allow the publish to fan out.
|
||||
await asyncio.sleep(0.05)
|
||||
assert len(received) == 1
|
||||
assert received[0].type == EventType.TASK_CREATED
|
||||
assert received[0].data["task"]["subject"] == "hello"
|
||||
set_default_event_bus(None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_updated_emits_event(registry: ToolRegistry) -> None:
|
||||
bus = EventBus()
|
||||
set_default_event_bus(bus)
|
||||
received: list[AgentEvent] = []
|
||||
|
||||
async def handler(ev: AgentEvent) -> None:
|
||||
received.append(ev)
|
||||
|
||||
bus.subscribe([EventType.TASK_UPDATED], handler)
|
||||
|
||||
token = ToolRegistry.set_execution_context(agent_id="alice", task_list_id="session:alice:s1")
|
||||
try:
|
||||
await _invoke(registry, "task_create", subject="x")
|
||||
await _invoke(registry, "task_update", id=1, status="in_progress")
|
||||
finally:
|
||||
ToolRegistry.reset_execution_context(token)
|
||||
await asyncio.sleep(0.05)
|
||||
assert len(received) >= 1
|
||||
assert received[0].type == EventType.TASK_UPDATED
|
||||
set_default_event_bus(None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# REST routes integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def http_client(tmp_path: Path) -> TestClient:
|
||||
"""Spin up a stripped-down aiohttp app exposing only the task routes."""
|
||||
# Point the default TaskStore at the tmp_path so routes see our test data.
|
||||
os.environ["HIVE_HOME"] = str(tmp_path)
|
||||
# Force a fresh singleton.
|
||||
import framework.tasks.store as _store_mod
|
||||
|
||||
_store_mod._default_store = None
|
||||
|
||||
from framework.server.routes_tasks import register_routes
|
||||
|
||||
app = web.Application()
|
||||
register_routes(app)
|
||||
server = TestServer(app)
|
||||
client = TestClient(server)
|
||||
await client.start_server()
|
||||
yield client
|
||||
await client.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rest_get_task_list_404(http_client: TestClient) -> None:
|
||||
resp = await http_client.get("/api/tasks/session:nope:nope")
|
||||
assert resp.status == 404
|
||||
body = await resp.json()
|
||||
assert body["task_list_id"] == "session:nope:nope"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rest_get_task_list_after_create(http_client: TestClient) -> None:
|
||||
# Create a list + task via the store directly so we don't have to mount
|
||||
# the tools just for this test.
|
||||
from framework.tasks import get_task_store
|
||||
|
||||
store = get_task_store()
|
||||
await store.ensure_task_list("session:alice:s1", role=TaskListRole.SESSION)
|
||||
await store.create_task("session:alice:s1", subject="abc")
|
||||
|
||||
resp = await http_client.get("/api/tasks/session:alice:s1")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["task_list_id"] == "session:alice:s1"
|
||||
assert body["role"] == "session"
|
||||
assert len(body["tasks"]) == 1
|
||||
assert body["tasks"][0]["subject"] == "abc"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rest_colony_lists(http_client: TestClient) -> None:
|
||||
resp = await http_client.get("/api/colonies/test_colony/task_lists?queen_session_id=sess123")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["template_task_list_id"] == "colony:test_colony"
|
||||
assert body["queen_session_task_list_id"] == "session:queen:sess123"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cross-process durability — write in subprocess A, read in subprocess B.
|
||||
# Demonstrates the "task survives runtime restart" guarantee.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_durability_across_subprocesses(tmp_path: Path) -> None:
|
||||
env = dict(os.environ)
|
||||
env["HIVE_HOME"] = str(tmp_path)
|
||||
env["PYTHONUNBUFFERED"] = "1"
|
||||
|
||||
write_script = """
|
||||
import asyncio
|
||||
from framework.tasks import TaskStore, TaskListRole
|
||||
|
||||
async def main():
|
||||
s = TaskStore()
|
||||
await s.ensure_task_list('session:a:b', role=TaskListRole.SESSION)
|
||||
rec = await s.create_task('session:a:b', subject='persisted')
|
||||
print(rec.id)
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
out = subprocess.run(
|
||||
[sys.executable, "-c", write_script],
|
||||
env=env,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
written_id = int(out.stdout.strip())
|
||||
assert written_id == 1
|
||||
|
||||
read_script = """
|
||||
import asyncio
|
||||
from framework.tasks import TaskStore
|
||||
|
||||
async def main():
|
||||
s = TaskStore()
|
||||
rs = await s.list_tasks('session:a:b')
|
||||
print(len(rs), rs[0].subject if rs else '')
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
out2 = subprocess.run(
|
||||
[sys.executable, "-c", read_script],
|
||||
env=env,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
count, subject = out2.stdout.strip().split(" ", 1)
|
||||
assert count == "1"
|
||||
assert subject == "persisted"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# "run_parallel_workers" style flow at the storage level.
|
||||
# Validates plan-and-spawn pattern: queen publishes templates, then stamps
|
||||
# assigned_session per spawned worker.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_template_assignment_flow(store: TaskStore) -> None:
|
||||
template_id = "colony:swarm"
|
||||
await store.ensure_task_list(template_id, role=TaskListRole.TEMPLATE)
|
||||
rec1 = await store.create_task(template_id, subject="crawl A")
|
||||
rec2 = await store.create_task(template_id, subject="crawl B")
|
||||
|
||||
# Simulate run_parallel_workers stamping after spawn.
|
||||
await store.update_task(
|
||||
template_id,
|
||||
rec1.id,
|
||||
metadata_patch={"assigned_session": "session:w1:w1", "assigned_worker_id": "w1"},
|
||||
)
|
||||
await store.update_task(
|
||||
template_id,
|
||||
rec2.id,
|
||||
metadata_patch={"assigned_session": "session:w2:w2", "assigned_worker_id": "w2"},
|
||||
)
|
||||
|
||||
rs = await store.list_tasks(template_id)
|
||||
assert all(r.metadata.get("assigned_worker_id") for r in rs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reset preserves byte-equivalence semantics (durability under graceful op)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graceful_no_op_preserves_files(store: TaskStore, tmp_path: Path) -> None:
|
||||
"""The store has no shutdown hook — touching it never deletes files."""
|
||||
list_id = "session:a:b"
|
||||
await store.ensure_task_list(list_id, role=TaskListRole.SESSION)
|
||||
rec = await store.create_task(list_id, subject="x")
|
||||
pre = sorted((tmp_path).rglob("*.json"))
|
||||
pre_bytes = {p.name: p.read_bytes() for p in pre}
|
||||
|
||||
# Simulate "agent loop teardown" — should be a no-op.
|
||||
# (No method to call — the absence of teardown hooks IS the test.)
|
||||
post = sorted((tmp_path).rglob("*.json"))
|
||||
assert {p.name for p in post} == {p.name for p in pre}
|
||||
for p in post:
|
||||
assert p.read_bytes() == pre_bytes[p.name]
|
||||
assert rec.id == 1
|
||||
@@ -0,0 +1,188 @@
|
||||
"""Integration tests that wire multiple subsystems together.
|
||||
|
||||
Verifies the plan-and-spawn pattern end-to-end:
|
||||
- Queen authors colony template entries (via colony_template_add)
|
||||
- "spawn" stamps assigned_session metadata + emits the right event
|
||||
- Workers operate on their own session list (no fall-through)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.host.event_bus import AgentEvent, EventBus, EventType
|
||||
from framework.llm.provider import ToolUse
|
||||
from framework.loader.tool_registry import ToolRegistry
|
||||
from framework.tasks import TaskListRole, TaskStore
|
||||
from framework.tasks.events import (
|
||||
emit_colony_template_assignment,
|
||||
set_default_event_bus,
|
||||
)
|
||||
from framework.tasks.hooks import clear_hooks
|
||||
from framework.tasks.scoping import (
|
||||
colony_task_list_id,
|
||||
session_task_list_id,
|
||||
)
|
||||
from framework.tasks.tools import register_colony_template_tools, register_task_tools
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_hooks() -> None:
|
||||
clear_hooks()
|
||||
yield
|
||||
clear_hooks()
|
||||
|
||||
|
||||
async def _invoke(reg: ToolRegistry, name: str, **inputs):
|
||||
executor = reg.get_executor()
|
||||
result = executor(ToolUse(id=f"call_{name}", name=name, input=inputs))
|
||||
if asyncio.iscoroutine(result):
|
||||
result = await result
|
||||
return result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_queen_plans_workers_pick_up(tmp_path: Path) -> None:
|
||||
"""Queen authors a 3-step plan; we simulate spawning 3 workers, each
|
||||
associated with one template entry. Each worker writes to its own
|
||||
session list. The colony template gets stamped with assigned_session.
|
||||
"""
|
||||
bus = EventBus()
|
||||
set_default_event_bus(bus)
|
||||
received: list[AgentEvent] = []
|
||||
|
||||
async def handler(ev: AgentEvent) -> None:
|
||||
received.append(ev)
|
||||
|
||||
bus.subscribe(
|
||||
[
|
||||
EventType.TASK_CREATED,
|
||||
EventType.TASK_UPDATED,
|
||||
EventType.COLONY_TEMPLATE_ASSIGNMENT,
|
||||
],
|
||||
handler,
|
||||
)
|
||||
|
||||
store = TaskStore(hive_root=tmp_path)
|
||||
queen_reg = ToolRegistry()
|
||||
register_task_tools(queen_reg, store=store)
|
||||
register_colony_template_tools(queen_reg, colony_id="alpha", store=store)
|
||||
|
||||
# 1. Queen authors the plan.
|
||||
qtoken = ToolRegistry.set_execution_context(
|
||||
agent_id="queen",
|
||||
task_list_id=session_task_list_id("queen", "qsess"),
|
||||
colony_id="alpha",
|
||||
)
|
||||
try:
|
||||
for subject in ("crawl A", "crawl B", "crawl C"):
|
||||
r = await _invoke(queen_reg, "colony_template_add", subject=subject)
|
||||
assert json.loads(r.content)["success"] is True
|
||||
|
||||
# Verify the colony template now has 3 entries.
|
||||
list_result = await _invoke(queen_reg, "colony_template_list")
|
||||
body = json.loads(list_result.content)
|
||||
assert body["count"] == 3
|
||||
template_entries = body["tasks"]
|
||||
finally:
|
||||
ToolRegistry.reset_execution_context(qtoken)
|
||||
|
||||
template_list_id = colony_task_list_id("alpha")
|
||||
|
||||
# 2. Simulate spawning a worker per template entry: stamp the
|
||||
# assigned_session and emit the assignment event.
|
||||
worker_ids = ["w1", "w2", "w3"]
|
||||
for entry, wid in zip(template_entries, worker_ids, strict=True):
|
||||
await store.update_task(
|
||||
template_list_id,
|
||||
entry["id"],
|
||||
metadata_patch={
|
||||
"assigned_session": session_task_list_id(wid, wid),
|
||||
"assigned_worker_id": wid,
|
||||
},
|
||||
)
|
||||
await emit_colony_template_assignment(
|
||||
colony_id="alpha",
|
||||
task_id=entry["id"],
|
||||
assigned_session=session_task_list_id(wid, wid),
|
||||
assigned_worker_id=wid,
|
||||
)
|
||||
|
||||
# 3. Each worker operates on its OWN session list.
|
||||
for wid in worker_ids:
|
||||
worker_reg = ToolRegistry()
|
||||
register_task_tools(worker_reg, store=store)
|
||||
wtoken = ToolRegistry.set_execution_context(agent_id=wid, task_list_id=session_task_list_id(wid, wid))
|
||||
try:
|
||||
await _invoke(worker_reg, "task_create", subject=f"setup for {wid}")
|
||||
await _invoke(worker_reg, "task_update", id=1, status="in_progress")
|
||||
finally:
|
||||
ToolRegistry.reset_execution_context(wtoken)
|
||||
|
||||
# 4. Verify the colony template entries are stamped + workers have
|
||||
# their own private lists.
|
||||
template_after = await store.list_tasks(template_list_id)
|
||||
assert all(t.metadata.get("assigned_worker_id") in {"w1", "w2", "w3"} for t in template_after)
|
||||
|
||||
for wid in worker_ids:
|
||||
worker_tasks = await store.list_tasks(session_task_list_id(wid, wid))
|
||||
assert len(worker_tasks) == 1
|
||||
assert worker_tasks[0].owner == wid # auto-stamped on in_progress
|
||||
assert worker_tasks[0].subject == f"setup for {wid}"
|
||||
|
||||
# 5. Confirm the assignment events fired.
|
||||
await asyncio.sleep(0.05)
|
||||
assignments = [e for e in received if e.type == EventType.COLONY_TEMPLATE_ASSIGNMENT]
|
||||
assert len(assignments) == 3
|
||||
|
||||
set_default_event_bus(None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_tools_never_touch_template(tmp_path: Path) -> None:
|
||||
"""The four session tools must operate exclusively on the session list.
|
||||
|
||||
Even when colony_id is set in execution context, task_create writes to
|
||||
session list, not the template.
|
||||
"""
|
||||
store = TaskStore(hive_root=tmp_path)
|
||||
reg = ToolRegistry()
|
||||
register_task_tools(reg, store=store)
|
||||
|
||||
token = ToolRegistry.set_execution_context(
|
||||
agent_id="alice",
|
||||
task_list_id=session_task_list_id("alice", "sess1"),
|
||||
colony_id="alpha", # has colony_id but we still write to session
|
||||
)
|
||||
try:
|
||||
await _invoke(reg, "task_create", subject="my work")
|
||||
finally:
|
||||
ToolRegistry.reset_execution_context(token)
|
||||
|
||||
# Session list got the task.
|
||||
session_tasks = await store.list_tasks(session_task_list_id("alice", "sess1"))
|
||||
assert len(session_tasks) == 1
|
||||
|
||||
# Colony template MUST be empty (no leakage).
|
||||
assert not await store.list_exists(colony_task_list_id("alpha"))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_persisted_handle(tmp_path: Path) -> None:
|
||||
"""A session list created in 'session A' is still readable as long as
|
||||
we resolve to the same task_list_id."""
|
||||
store = TaskStore(hive_root=tmp_path)
|
||||
list_id = session_task_list_id("alice", "sess_persistent")
|
||||
|
||||
await store.ensure_task_list(list_id, role=TaskListRole.SESSION)
|
||||
await store.create_task(list_id, subject="a")
|
||||
await store.create_task(list_id, subject="b")
|
||||
|
||||
# Simulate a fresh process / "resume" — same hive_root, same list_id.
|
||||
store2 = TaskStore(hive_root=tmp_path)
|
||||
rs = await store2.list_tasks(list_id)
|
||||
assert [t.subject for t in rs] == ["a", "b"]
|
||||
@@ -0,0 +1,121 @@
|
||||
"""Tests for the periodic task-reminder logic.
|
||||
|
||||
The reminder state is a small counter machine; the policy is:
|
||||
- Bump on each iteration
|
||||
- Reset to zero on any task op tool call (task_create / task_update /
|
||||
colony_template_*)
|
||||
- When ``turns_since_task_op >= REMINDER_THRESHOLD_TURNS`` AND
|
||||
``turns_since_last_reminder >= REMINDER_COOLDOWN_TURNS`` AND there
|
||||
are open tasks, fire a reminder
|
||||
|
||||
The build_reminder helper composes the message body — checked for the
|
||||
key behavioral nudges (granularity + completion discipline).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.tasks import TaskListRole, TaskStore
|
||||
from framework.tasks.models import TaskStatus
|
||||
from framework.tasks.reminders import (
|
||||
REMINDER_COOLDOWN_TURNS,
|
||||
REMINDER_THRESHOLD_TURNS,
|
||||
ReminderState,
|
||||
build_reminder,
|
||||
saw_task_op,
|
||||
)
|
||||
|
||||
|
||||
def test_state_bumps_each_iteration() -> None:
|
||||
s = ReminderState()
|
||||
s.on_iteration()
|
||||
s.on_iteration()
|
||||
assert s.turns_since_task_op == 2
|
||||
assert s.turns_since_last_reminder == 2
|
||||
|
||||
|
||||
def test_state_resets_on_task_op() -> None:
|
||||
s = ReminderState()
|
||||
for _ in range(5):
|
||||
s.on_iteration()
|
||||
s.on_task_op()
|
||||
assert s.turns_since_task_op == 0
|
||||
# Reminder cooldown is independent — it tracks reminders, not ops.
|
||||
assert s.turns_since_last_reminder == 5
|
||||
|
||||
|
||||
def test_should_remind_below_threshold() -> None:
|
||||
s = ReminderState()
|
||||
s.turns_since_task_op = REMINDER_THRESHOLD_TURNS - 1
|
||||
s.turns_since_last_reminder = REMINDER_COOLDOWN_TURNS
|
||||
assert not s.should_remind(has_open_tasks=True)
|
||||
|
||||
|
||||
def test_should_remind_no_tasks() -> None:
|
||||
s = ReminderState()
|
||||
s.turns_since_task_op = REMINDER_THRESHOLD_TURNS + 5
|
||||
s.turns_since_last_reminder = REMINDER_COOLDOWN_TURNS + 5
|
||||
assert not s.should_remind(has_open_tasks=False)
|
||||
|
||||
|
||||
def test_should_remind_at_threshold() -> None:
|
||||
s = ReminderState()
|
||||
s.turns_since_task_op = REMINDER_THRESHOLD_TURNS
|
||||
s.turns_since_last_reminder = REMINDER_COOLDOWN_TURNS
|
||||
assert s.should_remind(has_open_tasks=True)
|
||||
|
||||
|
||||
def test_cooldown_blocks_back_to_back() -> None:
|
||||
s = ReminderState()
|
||||
s.turns_since_task_op = REMINDER_THRESHOLD_TURNS + 5
|
||||
s.on_reminder_sent()
|
||||
assert not s.should_remind(has_open_tasks=True)
|
||||
|
||||
|
||||
def test_saw_task_op_recognizes_mutating_tools() -> None:
|
||||
assert saw_task_op(["task_create"])
|
||||
assert saw_task_op(["read_file", "task_update"])
|
||||
assert saw_task_op(["colony_template_add"])
|
||||
# Reads do NOT reset the counter — important: model could read forever
|
||||
# without making progress.
|
||||
assert not saw_task_op(["task_list", "task_get"])
|
||||
assert not saw_task_op([])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_reminder_includes_open_tasks(tmp_path: Path) -> None:
|
||||
store = TaskStore(hive_root=tmp_path)
|
||||
await store.ensure_task_list("session:a:b", role=TaskListRole.SESSION)
|
||||
await store.create_task("session:a:b", subject="step 1")
|
||||
rec2 = await store.create_task("session:a:b", subject="step 2")
|
||||
await store.create_task("session:a:b", subject="step 3")
|
||||
# Mark #2 in_progress so the reminder mentions it.
|
||||
await store.update_task("session:a:b", rec2.id, status=TaskStatus.IN_PROGRESS)
|
||||
records = await store.list_tasks("session:a:b")
|
||||
|
||||
body = build_reminder(records)
|
||||
|
||||
assert "task_reminder" in body
|
||||
assert "step 1" in body
|
||||
assert "step 2" in body
|
||||
assert "step 3" in body
|
||||
# Granularity nudge present.
|
||||
assert "umbrella" in body.lower() or "atomic" in body.lower()
|
||||
# Completion-discipline nudge present.
|
||||
assert "completed" in body.lower()
|
||||
# Anti-nag boilerplate remains present.
|
||||
assert "NEVER mention this reminder to the user" in body
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_reminder_empty_when_no_open(tmp_path: Path) -> None:
|
||||
store = TaskStore(hive_root=tmp_path)
|
||||
await store.ensure_task_list("session:a:b", role=TaskListRole.SESSION)
|
||||
rec = await store.create_task("session:a:b", subject="done already")
|
||||
await store.update_task("session:a:b", rec.id, status=TaskStatus.COMPLETED)
|
||||
records = await store.list_tasks("session:a:b")
|
||||
|
||||
assert build_reminder(records) == ""
|
||||
@@ -0,0 +1,65 @@
|
||||
"""Tests for resolve_task_list_id."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.tasks.scoping import (
|
||||
colony_task_list_id,
|
||||
parse_task_list_id,
|
||||
resolve_task_list_id,
|
||||
session_task_list_id,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FakeCtx:
|
||||
agent_id: str = ""
|
||||
run_id: str = ""
|
||||
execution_id: str = ""
|
||||
stream_id: str = ""
|
||||
task_list_id: str | None = None
|
||||
|
||||
|
||||
def test_session_helper() -> None:
|
||||
assert session_task_list_id("a", "b") == "session:a:b"
|
||||
|
||||
|
||||
def test_colony_helper() -> None:
|
||||
assert colony_task_list_id("c") == "colony:c"
|
||||
|
||||
|
||||
def test_parse_session() -> None:
|
||||
parts = parse_task_list_id("session:agent:sess")
|
||||
assert parts == {"kind": "session", "agent_id": "agent", "session_id": "sess"}
|
||||
|
||||
|
||||
def test_parse_colony() -> None:
|
||||
parts = parse_task_list_id("colony:abc")
|
||||
assert parts == {"kind": "colony", "colony_id": "abc"}
|
||||
|
||||
|
||||
def test_resolve_uses_existing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.delenv("HIVE_TASK_LIST_ID", raising=False)
|
||||
ctx = FakeCtx(agent_id="x", run_id="r1", task_list_id="session:x:r1")
|
||||
assert resolve_task_list_id(ctx) == "session:x:r1"
|
||||
|
||||
|
||||
def test_resolve_env_override(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("HIVE_TASK_LIST_ID", "forced")
|
||||
ctx = FakeCtx(agent_id="x", run_id="r1")
|
||||
assert resolve_task_list_id(ctx) == "forced"
|
||||
|
||||
|
||||
def test_resolve_synthesizes_session(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.delenv("HIVE_TASK_LIST_ID", raising=False)
|
||||
ctx = FakeCtx(agent_id="alice", run_id="r123")
|
||||
assert resolve_task_list_id(ctx) == "session:alice:r123"
|
||||
|
||||
|
||||
def test_resolve_falls_back_to_unscoped(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.delenv("HIVE_TASK_LIST_ID", raising=False)
|
||||
ctx = FakeCtx(agent_id="alice")
|
||||
assert resolve_task_list_id(ctx).startswith("unscoped:")
|
||||
@@ -0,0 +1,273 @@
|
||||
"""Tests for the file-backed task store.
|
||||
|
||||
Concurrency / id-monotonicity / cascade / claim / reset — the engineering
|
||||
primitives the rest of the system relies on.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.tasks import TaskListRole, TaskStatus, TaskStore
|
||||
from framework.tasks.models import ClaimAlreadyOwned, ClaimBlocked, ClaimNotFound, ClaimOk
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(tmp_path: Path) -> TaskStore:
|
||||
return TaskStore(hive_root=tmp_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def list_id() -> str:
|
||||
return "session:test_agent:test_session"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Basic CRUD
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_get(store: TaskStore, list_id: str) -> None:
|
||||
await store.ensure_task_list(list_id, role=TaskListRole.SESSION)
|
||||
rec = await store.create_task(list_id, subject="hi")
|
||||
assert rec.id == 1
|
||||
fetched = await store.get_task(list_id, 1)
|
||||
assert fetched is not None
|
||||
assert fetched.subject == "hi"
|
||||
assert fetched.status == TaskStatus.PENDING
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_missing_returns_none(store: TaskStore, list_id: str) -> None:
|
||||
assert await store.get_task(list_id, 999) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_ascending(store: TaskStore, list_id: str) -> None:
|
||||
await store.ensure_task_list(list_id, role=TaskListRole.SESSION)
|
||||
await store.create_task(list_id, subject="a")
|
||||
await store.create_task(list_id, subject="b")
|
||||
await store.create_task(list_id, subject="c")
|
||||
rs = await store.list_tasks(list_id)
|
||||
assert [r.id for r in rs] == [1, 2, 3]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_filters_internal(store: TaskStore, list_id: str) -> None:
|
||||
await store.ensure_task_list(list_id, role=TaskListRole.SESSION)
|
||||
await store.create_task(list_id, subject="visible")
|
||||
await store.create_task(list_id, subject="hidden", metadata={"_internal": True})
|
||||
public = await store.list_tasks(list_id)
|
||||
assert len(public) == 1
|
||||
all_ = await store.list_tasks(list_id, include_internal=True)
|
||||
assert len(all_) == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Concurrent creation: two parallel calls -> N and N+1
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_create_distinct_ids(store: TaskStore, list_id: str) -> None:
|
||||
await store.ensure_task_list(list_id, role=TaskListRole.SESSION)
|
||||
results = await asyncio.gather(*(store.create_task(list_id, subject=f"t{i}") for i in range(20)))
|
||||
ids = sorted(r.id for r in results)
|
||||
assert ids == list(range(1, 21))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Update + change detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_returns_changed_fields(store: TaskStore, list_id: str) -> None:
|
||||
await store.ensure_task_list(list_id, role=TaskListRole.SESSION)
|
||||
rec = await store.create_task(list_id, subject="orig")
|
||||
new, fields = await store.update_task(list_id, rec.id, subject="orig", status=TaskStatus.IN_PROGRESS)
|
||||
assert fields == ["status"] # subject unchanged shouldn't appear
|
||||
assert new.status == TaskStatus.IN_PROGRESS
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_missing_returns_none(store: TaskStore, list_id: str) -> None:
|
||||
new, fields = await store.update_task(list_id, 42, subject="x")
|
||||
assert new is None
|
||||
assert fields == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metadata_patch_merges_and_deletes(store: TaskStore, list_id: str) -> None:
|
||||
await store.ensure_task_list(list_id, role=TaskListRole.SESSION)
|
||||
rec = await store.create_task(list_id, subject="x", metadata={"a": 1, "b": 2})
|
||||
new, _ = await store.update_task(list_id, rec.id, metadata_patch={"a": 10, "b": None})
|
||||
assert new.metadata == {"a": 10}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bidirectional blocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocks_bidirectional(store: TaskStore, list_id: str) -> None:
|
||||
await store.ensure_task_list(list_id, role=TaskListRole.SESSION)
|
||||
a = await store.create_task(list_id, subject="a")
|
||||
b = await store.create_task(list_id, subject="b")
|
||||
new_a, _ = await store.update_task(list_id, a.id, add_blocks=[b.id])
|
||||
assert b.id in new_a.blocks
|
||||
fetched_b = await store.get_task(list_id, b.id)
|
||||
assert a.id in fetched_b.blocked_by
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocked_by_bidirectional(store: TaskStore, list_id: str) -> None:
|
||||
await store.ensure_task_list(list_id, role=TaskListRole.SESSION)
|
||||
a = await store.create_task(list_id, subject="a")
|
||||
b = await store.create_task(list_id, subject="b")
|
||||
new_b, _ = await store.update_task(list_id, b.id, add_blocked_by=[a.id])
|
||||
assert a.id in new_b.blocked_by
|
||||
fetched_a = await store.get_task(list_id, a.id)
|
||||
assert b.id in fetched_a.blocks
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Delete: highwatermark + cascade
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_increments_highwatermark(store: TaskStore, list_id: str) -> None:
|
||||
await store.ensure_task_list(list_id, role=TaskListRole.SESSION)
|
||||
await store.create_task(list_id, subject="a")
|
||||
b = await store.create_task(list_id, subject="b")
|
||||
deleted, _ = await store.delete_task(list_id, b.id)
|
||||
assert deleted
|
||||
new = await store.create_task(list_id, subject="c")
|
||||
assert new.id == b.id + 1, "deleted ids must never be reused"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_cascades_blocks(store: TaskStore, list_id: str) -> None:
|
||||
await store.ensure_task_list(list_id, role=TaskListRole.SESSION)
|
||||
a = await store.create_task(list_id, subject="a")
|
||||
b = await store.create_task(list_id, subject="b")
|
||||
c = await store.create_task(list_id, subject="c")
|
||||
await store.update_task(list_id, a.id, add_blocks=[b.id])
|
||||
await store.update_task(list_id, c.id, add_blocked_by=[b.id])
|
||||
_, cascade = await store.delete_task(list_id, b.id)
|
||||
assert sorted(cascade) == sorted([a.id, c.id])
|
||||
fetched_a = await store.get_task(list_id, a.id)
|
||||
fetched_c = await store.get_task(list_id, c.id)
|
||||
assert b.id not in fetched_a.blocks
|
||||
assert b.id not in fetched_c.blocked_by
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_missing_returns_false(store: TaskStore, list_id: str) -> None:
|
||||
deleted, cascade = await store.delete_task(list_id, 42)
|
||||
assert not deleted
|
||||
assert cascade == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reset preserves high-water-mark
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_preserves_floor(store: TaskStore, list_id: str) -> None:
|
||||
await store.ensure_task_list(list_id, role=TaskListRole.SESSION)
|
||||
for _ in range(5):
|
||||
await store.create_task(list_id, subject="x")
|
||||
await store.reset_task_list(list_id)
|
||||
new = await store.create_task(list_id, subject="post-reset")
|
||||
assert new.id == 6
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Claim semantics (used by run_parallel_workers)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claim_ok(store: TaskStore, list_id: str) -> None:
|
||||
await store.ensure_task_list(list_id, role=TaskListRole.TEMPLATE)
|
||||
rec = await store.create_task(list_id, subject="x")
|
||||
result = await store.claim_task_with_busy_check(list_id, rec.id, "agent_a")
|
||||
assert isinstance(result, ClaimOk)
|
||||
assert result.record.owner == "agent_a"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claim_already_owned(store: TaskStore, list_id: str) -> None:
|
||||
await store.ensure_task_list(list_id, role=TaskListRole.TEMPLATE)
|
||||
rec = await store.create_task(list_id, subject="x", owner="agent_a")
|
||||
result = await store.claim_task_with_busy_check(list_id, rec.id, "agent_b")
|
||||
assert isinstance(result, ClaimAlreadyOwned)
|
||||
assert result.by == "agent_a"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claim_not_found(store: TaskStore, list_id: str) -> None:
|
||||
result = await store.claim_task_with_busy_check(list_id, 999, "agent_a")
|
||||
assert isinstance(result, ClaimNotFound)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claim_blocked(store: TaskStore, list_id: str) -> None:
|
||||
await store.ensure_task_list(list_id, role=TaskListRole.TEMPLATE)
|
||||
a = await store.create_task(list_id, subject="prereq")
|
||||
b = await store.create_task(list_id, subject="dep")
|
||||
await store.update_task(list_id, b.id, add_blocked_by=[a.id])
|
||||
# a is still pending -> b blocked.
|
||||
result = await store.claim_task_with_busy_check(list_id, b.id, "agent_a")
|
||||
assert isinstance(result, ClaimBlocked)
|
||||
assert a.id in result.by
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Meta lifecycle: ensure_task_list is idempotent and tracks last_seen
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_task_list_idempotent(store: TaskStore, list_id: str) -> None:
|
||||
m1 = await store.ensure_task_list(list_id, role=TaskListRole.SESSION, session_id="s1")
|
||||
m2 = await store.ensure_task_list(list_id, role=TaskListRole.SESSION, session_id="s2")
|
||||
assert m1.created_at == m2.created_at # same dir
|
||||
assert "s1" in m2.last_seen_session_ids
|
||||
assert "s2" in m2.last_seen_session_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_task_list_caps_history(store: TaskStore, list_id: str) -> None:
|
||||
for i in range(15):
|
||||
await store.ensure_task_list(list_id, role=TaskListRole.SESSION, session_id=f"s{i}")
|
||||
meta = await store.get_meta(list_id)
|
||||
assert len(meta.last_seen_session_ids) == 10
|
||||
assert "s14" in meta.last_seen_session_ids
|
||||
assert "s4" not in meta.last_seen_session_ids
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Path resolution sanity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_colony_path(store: TaskStore, tmp_path: Path) -> None:
|
||||
await store.ensure_task_list("colony:abc", role=TaskListRole.TEMPLATE)
|
||||
assert (tmp_path / "colonies" / "abc" / "tasks" / "meta.json").exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_path(store: TaskStore, tmp_path: Path) -> None:
|
||||
await store.ensure_task_list("session:agent_x:sess_y", role=TaskListRole.SESSION)
|
||||
p = tmp_path / "agents" / "agent_x" / "sessions" / "sess_y" / "tasks" / "meta.json"
|
||||
assert p.exists()
|
||||
@@ -0,0 +1,488 @@
|
||||
"""End-to-end tool tests via ToolRegistry.get_executor()."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.llm.provider import ToolUse
|
||||
from framework.loader.tool_registry import ToolRegistry
|
||||
from framework.tasks import TaskStore
|
||||
from framework.tasks.hooks import (
|
||||
HOOK_TASK_COMPLETED,
|
||||
HOOK_TASK_CREATED,
|
||||
BlockingHookError,
|
||||
clear_hooks,
|
||||
register_hook,
|
||||
)
|
||||
from framework.tasks.tools import register_colony_template_tools, register_task_tools
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_hooks() -> None:
|
||||
clear_hooks()
|
||||
yield
|
||||
clear_hooks()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(tmp_path: Path) -> TaskStore:
|
||||
return TaskStore(hive_root=tmp_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def registry_with_session_tools(store: TaskStore) -> ToolRegistry:
|
||||
reg = ToolRegistry()
|
||||
register_task_tools(reg, store=store)
|
||||
return reg
|
||||
|
||||
|
||||
async def _invoke(registry: ToolRegistry, name: str, **inputs):
|
||||
"""Invoke a tool via the registry's executor protocol."""
|
||||
executor = registry.get_executor()
|
||||
result = executor(ToolUse(id=f"call_{name}", name=name, input=inputs))
|
||||
if asyncio.iscoroutine(result):
|
||||
result = await result
|
||||
return result
|
||||
|
||||
|
||||
def _set_ctx(*, agent_id: str, task_list_id: str, **extra):
|
||||
return ToolRegistry.set_execution_context(agent_id=agent_id, task_list_id=task_list_id, **extra)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session tools — happy paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_then_list(registry_with_session_tools: ToolRegistry) -> None:
|
||||
reg = registry_with_session_tools
|
||||
list_id = "session:agent_a:sess_1"
|
||||
token = _set_ctx(agent_id="agent_a", task_list_id=list_id)
|
||||
try:
|
||||
result = await _invoke(reg, "task_create", subject="Plan retrieval")
|
||||
assert result.is_error is False
|
||||
body = json.loads(result.content)
|
||||
assert body["success"] is True
|
||||
assert body["task_id"] == 1
|
||||
|
||||
result2 = await _invoke(reg, "task_list")
|
||||
body2 = json.loads(result2.content)
|
||||
assert body2["count"] == 1
|
||||
assert body2["tasks"][0]["subject"] == "Plan retrieval"
|
||||
finally:
|
||||
ToolRegistry.reset_execution_context(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_in_progress_auto_owner(
|
||||
registry_with_session_tools: ToolRegistry,
|
||||
) -> None:
|
||||
reg = registry_with_session_tools
|
||||
list_id = "session:agent_a:sess_1"
|
||||
token = _set_ctx(agent_id="agent_a", task_list_id=list_id)
|
||||
try:
|
||||
await _invoke(reg, "task_create", subject="x")
|
||||
result = await _invoke(reg, "task_update", id=1, status="in_progress")
|
||||
body = json.loads(result.content)
|
||||
assert body["success"] is True
|
||||
assert body["task"]["status"] == "in_progress"
|
||||
assert body["task"]["owner"] == "agent_a" # auto-filled
|
||||
finally:
|
||||
ToolRegistry.reset_execution_context(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_status_deleted(
|
||||
registry_with_session_tools: ToolRegistry,
|
||||
) -> None:
|
||||
reg = registry_with_session_tools
|
||||
list_id = "session:agent_a:sess_1"
|
||||
token = _set_ctx(agent_id="agent_a", task_list_id=list_id)
|
||||
try:
|
||||
await _invoke(reg, "task_create", subject="x")
|
||||
result = await _invoke(reg, "task_update", id=1, status="deleted")
|
||||
body = json.loads(result.content)
|
||||
assert body["success"] is True
|
||||
assert body["deleted"] is True
|
||||
# Subsequent list sees nothing.
|
||||
body2 = json.loads((await _invoke(reg, "task_list")).content)
|
||||
assert body2["count"] == 0
|
||||
finally:
|
||||
ToolRegistry.reset_execution_context(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_returns_full_record(
|
||||
registry_with_session_tools: ToolRegistry,
|
||||
) -> None:
|
||||
reg = registry_with_session_tools
|
||||
list_id = "session:agent_a:sess_1"
|
||||
token = _set_ctx(agent_id="agent_a", task_list_id=list_id)
|
||||
try:
|
||||
await _invoke(reg, "task_create", subject="x", description="full body")
|
||||
result = await _invoke(reg, "task_get", id=1)
|
||||
body = json.loads(result.content)
|
||||
assert body["task"]["description"] == "full body"
|
||||
finally:
|
||||
ToolRegistry.reset_execution_context(token)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Task-not-found is non-error (so sibling tool cancellation doesn't cascade)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_not_found_is_not_error(
|
||||
registry_with_session_tools: ToolRegistry,
|
||||
) -> None:
|
||||
reg = registry_with_session_tools
|
||||
list_id = "session:agent_a:sess_1"
|
||||
token = _set_ctx(agent_id="agent_a", task_list_id=list_id)
|
||||
try:
|
||||
result = await _invoke(reg, "task_update", id=42, subject="ghost")
|
||||
# is_error must be False so the streaming executor doesn't cascade-cancel.
|
||||
assert result.is_error is False
|
||||
body = json.loads(result.content)
|
||||
assert body["success"] is False
|
||||
finally:
|
||||
ToolRegistry.reset_execution_context(token)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hooks: task_created blocking deletes the just-created task
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_batch_creates_n_tasks_atomically(
|
||||
registry_with_session_tools: ToolRegistry,
|
||||
) -> None:
|
||||
reg = registry_with_session_tools
|
||||
list_id = "session:agent_a:sess_1"
|
||||
token = _set_ctx(agent_id="agent_a", task_list_id=list_id)
|
||||
try:
|
||||
result = await _invoke(
|
||||
reg,
|
||||
"task_create_batch",
|
||||
tasks=[
|
||||
{"subject": "step 1", "active_form": "Doing 1"},
|
||||
{"subject": "step 2"},
|
||||
{"subject": "step 3"},
|
||||
],
|
||||
)
|
||||
assert result.is_error is False
|
||||
body = json.loads(result.content)
|
||||
assert body["success"] is True
|
||||
assert body["task_ids"] == [1, 2, 3]
|
||||
# Compact summary message — references first id and the range.
|
||||
assert "#1-#3" in body["message"] or "#1, #2, #3" in body["message"]
|
||||
assert "Mark #1 in_progress" in body["message"]
|
||||
|
||||
# Sanity: list shows all three.
|
||||
body2 = json.loads((await _invoke(reg, "task_list")).content)
|
||||
assert body2["count"] == 3
|
||||
finally:
|
||||
ToolRegistry.reset_execution_context(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_batch_rejects_empty(
|
||||
registry_with_session_tools: ToolRegistry,
|
||||
) -> None:
|
||||
reg = registry_with_session_tools
|
||||
token = _set_ctx(agent_id="a", task_list_id="session:a:s")
|
||||
try:
|
||||
result = await _invoke(reg, "task_create_batch", tasks=[])
|
||||
body = json.loads(result.content)
|
||||
assert body["success"] is False
|
||||
assert "non-empty" in body["error"]
|
||||
finally:
|
||||
ToolRegistry.reset_execution_context(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_batch_rejects_malformed_spec_atomically(
|
||||
registry_with_session_tools: ToolRegistry,
|
||||
) -> None:
|
||||
"""A bad subject in the middle of the batch must reject the whole
|
||||
batch — not leave partial state on disk."""
|
||||
reg = registry_with_session_tools
|
||||
token = _set_ctx(agent_id="a", task_list_id="session:a:s")
|
||||
try:
|
||||
result = await _invoke(
|
||||
reg,
|
||||
"task_create_batch",
|
||||
tasks=[{"subject": "good"}, {"subject": ""}],
|
||||
)
|
||||
body = json.loads(result.content)
|
||||
assert body["success"] is False
|
||||
# Confirm zero tasks landed.
|
||||
body2 = json.loads((await _invoke(reg, "task_list")).content)
|
||||
assert body2["count"] == 0
|
||||
finally:
|
||||
ToolRegistry.reset_execution_context(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_batch_hook_blocks_rolls_back_whole_batch(
|
||||
registry_with_session_tools: ToolRegistry,
|
||||
) -> None:
|
||||
"""If a task_created hook blocks even one task in the batch, the
|
||||
entire batch must roll back."""
|
||||
reg = registry_with_session_tools
|
||||
|
||||
# Block on the second task only.
|
||||
def selective_blocker(ctx) -> None:
|
||||
if ctx.task.subject == "block me":
|
||||
raise BlockingHookError("policy")
|
||||
|
||||
register_hook(HOOK_TASK_CREATED, selective_blocker)
|
||||
|
||||
token = _set_ctx(agent_id="a", task_list_id="session:a:s")
|
||||
try:
|
||||
result = await _invoke(
|
||||
reg,
|
||||
"task_create_batch",
|
||||
tasks=[
|
||||
{"subject": "ok 1"},
|
||||
{"subject": "block me"},
|
||||
{"subject": "ok 3"},
|
||||
],
|
||||
)
|
||||
body = json.loads(result.content)
|
||||
assert body["success"] is False
|
||||
assert "rolled back" in body["error"]
|
||||
# All three rolled back.
|
||||
body2 = json.loads((await _invoke(reg, "task_list")).content)
|
||||
assert body2["count"] == 0
|
||||
finally:
|
||||
ToolRegistry.reset_execution_context(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_batch_then_single_create_keeps_id_monotonic(
|
||||
registry_with_session_tools: ToolRegistry,
|
||||
) -> None:
|
||||
"""task_create_batch uses sequential ids; a follow-up task_create
|
||||
should pick up at the next id after the batch's highest."""
|
||||
reg = registry_with_session_tools
|
||||
token = _set_ctx(agent_id="a", task_list_id="session:a:s")
|
||||
try:
|
||||
await _invoke(
|
||||
reg,
|
||||
"task_create_batch",
|
||||
tasks=[{"subject": "a"}, {"subject": "b"}, {"subject": "c"}],
|
||||
)
|
||||
result = await _invoke(reg, "task_create", subject="d")
|
||||
body = json.loads(result.content)
|
||||
assert body["task_id"] == 4
|
||||
finally:
|
||||
ToolRegistry.reset_execution_context(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_suffix_points_to_next_pending(
|
||||
registry_with_session_tools: ToolRegistry,
|
||||
) -> None:
|
||||
"""When a task is marked completed, the result should point at the
|
||||
lowest-id pending task as a steering nudge."""
|
||||
reg = registry_with_session_tools
|
||||
list_id = "session:agent_a:sess_1"
|
||||
token = _set_ctx(agent_id="agent_a", task_list_id=list_id)
|
||||
try:
|
||||
await _invoke(reg, "task_create", subject="step 1")
|
||||
await _invoke(reg, "task_create", subject="step 2")
|
||||
await _invoke(reg, "task_create", subject="step 3")
|
||||
await _invoke(reg, "task_update", id=1, status="in_progress")
|
||||
result = await _invoke(reg, "task_update", id=1, status="completed")
|
||||
body = json.loads(result.content)
|
||||
assert body["success"] is True
|
||||
assert "Next pending: #2" in body["message"]
|
||||
assert "step 2" in body["message"]
|
||||
finally:
|
||||
ToolRegistry.reset_execution_context(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_suffix_signals_all_done(
|
||||
registry_with_session_tools: ToolRegistry,
|
||||
) -> None:
|
||||
reg = registry_with_session_tools
|
||||
list_id = "session:agent_a:sess_1"
|
||||
token = _set_ctx(agent_id="agent_a", task_list_id=list_id)
|
||||
try:
|
||||
await _invoke(reg, "task_create", subject="only step")
|
||||
await _invoke(reg, "task_update", id=1, status="in_progress")
|
||||
result = await _invoke(reg, "task_update", id=1, status="completed")
|
||||
body = json.loads(result.content)
|
||||
assert "All tasks complete" in body["message"]
|
||||
finally:
|
||||
ToolRegistry.reset_execution_context(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_suffix_skips_blocked_pending(
|
||||
registry_with_session_tools: ToolRegistry,
|
||||
) -> None:
|
||||
"""If the only pending task is blocked, the suffix should not point at
|
||||
it — fall through to "all done" or note in-progress siblings."""
|
||||
reg = registry_with_session_tools
|
||||
list_id = "session:agent_a:sess_1"
|
||||
token = _set_ctx(agent_id="agent_a", task_list_id=list_id)
|
||||
try:
|
||||
await _invoke(reg, "task_create", subject="prereq")
|
||||
await _invoke(reg, "task_create", subject="blocked dep")
|
||||
# #2 is blocked by #1.
|
||||
await _invoke(reg, "task_update", id=2, add_blocked_by=[1])
|
||||
await _invoke(reg, "task_update", id=1, status="in_progress")
|
||||
# Don't actually complete #1 — instead add an unrelated done.
|
||||
await _invoke(reg, "task_create", subject="extra step")
|
||||
await _invoke(reg, "task_update", id=3, status="in_progress")
|
||||
result = await _invoke(reg, "task_update", id=3, status="completed")
|
||||
body = json.loads(result.content)
|
||||
# #2 is still blocked by uncompleted #1, so the suffix shouldn't
|
||||
# surface it. #1 is in_progress, so the suffix highlights that.
|
||||
assert "Still in progress: #1" in body["message"]
|
||||
finally:
|
||||
ToolRegistry.reset_execution_context(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hook_blocks_task_created(
|
||||
registry_with_session_tools: ToolRegistry,
|
||||
) -> None:
|
||||
reg = registry_with_session_tools
|
||||
list_id = "session:agent_a:sess_1"
|
||||
|
||||
def blocker(ctx) -> None:
|
||||
raise BlockingHookError("test policy")
|
||||
|
||||
register_hook(HOOK_TASK_CREATED, blocker)
|
||||
token = _set_ctx(agent_id="agent_a", task_list_id=list_id)
|
||||
try:
|
||||
result = await _invoke(reg, "task_create", subject="will be aborted")
|
||||
body = json.loads(result.content)
|
||||
assert body["success"] is False
|
||||
# The task must have been rolled back.
|
||||
body2 = json.loads((await _invoke(reg, "task_list")).content)
|
||||
assert body2["count"] == 0
|
||||
finally:
|
||||
ToolRegistry.reset_execution_context(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hook_blocks_task_completed(
|
||||
registry_with_session_tools: ToolRegistry,
|
||||
) -> None:
|
||||
reg = registry_with_session_tools
|
||||
list_id = "session:agent_a:sess_1"
|
||||
|
||||
register_hook(HOOK_TASK_COMPLETED, lambda ctx: (_ for _ in ()).throw(BlockingHookError("nope")))
|
||||
token = _set_ctx(agent_id="agent_a", task_list_id=list_id)
|
||||
try:
|
||||
await _invoke(reg, "task_create", subject="x")
|
||||
await _invoke(reg, "task_update", id=1, status="in_progress")
|
||||
result = await _invoke(reg, "task_update", id=1, status="completed")
|
||||
body = json.loads(result.content)
|
||||
assert body["success"] is False
|
||||
# Status rolled back to in_progress, not stuck on completed.
|
||||
body2 = json.loads((await _invoke(reg, "task_get", id=1)).content)
|
||||
assert body2["task"]["status"] == "in_progress"
|
||||
finally:
|
||||
ToolRegistry.reset_execution_context(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hook_blocks_task_completed_never_writes(
|
||||
registry_with_session_tools: ToolRegistry,
|
||||
store: TaskStore,
|
||||
) -> None:
|
||||
"""Veto-before-write: when the task_completed hook blocks, the COMPLETED
|
||||
status must NEVER touch disk — `updated_at` should equal the value from
|
||||
the prior in_progress write, not be bumped by a transient COMPLETED
|
||||
write + rollback."""
|
||||
from framework.tasks.models import TaskStatus
|
||||
|
||||
reg = registry_with_session_tools
|
||||
list_id = "session:agent_a:sess_1"
|
||||
register_hook(HOOK_TASK_COMPLETED, lambda ctx: (_ for _ in ()).throw(BlockingHookError("nope")))
|
||||
token = _set_ctx(agent_id="agent_a", task_list_id=list_id)
|
||||
try:
|
||||
await _invoke(reg, "task_create", subject="x")
|
||||
await _invoke(reg, "task_update", id=1, status="in_progress")
|
||||
# Snapshot updated_at after the in_progress write — this is the
|
||||
# value that should persist if veto-before-write is honored.
|
||||
before = await store.get_task(list_id, 1)
|
||||
assert before is not None
|
||||
ts_before = before.updated_at
|
||||
|
||||
# Vetoed completion attempt.
|
||||
result = await _invoke(reg, "task_update", id=1, status="completed")
|
||||
body = json.loads(result.content)
|
||||
assert body["success"] is False
|
||||
|
||||
# On-disk record must be byte-identical to the pre-vet snapshot —
|
||||
# no transient COMPLETED write, no rollback updated_at bump.
|
||||
after = await store.get_task(list_id, 1)
|
||||
assert after is not None
|
||||
assert after.status == TaskStatus.IN_PROGRESS
|
||||
assert after.updated_at == ts_before, (
|
||||
"veto-before-write violated: updated_at changed, indicating a transient write happened"
|
||||
)
|
||||
finally:
|
||||
ToolRegistry.reset_execution_context(token)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Colony template tools
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def queen_registry(store: TaskStore) -> ToolRegistry:
|
||||
reg = ToolRegistry()
|
||||
register_task_tools(reg, store=store)
|
||||
register_colony_template_tools(reg, colony_id="abc", store=store)
|
||||
return reg
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_colony_template_add_and_list(queen_registry: ToolRegistry) -> None:
|
||||
reg = queen_registry
|
||||
queen_session_list = "session:queen:sess_1"
|
||||
token = _set_ctx(agent_id="queen", task_list_id=queen_session_list, colony_id="abc")
|
||||
try:
|
||||
await _invoke(reg, "colony_template_add", subject="crawl")
|
||||
await _invoke(reg, "colony_template_add", subject="parse")
|
||||
body = json.loads((await _invoke(reg, "colony_template_list")).content)
|
||||
assert body["count"] == 2
|
||||
|
||||
# The session task list should be empty — colony tools don't write there.
|
||||
body_session = json.loads((await _invoke(reg, "task_list")).content)
|
||||
assert body_session["count"] == 0
|
||||
finally:
|
||||
ToolRegistry.reset_execution_context(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_colony_template_remove(queen_registry: ToolRegistry) -> None:
|
||||
reg = queen_registry
|
||||
token = _set_ctx(agent_id="queen", task_list_id="session:queen:sess_1", colony_id="abc")
|
||||
try:
|
||||
await _invoke(reg, "colony_template_add", subject="a")
|
||||
await _invoke(reg, "colony_template_add", subject="b")
|
||||
result = await _invoke(reg, "colony_template_remove", id=2)
|
||||
body = json.loads(result.content)
|
||||
assert body["success"] is True
|
||||
# Next add gets id 3 (highwatermark preserved)
|
||||
result2 = await _invoke(reg, "colony_template_add", subject="c")
|
||||
body2 = json.loads(result2.content)
|
||||
assert body2["task_id"] == 3
|
||||
finally:
|
||||
ToolRegistry.reset_execution_context(token)
|
||||
@@ -0,0 +1,11 @@
|
||||
"""Task tools — the four session-list tools and the queen-only colony template tools."""
|
||||
|
||||
from framework.tasks.tools.register import (
|
||||
register_colony_template_tools,
|
||||
register_task_tools,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"register_colony_template_tools",
|
||||
"register_task_tools",
|
||||
]
|
||||
@@ -0,0 +1,39 @@
|
||||
"""Context resolution for task-tool executors.
|
||||
|
||||
Tool executors run synchronously inside ``ToolRegistry.get_executor()``;
|
||||
they need the calling agent's id and task_list_id to know which list to
|
||||
write to. We pull both from contextvars set by the runner /
|
||||
ColonyRuntime / orchestrator before each agent's iteration.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from framework.loader.tool_registry import _execution_context
|
||||
|
||||
|
||||
def current_context() -> dict[str, Any]:
|
||||
return dict(_execution_context.get() or {})
|
||||
|
||||
|
||||
def current_agent_id() -> str | None:
|
||||
return current_context().get("agent_id")
|
||||
|
||||
|
||||
def current_task_list_id() -> str | None:
|
||||
return current_context().get("task_list_id")
|
||||
|
||||
|
||||
def current_colony_id() -> str | None:
|
||||
return current_context().get("colony_id")
|
||||
|
||||
|
||||
def current_picked_up_from() -> tuple[str, int] | None:
|
||||
"""If this session was spawned for a colony template entry, return it."""
|
||||
raw = current_context().get("picked_up_from")
|
||||
if not raw:
|
||||
return None
|
||||
if isinstance(raw, tuple) and len(raw) == 2:
|
||||
return raw[0], int(raw[1])
|
||||
return None
|
||||
@@ -0,0 +1,238 @@
|
||||
"""Queen-only colony template tools.
|
||||
|
||||
These tools manipulate a colony's task template — the queen's spawn plan.
|
||||
They are gated to the queen of a colony at registration time
|
||||
(``register_colony_template_tools(colony_id=...)``).
|
||||
|
||||
Workers never see these tools. The four session tools (`task_create`,
|
||||
`task_update`, `task_list`, `task_get`) operate exclusively on the
|
||||
caller's session list — never the colony template.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from framework.llm.provider import Tool
|
||||
from framework.tasks.events import (
|
||||
emit_task_created,
|
||||
emit_task_deleted,
|
||||
emit_task_updated,
|
||||
)
|
||||
from framework.tasks.models import TaskRecord, TaskStatus
|
||||
from framework.tasks.scoping import colony_task_list_id
|
||||
from framework.tasks.store import _UNSET_SENTINEL, TaskStore, get_task_store
|
||||
from framework.tasks.tools.session_tools import _serialize_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _add_schema() -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"subject": {"type": "string"},
|
||||
"description": {"type": "string"},
|
||||
"active_form": {"type": "string"},
|
||||
"metadata": {"type": "object"},
|
||||
},
|
||||
"required": ["subject"],
|
||||
}
|
||||
|
||||
|
||||
def _update_schema() -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "integer"},
|
||||
"subject": {"type": "string"},
|
||||
"description": {"type": "string"},
|
||||
"active_form": {"type": "string"},
|
||||
"owner": {"type": ["string", "null"]},
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": ["pending", "in_progress", "completed"],
|
||||
},
|
||||
"metadata_patch": {"type": "object"},
|
||||
},
|
||||
"required": ["id"],
|
||||
}
|
||||
|
||||
|
||||
def _remove_schema() -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {"id": {"type": "integer"}},
|
||||
"required": ["id"],
|
||||
}
|
||||
|
||||
|
||||
def _list_schema() -> dict[str, Any]:
|
||||
return {"type": "object", "properties": {}}
|
||||
|
||||
|
||||
_ADD_DESC = (
|
||||
"Append a task to your colony's spawn-plan template. Templates are read "
|
||||
"by `run_parallel_workers` and the UI; workers do not pull from the "
|
||||
"template after spawn. Use this to plan colony work before spawning."
|
||||
)
|
||||
|
||||
_UPDATE_DESC = (
|
||||
"Update a template entry on your colony's spawn-plan template (e.g., "
|
||||
"stamp completion when a worker reports back, adjust subject/description). "
|
||||
"Only the queen can call this."
|
||||
)
|
||||
|
||||
_REMOVE_DESC = (
|
||||
"Remove a template entry from your colony's spawn-plan template. The "
|
||||
"id is reserved (high-water-mark preserved) — never reused."
|
||||
)
|
||||
|
||||
_LIST_DESC = (
|
||||
"List all entries on your colony's spawn-plan template. Each entry "
|
||||
"includes any `metadata.assigned_session` stamp that ties the entry to "
|
||||
"a spawned worker."
|
||||
)
|
||||
|
||||
|
||||
def _make_add_executor(store: TaskStore, list_id: str):
|
||||
async def execute(inputs: dict) -> dict[str, Any]:
|
||||
rec: TaskRecord = await store.create_task(
|
||||
list_id,
|
||||
subject=inputs["subject"],
|
||||
description=inputs.get("description", ""),
|
||||
active_form=inputs.get("active_form"),
|
||||
metadata=inputs.get("metadata") or {},
|
||||
)
|
||||
await emit_task_created(task_list_id=list_id, record=rec)
|
||||
return {
|
||||
"success": True,
|
||||
"task_list_id": list_id,
|
||||
"task_id": rec.id,
|
||||
"message": f"Template entry #{rec.id} added: {rec.subject}",
|
||||
"task": _serialize_task(rec),
|
||||
}
|
||||
|
||||
return execute
|
||||
|
||||
|
||||
def _make_update_executor(store: TaskStore, list_id: str):
|
||||
async def execute(inputs: dict) -> dict[str, Any]:
|
||||
task_id = int(inputs["id"])
|
||||
status_in = inputs.get("status")
|
||||
status_enum = TaskStatus(status_in) if status_in else None
|
||||
owner_in = inputs.get("owner", _UNSET_SENTINEL)
|
||||
new, fields = await store.update_task(
|
||||
list_id,
|
||||
task_id,
|
||||
subject=inputs.get("subject"),
|
||||
description=inputs.get("description"),
|
||||
active_form=inputs.get("active_form"),
|
||||
owner=owner_in,
|
||||
status=status_enum,
|
||||
metadata_patch=inputs.get("metadata_patch"),
|
||||
)
|
||||
if new is None:
|
||||
return {
|
||||
"success": False,
|
||||
"task_list_id": list_id,
|
||||
"task_id": task_id,
|
||||
"message": f"Template entry #{task_id} not found.",
|
||||
}
|
||||
if fields:
|
||||
await emit_task_updated(task_list_id=list_id, record=new, fields=fields)
|
||||
return {
|
||||
"success": True,
|
||||
"task_list_id": list_id,
|
||||
"task_id": task_id,
|
||||
"fields": fields,
|
||||
"message": f"Template entry #{task_id} updated. Fields: {', '.join(fields) or '(none)'}.",
|
||||
"task": _serialize_task(new),
|
||||
}
|
||||
|
||||
return execute
|
||||
|
||||
|
||||
def _make_remove_executor(store: TaskStore, list_id: str):
|
||||
async def execute(inputs: dict) -> dict[str, Any]:
|
||||
task_id = int(inputs["id"])
|
||||
deleted, cascade = await store.delete_task(list_id, task_id)
|
||||
if not deleted:
|
||||
return {
|
||||
"success": False,
|
||||
"task_list_id": list_id,
|
||||
"task_id": task_id,
|
||||
"message": f"Template entry #{task_id} not found.",
|
||||
}
|
||||
await emit_task_deleted(task_list_id=list_id, task_id=task_id, cascade=cascade)
|
||||
return {
|
||||
"success": True,
|
||||
"task_list_id": list_id,
|
||||
"task_id": task_id,
|
||||
"deleted": True,
|
||||
"cascade": cascade,
|
||||
"message": f"Template entry #{task_id} removed.",
|
||||
}
|
||||
|
||||
return execute
|
||||
|
||||
|
||||
def _make_list_executor(store: TaskStore, list_id: str):
|
||||
async def execute(inputs: dict) -> dict[str, Any]:
|
||||
records = await store.list_tasks(list_id)
|
||||
return {
|
||||
"success": True,
|
||||
"task_list_id": list_id,
|
||||
"count": len(records),
|
||||
"tasks": [_serialize_task(r) for r in records],
|
||||
}
|
||||
|
||||
return execute
|
||||
|
||||
|
||||
def build_colony_template_tools(
|
||||
*,
|
||||
colony_id: str,
|
||||
store: TaskStore | None = None,
|
||||
) -> list[tuple[Tool, Any]]:
|
||||
s = store or get_task_store()
|
||||
list_id = colony_task_list_id(colony_id)
|
||||
return [
|
||||
(
|
||||
Tool(
|
||||
name="colony_template_add",
|
||||
description=_ADD_DESC,
|
||||
parameters=_add_schema(),
|
||||
concurrency_safe=False,
|
||||
),
|
||||
_make_add_executor(s, list_id),
|
||||
),
|
||||
(
|
||||
Tool(
|
||||
name="colony_template_update",
|
||||
description=_UPDATE_DESC,
|
||||
parameters=_update_schema(),
|
||||
concurrency_safe=False,
|
||||
),
|
||||
_make_update_executor(s, list_id),
|
||||
),
|
||||
(
|
||||
Tool(
|
||||
name="colony_template_remove",
|
||||
description=_REMOVE_DESC,
|
||||
parameters=_remove_schema(),
|
||||
concurrency_safe=False,
|
||||
),
|
||||
_make_remove_executor(s, list_id),
|
||||
),
|
||||
(
|
||||
Tool(
|
||||
name="colony_template_list",
|
||||
description=_LIST_DESC,
|
||||
parameters=_list_schema(),
|
||||
concurrency_safe=True,
|
||||
),
|
||||
_make_list_executor(s, list_id),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,74 @@
|
||||
"""Wire task tools into a ToolRegistry.
|
||||
|
||||
The four session task tools are registered for every agent that gets a
|
||||
ToolRegistry. The colony template tools are queen-only and registered
|
||||
separately by ``register_colony_template_tools``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from framework.loader.tool_registry import ToolRegistry
|
||||
from framework.tasks.store import TaskStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _wrap_async_executor(async_executor):
|
||||
"""Adapt an async executor to ToolRegistry's sync executor protocol.
|
||||
|
||||
ToolRegistry's executor expects ``Callable[[dict], Any]`` where Any may
|
||||
be a coroutine; the registry awaits it. We just pass the coroutine
|
||||
through.
|
||||
"""
|
||||
|
||||
def executor(inputs: dict) -> Any:
|
||||
return async_executor(inputs)
|
||||
|
||||
return executor
|
||||
|
||||
|
||||
def register_task_tools(
|
||||
registry: ToolRegistry,
|
||||
*,
|
||||
store: TaskStore | None = None,
|
||||
) -> None:
|
||||
"""Register the four session task tools on ``registry``.
|
||||
|
||||
Idempotent: re-registering overwrites the previous executor (which is
|
||||
fine — they share the same TaskStore singleton anyway).
|
||||
"""
|
||||
from framework.tasks.tools.session_tools import build_session_tools
|
||||
|
||||
pairs = build_session_tools(store=store)
|
||||
for tool, async_executor in pairs:
|
||||
registry.register(tool.name, tool, _wrap_async_executor(async_executor))
|
||||
# Also stamp into the concurrency-safe set if appropriate so the
|
||||
# parallel batch dispatcher knows it can fan reads out.
|
||||
if tool.concurrency_safe and tool.name not in ToolRegistry.CONCURRENCY_SAFE_TOOLS:
|
||||
# CONCURRENCY_SAFE_TOOLS is a frozenset; attribute is a frozenset
|
||||
# at the class level, so we instead set the attribute on the Tool
|
||||
# object itself (already done) and trust the dispatcher to read it.
|
||||
pass
|
||||
logger.debug("Registered task tools on %s", registry)
|
||||
|
||||
|
||||
def register_colony_template_tools(
|
||||
registry: ToolRegistry,
|
||||
*,
|
||||
colony_id: str,
|
||||
store: TaskStore | None = None,
|
||||
) -> None:
|
||||
"""Register the queen-only colony_template_* tools on ``registry``.
|
||||
|
||||
Should only be called for the queen of a colony — workers and queen-DM
|
||||
do not get these tools.
|
||||
"""
|
||||
from framework.tasks.tools.colony_tools import build_colony_template_tools
|
||||
|
||||
pairs = build_colony_template_tools(colony_id=colony_id, store=store)
|
||||
for tool, async_executor in pairs:
|
||||
registry.register(tool.name, tool, _wrap_async_executor(async_executor))
|
||||
logger.debug("Registered colony_template_* tools (colony_id=%s)", colony_id)
|
||||
@@ -0,0 +1,594 @@
|
||||
"""The four session task tools: task_create, task_update, task_list, task_get.
|
||||
|
||||
All four operate on the calling agent's OWN session list. They never touch
|
||||
the colony template — the queen has separate ``colony_template_*`` tools
|
||||
for that (see ``colony_tools.py``).
|
||||
|
||||
Concurrency safety:
|
||||
task_list, task_get -> concurrency_safe=True (pure reads)
|
||||
task_create, task_update -> concurrency_safe=False (writes serialize)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from framework.llm.provider import Tool
|
||||
from framework.tasks.events import (
|
||||
emit_task_created,
|
||||
emit_task_deleted,
|
||||
emit_task_updated,
|
||||
)
|
||||
from framework.tasks.hooks import (
|
||||
HOOK_TASK_COMPLETED,
|
||||
HOOK_TASK_CREATED,
|
||||
BlockingHookError,
|
||||
run_task_hooks,
|
||||
)
|
||||
from framework.tasks.models import TaskRecord, TaskStatus
|
||||
from framework.tasks.store import (
|
||||
_UNSET_SENTINEL as _UNSET, # re-export for clarity
|
||||
TaskStore,
|
||||
get_task_store,
|
||||
)
|
||||
from framework.tasks.tools._context import (
|
||||
current_agent_id,
|
||||
current_task_list_id,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schemas (Anthropic-style JSONSchema)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_TASK_STATUS_VALUES = ["pending", "in_progress", "completed", "deleted"]
|
||||
|
||||
|
||||
def _create_schema() -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"subject": {
|
||||
"type": "string",
|
||||
"description": "Imperative title (e.g., 'Crawl target URLs').",
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "Brief description of what to do.",
|
||||
},
|
||||
"active_form": {
|
||||
"type": "string",
|
||||
"description": "Present-continuous label shown while in_progress (e.g., 'Crawling target URLs').",
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"description": "Arbitrary key/value metadata. Use _internal=true to hide from task_list.",
|
||||
},
|
||||
},
|
||||
"required": ["subject"],
|
||||
}
|
||||
|
||||
|
||||
def _update_schema() -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "integer", "description": "Task id (the #N from task_list)."},
|
||||
"subject": {"type": "string"},
|
||||
"description": {"type": "string"},
|
||||
"active_form": {"type": "string"},
|
||||
"owner": {
|
||||
"type": ["string", "null"],
|
||||
"description": "Agent id of the owner. Null clears ownership.",
|
||||
},
|
||||
"status": {"type": "string", "enum": _TASK_STATUS_VALUES},
|
||||
"add_blocks": {
|
||||
"type": "array",
|
||||
"items": {"type": "integer"},
|
||||
"description": "Add task ids that this task blocks (bidirectional).",
|
||||
},
|
||||
"add_blocked_by": {
|
||||
"type": "array",
|
||||
"items": {"type": "integer"},
|
||||
"description": "Add task ids that block this task (bidirectional).",
|
||||
},
|
||||
"metadata_patch": {
|
||||
"type": "object",
|
||||
"description": "Merge into metadata. Null values delete keys.",
|
||||
},
|
||||
},
|
||||
"required": ["id"],
|
||||
}
|
||||
|
||||
|
||||
def _list_schema() -> dict[str, Any]:
|
||||
return {"type": "object", "properties": {}}
|
||||
|
||||
|
||||
def _get_schema() -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {"id": {"type": "integer"}},
|
||||
"required": ["id"],
|
||||
}
|
||||
|
||||
|
||||
def _create_batch_schema() -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tasks": {
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"description": (
|
||||
"Array of task specs. Each becomes one task with a "
|
||||
"sequential id. Atomic — all created or none."
|
||||
),
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"subject": {
|
||||
"type": "string",
|
||||
"description": "Imperative title (e.g. 'Crawl target URL').",
|
||||
},
|
||||
"description": {"type": "string"},
|
||||
"active_form": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Present-continuous label shown while in_progress."
|
||||
),
|
||||
},
|
||||
"metadata": {"type": "object"},
|
||||
},
|
||||
"required": ["subject"],
|
||||
},
|
||||
}
|
||||
},
|
||||
"required": ["tasks"],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool descriptions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CREATE_DESC = (
|
||||
"Create ONE task on your own session task list. Use this for one-off "
|
||||
"mid-run additions when you discover unplanned work after the initial "
|
||||
"plan is laid out.\n\n"
|
||||
"**For laying out a multi-step plan upfront, use `task_create_batch` "
|
||||
"instead** — one tool call with all the steps is cheaper and atomic.\n\n"
|
||||
"Fields:\n"
|
||||
"- subject: short imperative title (e.g. 'Crawl target URL').\n"
|
||||
"- description: optional, slightly longer 'what to do' note.\n"
|
||||
"- active_form: present-continuous label shown while in_progress (e.g. "
|
||||
"'Crawling target URL'). If omitted, the spinner shows the subject.\n"
|
||||
"- metadata: optional KV. Set _internal=true to hide from task_list."
|
||||
)
|
||||
|
||||
_UPDATE_DESC = (
|
||||
"Update ONE task on your own session task list. There is no batch "
|
||||
"update tool by design — every `completed` transition is a discrete "
|
||||
"progress signal to the user.\n\n"
|
||||
"Workflow:\n"
|
||||
"- Mark a task `in_progress` BEFORE you start working on it.\n"
|
||||
"- Mark it `completed` AS SOON as you finish it — do not let "
|
||||
"multiple finished tasks pile up unmarked before flushing them at "
|
||||
"the end of the run.\n"
|
||||
"- Set status='deleted' to drop a task that's no longer relevant.\n\n"
|
||||
"ONLY mark `completed` when the task is FULLY done. If you hit errors, "
|
||||
"blockers, or partial state, keep it `in_progress` and create a new "
|
||||
"task describing what's blocking. Never mark completed with caveats; "
|
||||
"if it's not done, it's not done.\n\n"
|
||||
"Setting status='in_progress' without owner auto-fills your agent_id."
|
||||
)
|
||||
|
||||
_LIST_DESC = (
|
||||
"Show your session task list, sorted by id ascending. Internal tasks "
|
||||
"(metadata._internal=true) and resolved blockers are filtered out. "
|
||||
"**Prefer working on tasks in id order** (lowest first) — earlier "
|
||||
"tasks usually set up context for later ones."
|
||||
)
|
||||
|
||||
_GET_DESC = (
|
||||
"Read the full record of one task (description, metadata, timestamps) "
|
||||
"from your own session task list. Use this to refresh your view of a "
|
||||
"task before updating it if you're not sure of current fields."
|
||||
)
|
||||
|
||||
_CREATE_BATCH_DESC = (
|
||||
"Create N tasks at once on your own session task list. **Use this "
|
||||
"FIRST when laying out a multi-step plan upfront** — replying to 5 "
|
||||
"posts is one `task_create_batch` with 5 entries, not 5 separate "
|
||||
"`task_create` calls. Atomic: all-or-none. Use single `task_create` "
|
||||
"for one-off mid-run additions when you discover unplanned work, "
|
||||
"not for the initial plan."
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Executors
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _resolve_list_id() -> str | None:
|
||||
"""Pull the calling agent's session task_list_id from execution context."""
|
||||
return current_task_list_id()
|
||||
|
||||
|
||||
def _serialize_task(t: TaskRecord) -> dict[str, Any]:
|
||||
return {
|
||||
"id": t.id,
|
||||
"subject": t.subject,
|
||||
"description": t.description,
|
||||
"active_form": t.active_form,
|
||||
"owner": t.owner,
|
||||
"status": t.status.value,
|
||||
"blocks": list(t.blocks),
|
||||
"blocked_by": list(t.blocked_by),
|
||||
"metadata": dict(t.metadata),
|
||||
"created_at": t.created_at,
|
||||
"updated_at": t.updated_at,
|
||||
}
|
||||
|
||||
|
||||
def _make_create_executor(store: TaskStore):
|
||||
async def execute(inputs: dict) -> dict[str, Any]:
|
||||
list_id = _resolve_list_id()
|
||||
if not list_id:
|
||||
return {"success": False, "error": "No task_list_id resolved for this agent."}
|
||||
agent_id = current_agent_id() or ""
|
||||
kwargs = {
|
||||
"subject": inputs["subject"],
|
||||
"description": inputs.get("description", ""),
|
||||
"active_form": inputs.get("active_form"),
|
||||
"metadata": inputs.get("metadata") or {},
|
||||
}
|
||||
rec = await store.create_task(list_id, **kwargs)
|
||||
# task_created hooks may block creation -> rollback by deleting.
|
||||
try:
|
||||
await run_task_hooks(
|
||||
HOOK_TASK_CREATED,
|
||||
task_list_id=list_id,
|
||||
task=rec,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
except BlockingHookError as exc:
|
||||
logger.warning("task_created hook blocked task #%s: %s", rec.id, exc)
|
||||
await store.delete_task(list_id, rec.id)
|
||||
return {"success": False, "error": f"Hook blocked task creation: {exc}"}
|
||||
await emit_task_created(task_list_id=list_id, record=rec)
|
||||
return {
|
||||
"success": True,
|
||||
"task_list_id": list_id,
|
||||
"task_id": rec.id,
|
||||
"message": f"Task #{rec.id} created successfully: {rec.subject}",
|
||||
"task": _serialize_task(rec),
|
||||
}
|
||||
|
||||
return execute
|
||||
|
||||
|
||||
def _make_create_batch_executor(store: TaskStore):
|
||||
async def execute(inputs: dict) -> dict[str, Any]:
|
||||
list_id = _resolve_list_id()
|
||||
if not list_id:
|
||||
return {"success": False, "error": "No task_list_id resolved for this agent."}
|
||||
agent_id = current_agent_id() or ""
|
||||
specs = inputs.get("tasks") or []
|
||||
if not isinstance(specs, list) or not specs:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "task_create_batch requires a non-empty `tasks` array.",
|
||||
}
|
||||
# Storage layer validates subject; surface its error as a soft
|
||||
# tool_result so sibling tools don't cancel.
|
||||
try:
|
||||
recs = await store.create_tasks_batch(list_id, specs)
|
||||
except ValueError as exc:
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
# Run task_created hooks per task; blocking on any aborts the
|
||||
# whole batch (delete every record we just wrote, return error).
|
||||
for rec in recs:
|
||||
try:
|
||||
await run_task_hooks(
|
||||
HOOK_TASK_CREATED,
|
||||
task_list_id=list_id,
|
||||
task=rec,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
except BlockingHookError as exc:
|
||||
logger.warning(
|
||||
"task_created hook blocked batch on task #%s: %s",
|
||||
rec.id,
|
||||
exc,
|
||||
)
|
||||
for r in recs:
|
||||
await store.delete_task(list_id, r.id)
|
||||
return {
|
||||
"success": False,
|
||||
"error": (
|
||||
f"Hook blocked task #{rec.id} ({rec.subject!r}); "
|
||||
f"entire batch rolled back: {exc}"
|
||||
),
|
||||
}
|
||||
|
||||
for rec in recs:
|
||||
await emit_task_created(task_list_id=list_id, record=rec)
|
||||
|
||||
ids = [r.id for r in recs]
|
||||
# Compact summary message — don't flood the conversation with
|
||||
# one line per created task.
|
||||
if len(ids) == 1:
|
||||
range_label = f"#{ids[0]}"
|
||||
elif ids == list(range(ids[0], ids[-1] + 1)):
|
||||
range_label = f"#{ids[0]}-#{ids[-1]}"
|
||||
else:
|
||||
range_label = ", ".join(f"#{i}" for i in ids)
|
||||
return {
|
||||
"success": True,
|
||||
"task_list_id": list_id,
|
||||
"task_ids": ids,
|
||||
"message": (
|
||||
f"Created {len(ids)} task(s): {range_label}. "
|
||||
f"Mark #{ids[0]} in_progress before starting it."
|
||||
),
|
||||
"tasks": [_serialize_task(r) for r in recs],
|
||||
}
|
||||
|
||||
return execute
|
||||
|
||||
|
||||
def _make_update_executor(store: TaskStore):
|
||||
async def execute(inputs: dict) -> dict[str, Any]:
|
||||
list_id = _resolve_list_id()
|
||||
if not list_id:
|
||||
return {"success": False, "error": "No task_list_id resolved for this agent."}
|
||||
agent_id = current_agent_id() or ""
|
||||
task_id = int(inputs["id"])
|
||||
|
||||
status_in = inputs.get("status")
|
||||
# 'deleted' is a synthetic status — handle it as a separate path.
|
||||
if status_in == "deleted":
|
||||
deleted, cascade = await store.delete_task(list_id, task_id)
|
||||
if not deleted:
|
||||
return {
|
||||
"success": False,
|
||||
"task_list_id": list_id,
|
||||
"task_id": task_id,
|
||||
"message": f"Task #{task_id} not found (already deleted?)",
|
||||
}
|
||||
await emit_task_deleted(task_list_id=list_id, task_id=task_id, cascade=cascade)
|
||||
return {
|
||||
"success": True,
|
||||
"task_list_id": list_id,
|
||||
"task_id": task_id,
|
||||
"deleted": True,
|
||||
"cascade": cascade,
|
||||
"message": f"Task #{task_id} deleted.",
|
||||
}
|
||||
|
||||
# Auto-owner on in_progress.
|
||||
owner_in = inputs.get("owner", _OwnerSentinel)
|
||||
status_enum = TaskStatus(status_in) if status_in else None
|
||||
if status_enum == TaskStatus.IN_PROGRESS and owner_in is _OwnerSentinel and agent_id:
|
||||
owner_in = agent_id
|
||||
|
||||
# task_completed hook — fires BEFORE the write (Claude Code's
|
||||
# veto-before-write semantics). If the hook blocks, nothing
|
||||
# touches disk and no SSE event fires. The hook receives a
|
||||
# preview record with the intended new status so it can inspect
|
||||
# what's about to land.
|
||||
if status_enum == TaskStatus.COMPLETED:
|
||||
current = await store.get_task(list_id, task_id)
|
||||
if current is None:
|
||||
return {
|
||||
"success": False,
|
||||
"task_list_id": list_id,
|
||||
"task_id": task_id,
|
||||
"message": f"Task #{task_id} not found.",
|
||||
}
|
||||
if current.status != TaskStatus.COMPLETED:
|
||||
preview = current.model_copy(update={"status": TaskStatus.COMPLETED})
|
||||
try:
|
||||
await run_task_hooks(
|
||||
HOOK_TASK_COMPLETED,
|
||||
task_list_id=list_id,
|
||||
task=preview,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
except BlockingHookError as exc:
|
||||
logger.warning("task_completed hook blocked #%s: %s", task_id, exc)
|
||||
return {
|
||||
"success": False,
|
||||
"task_list_id": list_id,
|
||||
"task_id": task_id,
|
||||
"message": f"Hook blocked completion of #{task_id}: {exc}",
|
||||
"task": _serialize_task(current),
|
||||
}
|
||||
|
||||
# Hook passed (or wasn't applicable) — proceed with the write.
|
||||
new, fields = await store.update_task(
|
||||
list_id,
|
||||
task_id,
|
||||
subject=inputs.get("subject"),
|
||||
description=inputs.get("description"),
|
||||
active_form=inputs.get("active_form"),
|
||||
owner=owner_in if owner_in is not _OwnerSentinel else _UNSET,
|
||||
status=status_enum,
|
||||
add_blocks=inputs.get("add_blocks"),
|
||||
add_blocked_by=inputs.get("add_blocked_by"),
|
||||
metadata_patch=inputs.get("metadata_patch"),
|
||||
)
|
||||
if new is None:
|
||||
# "Task not found" is not an error — keep is_error=False semantics.
|
||||
return {
|
||||
"success": False,
|
||||
"task_list_id": list_id,
|
||||
"task_id": task_id,
|
||||
"message": f"Task #{task_id} not found.",
|
||||
}
|
||||
|
||||
if fields:
|
||||
await emit_task_updated(task_list_id=list_id, record=new, fields=fields)
|
||||
|
||||
# Layer 4: tool-result steering. When a task just completed,
|
||||
# peek at remaining work and append a focused next-step nudge.
|
||||
# For hive's solo (non-claim) model, point at the lowest-id
|
||||
# pending task or signal "all done".
|
||||
message = f"Task #{task_id} updated. Fields changed: {', '.join(fields) or '(none)'}."
|
||||
if status_enum == TaskStatus.COMPLETED and "status" in fields:
|
||||
others = await store.list_tasks(list_id)
|
||||
completed_ids = {r.id for r in others if r.status == TaskStatus.COMPLETED}
|
||||
next_pending = next(
|
||||
(
|
||||
r
|
||||
for r in others
|
||||
if r.status == TaskStatus.PENDING and not [b for b in r.blocked_by if b not in completed_ids]
|
||||
),
|
||||
None,
|
||||
)
|
||||
in_progress = [r for r in others if r.status == TaskStatus.IN_PROGRESS]
|
||||
if in_progress:
|
||||
names = ", ".join(f"#{r.id}" for r in in_progress[:3])
|
||||
message += f" Still in progress: {names}."
|
||||
elif next_pending is not None:
|
||||
message += (
|
||||
f' Next pending: #{next_pending.id} — "{next_pending.subject}". '
|
||||
f"Mark it in_progress before starting."
|
||||
)
|
||||
else:
|
||||
message += " All tasks complete. Wrap up: report results to the user and stop."
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"task_list_id": list_id,
|
||||
"task_id": task_id,
|
||||
"fields": fields,
|
||||
"message": message,
|
||||
"task": _serialize_task(new),
|
||||
}
|
||||
|
||||
return execute
|
||||
|
||||
|
||||
def _make_list_executor(store: TaskStore):
|
||||
async def execute(inputs: dict) -> dict[str, Any]:
|
||||
list_id = _resolve_list_id()
|
||||
if not list_id:
|
||||
return {"success": False, "error": "No task_list_id resolved for this agent."}
|
||||
records = await store.list_tasks(list_id)
|
||||
# Filter resolved blockers from the rendering so a completed
|
||||
# blocker disappears from blocked_by.
|
||||
completed_ids = {r.id for r in records if r.status == TaskStatus.COMPLETED}
|
||||
rendered: list[str] = []
|
||||
for r in records:
|
||||
unresolved_blockers = [b for b in r.blocked_by if b not in completed_ids]
|
||||
line_parts = [f"#{r.id}", f"[{r.status.value}]", r.subject]
|
||||
if r.owner:
|
||||
line_parts.append(f"({r.owner})")
|
||||
if unresolved_blockers:
|
||||
line_parts.append(f"[blocked by {', '.join(f'#{b}' for b in unresolved_blockers)}]")
|
||||
rendered.append(" ".join(line_parts))
|
||||
return {
|
||||
"success": True,
|
||||
"task_list_id": list_id,
|
||||
"count": len(records),
|
||||
"lines": rendered,
|
||||
"tasks": [_serialize_task(r) for r in records],
|
||||
}
|
||||
|
||||
return execute
|
||||
|
||||
|
||||
def _make_get_executor(store: TaskStore):
|
||||
async def execute(inputs: dict) -> dict[str, Any]:
|
||||
list_id = _resolve_list_id()
|
||||
if not list_id:
|
||||
return {"success": False, "error": "No task_list_id resolved for this agent."}
|
||||
task_id = int(inputs["id"])
|
||||
rec = await store.get_task(list_id, task_id)
|
||||
if rec is None:
|
||||
return {
|
||||
"success": False,
|
||||
"task_list_id": list_id,
|
||||
"task_id": task_id,
|
||||
"message": f"Task #{task_id} not found.",
|
||||
}
|
||||
return {
|
||||
"success": True,
|
||||
"task_list_id": list_id,
|
||||
"task_id": task_id,
|
||||
"task": _serialize_task(rec),
|
||||
}
|
||||
|
||||
return execute
|
||||
|
||||
|
||||
# Sentinels so we can distinguish "owner not provided" from "owner=null".
|
||||
class _OwnerSentinel: # noqa: N801 — internal sentinel class
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public registration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def build_session_tools(
|
||||
store: TaskStore | None = None,
|
||||
) -> list[tuple[Tool, Any]]:
|
||||
"""Build (Tool, executor) pairs for the session task tools."""
|
||||
s = store or get_task_store()
|
||||
return [
|
||||
(
|
||||
Tool(
|
||||
name="task_create_batch",
|
||||
description=_CREATE_BATCH_DESC,
|
||||
parameters=_create_batch_schema(),
|
||||
concurrency_safe=False,
|
||||
),
|
||||
_make_create_batch_executor(s),
|
||||
),
|
||||
(
|
||||
Tool(
|
||||
name="task_create",
|
||||
description=_CREATE_DESC,
|
||||
parameters=_create_schema(),
|
||||
concurrency_safe=False,
|
||||
),
|
||||
_make_create_executor(s),
|
||||
),
|
||||
(
|
||||
Tool(
|
||||
name="task_update",
|
||||
description=_UPDATE_DESC,
|
||||
parameters=_update_schema(),
|
||||
concurrency_safe=False,
|
||||
),
|
||||
_make_update_executor(s),
|
||||
),
|
||||
(
|
||||
Tool(
|
||||
name="task_list",
|
||||
description=_LIST_DESC,
|
||||
parameters=_list_schema(),
|
||||
concurrency_safe=True,
|
||||
),
|
||||
_make_list_executor(s),
|
||||
),
|
||||
(
|
||||
Tool(
|
||||
name="task_get",
|
||||
description=_GET_DESC,
|
||||
parameters=_get_schema(),
|
||||
concurrency_safe=True,
|
||||
),
|
||||
_make_get_executor(s),
|
||||
),
|
||||
]
|
||||
@@ -116,6 +116,9 @@ class WorkerSessionAdapter:
|
||||
worker_path: Path | None = None
|
||||
|
||||
|
||||
QUEEN_PHASES: frozenset[str] = frozenset({"independent", "incubating", "working", "reviewing"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueenPhaseState:
|
||||
"""Mutable state container for queen operating phase.
|
||||
@@ -131,7 +134,7 @@ class QueenPhaseState:
|
||||
that trigger phase transitions.
|
||||
"""
|
||||
|
||||
phase: str = "independent" # "independent", "incubating", "working", or "reviewing"
|
||||
phase: str = "independent" # one of QUEEN_PHASES
|
||||
independent_tools: list = field(default_factory=list) # list[Tool]
|
||||
incubating_tools: list = field(default_factory=list) # list[Tool]
|
||||
working_tools: list = field(default_factory=list) # list[Tool]
|
||||
@@ -182,10 +185,31 @@ class QueenPhaseState:
|
||||
# Cached recall blocks — populated async by recall_selector after each turn.
|
||||
_cached_global_recall_block: str = ""
|
||||
_cached_queen_recall_block: str = ""
|
||||
# Cached dynamic system-prompt suffix — frozen at user-turn boundaries so
|
||||
# AgentLoop iterations within a single turn send a byte-stable prompt and
|
||||
# Anthropic's prompt cache keeps the static block warm. Rebuilt by
|
||||
# refresh_dynamic_suffix() on CLIENT_INPUT_RECEIVED and on phase change.
|
||||
_cached_dynamic_suffix: str = ""
|
||||
# Memory directories.
|
||||
global_memory_dir: Path | None = None
|
||||
queen_memory_dir: Path | None = None
|
||||
|
||||
# Per-queen MCP tool allowlist for the INDEPENDENT phase. ``None`` means
|
||||
# "allow every MCP tool" (default, backward-compatible). An explicit list
|
||||
# is authoritative: only tools whose name appears here pass through.
|
||||
# Lifecycle / synthetic tools bypass this gate regardless.
|
||||
enabled_mcp_tools: list[str] | None = None
|
||||
# Union of every MCP-origin tool name currently registered — the set the
|
||||
# allowlist can gate. Populated once at queen boot from
|
||||
# ``ToolRegistry._mcp_server_tools``. Names outside this set (lifecycle,
|
||||
# ``ask_user``) always pass through the filter.
|
||||
mcp_tool_names_all: set = field(default_factory=set)
|
||||
# Memoized output of the filter applied to ``independent_tools``.
|
||||
# Recomputed only when ``enabled_mcp_tools`` or ``independent_tools``
|
||||
# changes, so ``get_current_tools()`` in the independent phase returns
|
||||
# a byte-stable list between saves and the LLM prompt cache stays warm.
|
||||
_filtered_independent_tools: list = field(default_factory=list)
|
||||
|
||||
async def switch_to_working(self, source: str = "tool") -> None:
|
||||
"""Switch to working phase — colony workers are running.
|
||||
|
||||
@@ -204,6 +228,44 @@ class QueenPhaseState:
|
||||
"Colony workers are running. Available tools: " + ", ".join(tool_names) + "."
|
||||
)
|
||||
|
||||
def rebuild_independent_filter(self) -> None:
|
||||
"""Recompute the memoized independent-phase tool list.
|
||||
|
||||
Called once at queen boot (after ``independent_tools``,
|
||||
``mcp_tool_names_all`` and ``enabled_mcp_tools`` are all populated)
|
||||
and again from the tools-PATCH handler whenever the allowlist
|
||||
changes. Keeping the result memoized means the independent-phase
|
||||
branch of ``get_current_tools()`` returns the same Python list
|
||||
object across turns, so the LLM prompt cache stays warm until
|
||||
the user explicitly edits their allowlist.
|
||||
"""
|
||||
if self.enabled_mcp_tools is None:
|
||||
self._filtered_independent_tools = list(self.independent_tools)
|
||||
return
|
||||
allowed = set(self.enabled_mcp_tools)
|
||||
# If ``mcp_tool_names_all`` is empty, every tool falls through the
|
||||
# "not in mcp_tool_names_all" branch below and the allowlist is
|
||||
# silently ignored. That's a fail-open bug (the symptom: a
|
||||
# role-restricted queen sees every MCP tool). Log a warning so the
|
||||
# upstream cause is visible next time it happens.
|
||||
if not self.mcp_tool_names_all:
|
||||
logger.warning(
|
||||
"rebuild_independent_filter: mcp_tool_names_all is empty but "
|
||||
"allowlist has %d entries — allowlist cannot be applied. "
|
||||
"Check that queen boot populated phase_state.mcp_tool_names_all.",
|
||||
len(allowed),
|
||||
)
|
||||
self._filtered_independent_tools = [
|
||||
t for t in self.independent_tools if t.name not in self.mcp_tool_names_all or t.name in allowed
|
||||
]
|
||||
logger.info(
|
||||
"rebuild_independent_filter: allowlist=%d, mcp_names=%d, independent=%d -> filtered=%d",
|
||||
len(allowed),
|
||||
len(self.mcp_tool_names_all),
|
||||
len(self.independent_tools),
|
||||
len(self._filtered_independent_tools),
|
||||
)
|
||||
|
||||
def get_current_tools(self) -> list:
|
||||
"""Return tools for the current phase."""
|
||||
if self.phase == "working":
|
||||
@@ -212,11 +274,29 @@ class QueenPhaseState:
|
||||
return list(self.reviewing_tools)
|
||||
if self.phase == "incubating":
|
||||
return list(self.incubating_tools)
|
||||
# Default / "independent" — DM mode with full MCP tools.
|
||||
return list(self.independent_tools)
|
||||
# Default / "independent" — DM mode with full MCP tools, gated by
|
||||
# the per-queen allowlist. Return the memoized list directly so the
|
||||
# JSON sent to the LLM is byte-identical turn-to-turn.
|
||||
if not self._filtered_independent_tools and self.independent_tools:
|
||||
# Safety net: first-call in tests or code paths that skipped
|
||||
# the explicit boot-time rebuild.
|
||||
self.rebuild_independent_filter()
|
||||
return self._filtered_independent_tools
|
||||
|
||||
def get_current_prompt(self) -> str:
|
||||
"""Return the system prompt for the current phase."""
|
||||
def get_static_prompt(self) -> str:
|
||||
"""Return the stable portion of the system prompt for the current phase.
|
||||
|
||||
Includes identity, phase-role prompt, connected-integrations block,
|
||||
skills catalog, and default skill protocols. These change only on
|
||||
phase transition, queen identity selection, or when the user adds/
|
||||
removes an integration — rare events. Designed to be byte-stable
|
||||
across AgentLoop iterations within a single user turn so that
|
||||
Anthropic's prompt cache keeps this block warm.
|
||||
|
||||
The dynamic tail (recall + timestamp) is returned separately by
|
||||
``get_dynamic_suffix()``; the LLM wrapper emits them as two system
|
||||
content blocks with a cache breakpoint between them.
|
||||
"""
|
||||
if self.phase == "working":
|
||||
base = self.prompt_working
|
||||
elif self.phase == "reviewing":
|
||||
@@ -250,11 +330,51 @@ class QueenPhaseState:
|
||||
parts.append(catalog_prompt)
|
||||
if self.protocols_prompt:
|
||||
parts.append(self.protocols_prompt)
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def refresh_dynamic_suffix(self) -> str:
|
||||
"""Rebuild and cache the dynamic system-prompt suffix.
|
||||
|
||||
The suffix contains recall blocks only. Called from the
|
||||
CLIENT_INPUT_RECEIVED subscriber so the suffix is byte-stable across
|
||||
every AgentLoop iteration within a single user turn.
|
||||
|
||||
Timestamps used to live here too; they were moved into the
|
||||
conversation itself as a ``[YYYY-MM-DD HH:MM TZ]`` prefix on each
|
||||
injected event (see ``drain_injection_queue``) so they ride on
|
||||
byte-stable conversation history instead of busting the
|
||||
per-turn system-prompt cache tail.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
if self._cached_global_recall_block:
|
||||
parts.append(self._cached_global_recall_block)
|
||||
if self._cached_queen_recall_block:
|
||||
parts.append(self._cached_queen_recall_block)
|
||||
return "\n\n".join(parts)
|
||||
self._cached_dynamic_suffix = "\n\n".join(parts)
|
||||
return self._cached_dynamic_suffix
|
||||
|
||||
def get_dynamic_suffix(self) -> str:
|
||||
"""Return the cached dynamic system-prompt suffix.
|
||||
|
||||
Lazily populates on first call so callers don't have to know about
|
||||
the refresh lifecycle. Subsequent calls return the cached string
|
||||
until ``refresh_dynamic_suffix()`` is invoked again.
|
||||
"""
|
||||
if not self._cached_dynamic_suffix:
|
||||
self.refresh_dynamic_suffix()
|
||||
return self._cached_dynamic_suffix
|
||||
|
||||
def get_current_prompt(self) -> str:
|
||||
"""Return the concatenated system prompt (static + dynamic).
|
||||
|
||||
Retained for backward compatibility and for callers that want one
|
||||
string (conversation persistence, debug dumps). The AgentLoop sends
|
||||
the two pieces separately to the LLM so the cache can break between
|
||||
them — see ``get_static_prompt()`` / ``get_dynamic_suffix()``.
|
||||
"""
|
||||
static = self.get_static_prompt()
|
||||
dynamic = self.get_dynamic_suffix()
|
||||
return f"{static}\n\n{dynamic}" if dynamic else static
|
||||
|
||||
async def _emit_phase_event(self) -> None:
|
||||
"""Publish a QUEEN_PHASE_CHANGED event so the frontend updates the tag."""
|
||||
@@ -1208,6 +1328,35 @@ def register_queen_lifecycle_tools(
|
||||
_pinned,
|
||||
)
|
||||
|
||||
# Publish a colony template entry per task BEFORE spawning so
|
||||
# the entries' template ids can be threaded into the spawn data
|
||||
# (workers' ctx.picked_up_from references them). This mirrors the
|
||||
# plan §5d "auto-populated by run_parallel_workers" behavior.
|
||||
_template_ids: list[int | None] = [None] * len(normalised)
|
||||
try:
|
||||
from framework.tasks import TaskListRole, get_task_store
|
||||
from framework.tasks.scoping import colony_task_list_id
|
||||
|
||||
_task_store = get_task_store()
|
||||
_template_list_id = colony_task_list_id(_colony_id or "primary")
|
||||
await _task_store.ensure_task_list(_template_list_id, role=TaskListRole.TEMPLATE)
|
||||
for i, spec in enumerate(normalised):
|
||||
rec = await _task_store.create_task(
|
||||
_template_list_id,
|
||||
subject=spec["task"][:200],
|
||||
description=spec["task"],
|
||||
)
|
||||
_template_ids[i] = rec.id
|
||||
# Thread the template id into the worker's spawn data so
|
||||
# ColonyRuntime.spawn populates ctx.picked_up_from correctly.
|
||||
spec["data"] = dict(spec.get("data") or {})
|
||||
spec["data"]["__template_task_id"] = rec.id
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"run_parallel_workers: colony template publish failed (non-fatal)",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
worker_ids = await colony.spawn_batch(
|
||||
normalised,
|
||||
@@ -1216,6 +1365,33 @@ def register_queen_lifecycle_tools(
|
||||
except Exception as e:
|
||||
return json.dumps({"error": f"spawn_batch failed: {e}"})
|
||||
|
||||
# Stamp `assigned_session` on each template entry post-spawn so the
|
||||
# UI's colony-overview panel can render the assigned-session chip.
|
||||
try:
|
||||
from framework.tasks.events import emit_colony_template_assignment
|
||||
from framework.tasks.scoping import session_task_list_id
|
||||
|
||||
for tid, wid in zip(_template_ids, worker_ids, strict=False):
|
||||
if tid is None:
|
||||
continue
|
||||
_assigned = session_task_list_id(wid, wid)
|
||||
await _task_store.update_task(
|
||||
_template_list_id,
|
||||
tid,
|
||||
metadata_patch={
|
||||
"assigned_session": _assigned,
|
||||
"assigned_worker_id": wid,
|
||||
},
|
||||
)
|
||||
await emit_colony_template_assignment(
|
||||
colony_id=_colony_id or "primary",
|
||||
task_id=tid,
|
||||
assigned_session=_assigned,
|
||||
assigned_worker_id=wid,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("run_parallel_workers: failed to stamp template assignments", exc_info=True)
|
||||
|
||||
# Phase transition — workers are now live, queen is in "working"
|
||||
# phase. Worker-finish auto-transitions back to "reviewing" once
|
||||
# every worker has reported (see queen_orchestrator._on_worker_report).
|
||||
@@ -1382,124 +1558,8 @@ def register_queen_lifecycle_tools(
|
||||
# re-runs idempotent.
|
||||
|
||||
import re as _re
|
||||
import shutil as _shutil
|
||||
|
||||
_COLONY_NAME_RE = _re.compile(r"^[a-z0-9_]+$")
|
||||
_SKILL_NAME_RE = _re.compile(r"^[a-z0-9-]+$")
|
||||
|
||||
def _materialize_skill_folder(
|
||||
*,
|
||||
skill_name: str,
|
||||
skill_description: str,
|
||||
skill_body: str,
|
||||
skill_files: list[dict] | None,
|
||||
colony_dir: Path,
|
||||
) -> tuple[Path | None, str | None, bool]:
|
||||
"""Write a skill folder under ``{colony_dir}/.hive/skills/{name}/`` from inline content.
|
||||
|
||||
The skill is scoped to a single colony: ``SkillDiscovery`` scans
|
||||
``{project_root}/.hive/skills/`` as project-scope, and the
|
||||
colony's worker uses ``project_root = colony_dir`` — so only
|
||||
that colony's workers see it, not every colony on the machine.
|
||||
We deliberately avoid ``~/.hive/skills/`` here because that
|
||||
directory is scanned as user scope and leaks into every agent.
|
||||
|
||||
Returns ``(installed_path, error, replaced)``. On success
|
||||
``error`` is ``None`` and ``installed_path`` is the final
|
||||
location; ``replaced`` is ``True`` when an existing skill with
|
||||
the same name was overwritten. On failure ``installed_path`` is
|
||||
``None``, ``error`` is a human-readable reason, and
|
||||
``replaced`` is ``False``.
|
||||
"""
|
||||
name = (skill_name or "").strip() if isinstance(skill_name, str) else ""
|
||||
if not name:
|
||||
return None, "skill_name is required", False
|
||||
if not _SKILL_NAME_RE.match(name):
|
||||
return None, (f"skill_name '{name}' must match [a-z0-9-] pattern"), False
|
||||
if name.startswith("-") or name.endswith("-") or "--" in name:
|
||||
return None, (f"skill_name '{name}' has leading/trailing/consecutive hyphens"), False
|
||||
if len(name) > 64:
|
||||
return None, f"skill_name '{name}' exceeds 64 chars", False
|
||||
|
||||
desc = (skill_description or "").strip() if isinstance(skill_description, str) else ""
|
||||
if not desc:
|
||||
return None, "skill_description is required", False
|
||||
if len(desc) > 1024:
|
||||
return None, "skill_description must be 1–1024 chars", False
|
||||
# Frontmatter descriptions must stay on a single line because
|
||||
# our frontmatter parser is line-oriented and the downstream
|
||||
# skill loader expects ``description:`` to resolve to one value.
|
||||
if "\n" in desc or "\r" in desc:
|
||||
return None, "skill_description must be a single line (no newlines)", False
|
||||
|
||||
body = skill_body if isinstance(skill_body, str) else ""
|
||||
if not body.strip():
|
||||
return (
|
||||
None,
|
||||
(
|
||||
"skill_body is required — the operational procedure the "
|
||||
"colony worker needs to run this job unattended"
|
||||
),
|
||||
False,
|
||||
)
|
||||
|
||||
# Optional supporting files (scripts/, references/, assets/…).
|
||||
# Each entry: {"path": "<relative>", "content": "<text>"}.
|
||||
normalized_files: list[tuple[Path, str]] = []
|
||||
if skill_files:
|
||||
if not isinstance(skill_files, list):
|
||||
return None, "skill_files must be an array", False
|
||||
for entry in skill_files:
|
||||
if not isinstance(entry, dict):
|
||||
return None, "each skill_files entry must be an object with 'path' and 'content'", False
|
||||
rel_raw = entry.get("path")
|
||||
content = entry.get("content")
|
||||
if not isinstance(rel_raw, str) or not rel_raw.strip():
|
||||
return None, "skill_files entry missing non-empty 'path'", False
|
||||
if not isinstance(content, str):
|
||||
return None, f"skill_files entry '{rel_raw}' missing string 'content'", False
|
||||
rel_stripped = rel_raw.strip()
|
||||
# Normalize a leading ``./`` but do NOT strip bare ``/`` —
|
||||
# an absolute path should be rejected, not silently relativized.
|
||||
if rel_stripped.startswith("./"):
|
||||
rel_stripped = rel_stripped[2:]
|
||||
rel_path = Path(rel_stripped)
|
||||
if rel_stripped.startswith("/") or rel_path.is_absolute() or ".." in rel_path.parts:
|
||||
return None, (f"skill_files path '{rel_raw}' must be relative and inside the skill folder"), False
|
||||
if rel_path.as_posix() == "SKILL.md":
|
||||
return None, ("skill_files must not contain SKILL.md — pass skill_body instead"), False
|
||||
normalized_files.append((rel_path, content))
|
||||
|
||||
target_root = colony_dir / ".hive" / "skills"
|
||||
target = target_root / name
|
||||
try:
|
||||
target_root.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as e:
|
||||
return None, f"failed to create skills root: {e}", False
|
||||
|
||||
replaced = False
|
||||
try:
|
||||
if target.exists():
|
||||
# Queen is re-creating a skill under the same name —
|
||||
# her latest content wins. rmtree first so stale files
|
||||
# from a prior version don't linger alongside the new
|
||||
# ones (copytree with dirs_exist_ok would merge them).
|
||||
replaced = True
|
||||
_shutil.rmtree(target)
|
||||
target.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
body_norm = body.rstrip() + "\n"
|
||||
skill_md_text = f"---\nname: {name}\ndescription: {desc}\n---\n\n{body_norm}"
|
||||
(target / "SKILL.md").write_text(skill_md_text, encoding="utf-8")
|
||||
|
||||
for rel_path, file_content in normalized_files:
|
||||
full_path = target / rel_path
|
||||
full_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
full_path.write_text(file_content, encoding="utf-8")
|
||||
except OSError as e:
|
||||
return None, f"failed to write skill folder {target}: {e}", False
|
||||
|
||||
return target, None, replaced
|
||||
|
||||
def _validate_triggers(raw: Any) -> tuple[list[dict] | None, str | None]:
|
||||
"""Validate and normalize the ``triggers`` argument for create_colony.
|
||||
@@ -1645,17 +1705,26 @@ def register_queen_lifecycle_tools(
|
||||
except OSError as e:
|
||||
return json.dumps({"error": f"failed to create colony dir {colony_dir}: {e}"})
|
||||
|
||||
installed_skill, skill_err, skill_replaced = _materialize_skill_folder(
|
||||
# Validate + write via the shared authoring module so the HTTP
|
||||
# routes and this tool stay in lockstep.
|
||||
from framework.skills.authoring import build_draft, write_skill
|
||||
from framework.skills.overrides import (
|
||||
OverrideEntry,
|
||||
Provenance,
|
||||
SkillOverrideStore,
|
||||
utc_now,
|
||||
)
|
||||
|
||||
draft, draft_err = build_draft(
|
||||
skill_name=skill_name,
|
||||
skill_description=skill_description,
|
||||
skill_body=skill_body,
|
||||
skill_files=skill_files,
|
||||
colony_dir=colony_dir,
|
||||
)
|
||||
if skill_err is not None:
|
||||
if draft_err is not None or draft is None:
|
||||
return json.dumps(
|
||||
{
|
||||
"error": skill_err,
|
||||
"error": draft_err or "invalid skill draft",
|
||||
"hint": (
|
||||
"Provide skill_name (lowercase [a-z0-9-], ≤64 chars), "
|
||||
"skill_description (single line, 1–1024 chars), and "
|
||||
@@ -1668,6 +1737,63 @@ def register_queen_lifecycle_tools(
|
||||
}
|
||||
)
|
||||
|
||||
installed_skill, write_err, skill_replaced = write_skill(
|
||||
draft,
|
||||
target_root=colony_dir / ".hive" / "skills",
|
||||
replace_existing=True,
|
||||
)
|
||||
if write_err is not None or installed_skill is None:
|
||||
return json.dumps(
|
||||
{
|
||||
"error": write_err or "failed to write skill folder",
|
||||
}
|
||||
)
|
||||
|
||||
# Seed the colony's override ledger from the queen's current
|
||||
# state so the colony inherits everything she had enabled (preset
|
||||
# capability packs, toggled-off framework defaults, etc.) at fork
|
||||
# time. The colony then owns its own copy — later queen edits
|
||||
# don't retroactively alter this colony's skill surface.
|
||||
# On top of the seed we upsert the newly-written skill with
|
||||
# QUEEN_CREATED provenance so the UI renders + edits it properly.
|
||||
try:
|
||||
from framework.config import QUEENS_DIR
|
||||
|
||||
overrides_path = colony_dir / "skills_overrides.json"
|
||||
queen_id = getattr(session, "queen_name", None) or "unknown"
|
||||
colony_store = SkillOverrideStore.load(overrides_path, scope_label=f"colony:{cn}")
|
||||
|
||||
queen_overrides_path = QUEENS_DIR / queen_id / "skills_overrides.json"
|
||||
if queen_overrides_path.exists():
|
||||
queen_store = SkillOverrideStore.load(queen_overrides_path, scope_label=f"queen:{queen_id}")
|
||||
# Shallow clone: queen's explicit toggles + master switch
|
||||
# become the colony's starting state. Tombstones propagate
|
||||
# so a queen-deleted UI skill doesn't resurrect here.
|
||||
colony_store.all_defaults_disabled = queen_store.all_defaults_disabled
|
||||
for sname, entry in queen_store.overrides.items():
|
||||
# Don't overwrite an entry the colony already set
|
||||
# (rare on fresh fork; matters if this is a re-fork).
|
||||
if sname in colony_store.overrides:
|
||||
continue
|
||||
colony_store.upsert(sname, entry.clone())
|
||||
for sname in queen_store.deleted_ui_skills:
|
||||
colony_store.deleted_ui_skills.add(sname)
|
||||
|
||||
colony_store.upsert(
|
||||
draft.name,
|
||||
OverrideEntry(
|
||||
enabled=True,
|
||||
provenance=Provenance.QUEEN_CREATED,
|
||||
created_at=utc_now(),
|
||||
created_by=f"queen:{queen_id}",
|
||||
),
|
||||
)
|
||||
colony_store.save()
|
||||
except Exception:
|
||||
# Registration is best-effort; discovery still surfaces the
|
||||
# skill as project-scope even if the ledger fails to update.
|
||||
logger.warning("create_colony: override registration failed", exc_info=True)
|
||||
|
||||
logger.info(
|
||||
"create_colony: materialized skill at %s (replaced=%s)",
|
||||
installed_skill,
|
||||
|
||||
@@ -5,6 +5,8 @@ import ColonyChat from "./pages/colony-chat";
|
||||
import QueenDM from "./pages/queen-dm";
|
||||
import OrgChart from "./pages/org-chart";
|
||||
import PromptLibrary from "./pages/prompt-library";
|
||||
import SkillsLibrary from "./pages/skills-library";
|
||||
import ToolLibrary from "./pages/tool-library";
|
||||
import CredentialsPage from "./pages/credentials";
|
||||
import NotFound from "./pages/not-found";
|
||||
|
||||
@@ -16,7 +18,9 @@ function App() {
|
||||
<Route path="/colony/:colonyId" element={<ColonyChat />} />
|
||||
<Route path="/queen/:queenId" element={<QueenDM />} />
|
||||
<Route path="/org-chart" element={<OrgChart />} />
|
||||
<Route path="/skills-library" element={<SkillsLibrary />} />
|
||||
<Route path="/prompt-library" element={<PromptLibrary />} />
|
||||
<Route path="/tool-library" element={<ToolLibrary />} />
|
||||
<Route path="/credentials" element={<CredentialsPage />} />
|
||||
<Route path="*" element={<NotFound />} />
|
||||
</Route>
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
import { api } from "./client";
|
||||
import type { ToolMeta, McpServerTools } from "./queens";
|
||||
|
||||
export interface ColonySummary {
|
||||
name: string;
|
||||
queen_name: string | null;
|
||||
created_at: string | null;
|
||||
has_allowlist: boolean;
|
||||
enabled_count: number | null;
|
||||
}
|
||||
|
||||
export interface ColonyToolsResponse {
|
||||
colony_name: string;
|
||||
enabled_mcp_tools: string[] | null;
|
||||
stale: boolean;
|
||||
lifecycle: ToolMeta[];
|
||||
synthetic: ToolMeta[];
|
||||
mcp_servers: McpServerTools[];
|
||||
}
|
||||
|
||||
export interface ColonyToolsUpdateResult {
|
||||
colony_name: string;
|
||||
enabled_mcp_tools: string[] | null;
|
||||
refreshed_runtimes: number;
|
||||
note?: string;
|
||||
}
|
||||
|
||||
export const coloniesApi = {
|
||||
/** List every colony on disk with a summary of its tool allowlist. */
|
||||
list: () =>
|
||||
api.get<{ colonies: ColonySummary[] }>(`/colonies/tools-index`),
|
||||
|
||||
/** Enumerate a colony's tool surface (lifecycle + synthetic + MCP). */
|
||||
getTools: (colonyName: string) =>
|
||||
api.get<ColonyToolsResponse>(
|
||||
`/colony/${encodeURIComponent(colonyName)}/tools`,
|
||||
),
|
||||
|
||||
/** Persist a colony's MCP tool allowlist.
|
||||
*
|
||||
* ``null`` resets to "allow every MCP tool". A list of names enables
|
||||
* only those MCP tools. Changes take effect on the next worker spawn;
|
||||
* in-flight workers keep their booted tool list.
|
||||
*/
|
||||
updateTools: (colonyName: string, enabled: string[] | null) =>
|
||||
api.patch<ColonyToolsUpdateResult>(
|
||||
`/colony/${encodeURIComponent(colonyName)}/tools`,
|
||||
{ enabled_mcp_tools: enabled },
|
||||
),
|
||||
};
|
||||
@@ -0,0 +1,66 @@
|
||||
import { api } from "./client";
|
||||
|
||||
export type McpTransport = "stdio" | "http" | "sse" | "unix";
|
||||
|
||||
export interface McpServer {
|
||||
name: string;
|
||||
/** "local": added via UI/CLI (user-editable). "registry": installed from
|
||||
* the remote MCP registry. "built-in": baked into the queen package —
|
||||
* visible but not removable from the UI. */
|
||||
source: "local" | "registry" | "built-in";
|
||||
transport: McpTransport | string;
|
||||
description: string;
|
||||
enabled: boolean;
|
||||
last_health_status: "healthy" | "unhealthy" | null;
|
||||
last_error: string | null;
|
||||
last_health_check_at: string | null;
|
||||
tool_count: number | null;
|
||||
/** Servers flagged removable:false cannot be deleted from the UI. */
|
||||
removable?: boolean;
|
||||
}
|
||||
|
||||
export interface AddMcpServerBody {
|
||||
name: string;
|
||||
transport: McpTransport;
|
||||
/** stdio */
|
||||
command?: string;
|
||||
args?: string[];
|
||||
env?: Record<string, string>;
|
||||
cwd?: string;
|
||||
/** http / sse */
|
||||
url?: string;
|
||||
headers?: Record<string, string>;
|
||||
/** unix */
|
||||
socket_path?: string;
|
||||
description?: string;
|
||||
}
|
||||
|
||||
export interface McpHealthResult {
|
||||
name: string;
|
||||
status: "healthy" | "unhealthy" | "unknown";
|
||||
tools: number;
|
||||
error: string | null;
|
||||
}
|
||||
|
||||
/** Backend MCPError shape when an operation fails. */
|
||||
export interface McpErrorBody {
|
||||
error: string;
|
||||
code?: string;
|
||||
what?: string;
|
||||
why?: string;
|
||||
fix?: string;
|
||||
}
|
||||
|
||||
export const mcpApi = {
|
||||
listServers: () => api.get<{ servers: McpServer[] }>("/mcp/servers"),
|
||||
addServer: (body: AddMcpServerBody) =>
|
||||
api.post<{ server: McpServer; hint: string }>("/mcp/servers", body),
|
||||
removeServer: (name: string) =>
|
||||
api.delete<{ removed: string }>(`/mcp/servers/${encodeURIComponent(name)}`),
|
||||
setEnabled: (name: string, enabled: boolean) =>
|
||||
api.post<{ name: string; enabled: boolean }>(
|
||||
`/mcp/servers/${encodeURIComponent(name)}/${enabled ? "enable" : "disable"}`,
|
||||
),
|
||||
checkHealth: (name: string) =>
|
||||
api.post<McpHealthResult>(`/mcp/servers/${encodeURIComponent(name)}/health`),
|
||||
};
|
||||
@@ -16,6 +16,45 @@ export interface QueenSessionResult {
|
||||
status: "live" | "resumed" | "created";
|
||||
}
|
||||
|
||||
export interface ToolMeta {
|
||||
name: string;
|
||||
description: string;
|
||||
input_schema?: Record<string, unknown>;
|
||||
editable?: boolean;
|
||||
}
|
||||
|
||||
export interface McpServerTools {
|
||||
name: string;
|
||||
tools: Array<ToolMeta & { enabled: boolean }>;
|
||||
}
|
||||
|
||||
export interface QueenToolsResponse {
|
||||
queen_id: string;
|
||||
enabled_mcp_tools: string[] | null;
|
||||
/** True when the effective allowlist comes from the role-based default
|
||||
* (no tools.json sidecar saved for this queen). False means the user
|
||||
* has explicitly saved an allowlist. */
|
||||
is_role_default: boolean;
|
||||
stale: boolean;
|
||||
lifecycle: ToolMeta[];
|
||||
synthetic: ToolMeta[];
|
||||
mcp_servers: McpServerTools[];
|
||||
}
|
||||
|
||||
export interface QueenToolsUpdateResult {
|
||||
queen_id: string;
|
||||
enabled_mcp_tools: string[] | null;
|
||||
refreshed_sessions: number;
|
||||
}
|
||||
|
||||
export interface QueenToolsResetResult {
|
||||
queen_id: string;
|
||||
removed: boolean;
|
||||
enabled_mcp_tools: string[] | null;
|
||||
is_role_default: true;
|
||||
refreshed_sessions: number;
|
||||
}
|
||||
|
||||
export const queensApi = {
|
||||
/** List all queen profiles (id, name, title). */
|
||||
list: () =>
|
||||
@@ -57,4 +96,24 @@ export const queensApi = {
|
||||
initial_prompt: initialPrompt,
|
||||
initial_phase: initialPhase || undefined,
|
||||
}),
|
||||
|
||||
/** Enumerate the queen's tool surface (lifecycle + synthetic + MCP). */
|
||||
getTools: (queenId: string) =>
|
||||
api.get<QueenToolsResponse>(`/queen/${queenId}/tools`),
|
||||
|
||||
/** Persist the MCP tool allowlist for a queen.
|
||||
*
|
||||
* Pass ``null`` to explicitly allow every MCP tool, or a list to
|
||||
* restrict the queen's tool surface. Lifecycle and synthetic tools
|
||||
* are always enabled and cannot be listed here.
|
||||
*/
|
||||
updateTools: (queenId: string, enabled: string[] | null) =>
|
||||
api.patch<QueenToolsUpdateResult>(`/queen/${queenId}/tools`, {
|
||||
enabled_mcp_tools: enabled,
|
||||
}),
|
||||
|
||||
/** Drop the queen's tools.json sidecar so she falls back to the
|
||||
* role-based default (or allow-all for queens without a role entry). */
|
||||
resetTools: (queenId: string) =>
|
||||
api.delete<QueenToolsResetResult>(`/queen/${queenId}/tools`),
|
||||
};
|
||||
|
||||
@@ -0,0 +1,154 @@
|
||||
import { api } from "./client";
|
||||
|
||||
export type SkillScopeKind = "queen" | "colony" | "user";
|
||||
|
||||
export type SkillProvenance =
|
||||
| "framework"
|
||||
| "preset"
|
||||
| "user_dropped"
|
||||
| "user_ui_created"
|
||||
| "queen_created"
|
||||
| "learned_runtime"
|
||||
| "project_dropped"
|
||||
| "other";
|
||||
|
||||
export interface SkillOwner {
|
||||
type: "queen" | "colony";
|
||||
id: string;
|
||||
name: string;
|
||||
}
|
||||
|
||||
export interface SkillRow {
|
||||
name: string;
|
||||
description: string;
|
||||
source_scope: string;
|
||||
provenance: SkillProvenance;
|
||||
enabled: boolean;
|
||||
editable: boolean;
|
||||
deletable: boolean;
|
||||
location: string;
|
||||
base_dir?: string;
|
||||
visibility: string[] | null;
|
||||
trust: string | null;
|
||||
created_at: string | null;
|
||||
created_by: string | null;
|
||||
notes: string | null;
|
||||
param_overrides?: Record<string, unknown>;
|
||||
owner?: SkillOwner | null;
|
||||
visible_to?: { queens: string[]; colonies: string[] };
|
||||
enabled_by_default?: boolean;
|
||||
}
|
||||
|
||||
export interface ScopeSkillsResponse {
|
||||
queen_id?: string;
|
||||
colony_name?: string;
|
||||
all_defaults_disabled: boolean;
|
||||
skills: SkillRow[];
|
||||
inherited_from_queen?: string[];
|
||||
}
|
||||
|
||||
export interface AggregatedSkillsResponse {
|
||||
skills: SkillRow[];
|
||||
queens: Array<{ id: string; name: string }>;
|
||||
colonies: Array<{ name: string; queen_id: string | null }>;
|
||||
}
|
||||
|
||||
export interface SkillScopesResponse {
|
||||
queens: Array<{ id: string; name: string }>;
|
||||
colonies: Array<{ name: string; queen_id: string | null }>;
|
||||
}
|
||||
|
||||
export interface SkillDetailResponse {
|
||||
name: string;
|
||||
description: string;
|
||||
source_scope: string;
|
||||
location: string;
|
||||
base_dir: string;
|
||||
body: string;
|
||||
visibility: string[] | null;
|
||||
}
|
||||
|
||||
export interface SkillCreatePayload {
|
||||
name: string;
|
||||
description: string;
|
||||
body: string;
|
||||
files?: Array<{ path: string; content: string }>;
|
||||
enabled?: boolean;
|
||||
notes?: string | null;
|
||||
replace_existing?: boolean;
|
||||
}
|
||||
|
||||
export interface SkillPatchPayload {
|
||||
enabled?: boolean;
|
||||
param_overrides?: Record<string, unknown>;
|
||||
notes?: string | null;
|
||||
all_defaults_disabled?: boolean;
|
||||
}
|
||||
|
||||
const scopePath = (scope: "queen" | "colony", targetId: string) =>
|
||||
scope === "queen"
|
||||
? `/queen/${encodeURIComponent(targetId)}/skills`
|
||||
: `/colonies/${encodeURIComponent(targetId)}/skills`;
|
||||
|
||||
export const skillsApi = {
|
||||
// Aggregated library
|
||||
listAll: () => api.get<AggregatedSkillsResponse>("/skills"),
|
||||
listScopes: () => api.get<SkillScopesResponse>("/skills/scopes"),
|
||||
getDetail: (name: string) =>
|
||||
api.get<SkillDetailResponse>(`/skills/${encodeURIComponent(name)}`),
|
||||
|
||||
// Per-scope
|
||||
listForQueen: (queenId: string) =>
|
||||
api.get<ScopeSkillsResponse>(`/queen/${encodeURIComponent(queenId)}/skills`),
|
||||
listForColony: (colonyName: string) =>
|
||||
api.get<ScopeSkillsResponse>(
|
||||
`/colonies/${encodeURIComponent(colonyName)}/skills`,
|
||||
),
|
||||
|
||||
create: (
|
||||
scope: "queen" | "colony",
|
||||
targetId: string,
|
||||
payload: SkillCreatePayload,
|
||||
) => api.post<SkillRow>(scopePath(scope, targetId), payload),
|
||||
|
||||
patch: (
|
||||
scope: "queen" | "colony",
|
||||
targetId: string,
|
||||
skillName: string,
|
||||
payload: SkillPatchPayload,
|
||||
) =>
|
||||
api.patch<{ name: string; enabled: boolean | null; ok: boolean }>(
|
||||
`${scopePath(scope, targetId)}/${encodeURIComponent(skillName)}`,
|
||||
payload,
|
||||
),
|
||||
|
||||
putBody: (
|
||||
scope: "queen" | "colony",
|
||||
targetId: string,
|
||||
skillName: string,
|
||||
payload: { body: string; description?: string },
|
||||
) =>
|
||||
api.put<{ name: string; installed_path: string }>(
|
||||
`${scopePath(scope, targetId)}/${encodeURIComponent(skillName)}/body`,
|
||||
payload,
|
||||
),
|
||||
|
||||
remove: (scope: "queen" | "colony", targetId: string, skillName: string) =>
|
||||
api.delete<{ name: string; removed: boolean }>(
|
||||
`${scopePath(scope, targetId)}/${encodeURIComponent(skillName)}`,
|
||||
),
|
||||
|
||||
reload: (scope: "queen" | "colony", targetId: string) =>
|
||||
api.post<{ ok: boolean }>(`${scopePath(scope, targetId)}/reload`),
|
||||
|
||||
// Multipart upload. File may be a SKILL.md or a .zip bundle.
|
||||
upload: (formData: FormData) =>
|
||||
api.upload<{
|
||||
name: string;
|
||||
installed_path: string;
|
||||
replaced: boolean;
|
||||
scope: SkillScopeKind;
|
||||
target_id: string | null;
|
||||
enabled: boolean;
|
||||
}>("/skills/upload", formData),
|
||||
};
|
||||
@@ -0,0 +1,151 @@
|
||||
/**
|
||||
* REST + types for the task system.
|
||||
*
|
||||
* Two list types:
|
||||
* colony:{colony_id} — colony template (queen's spawn plan)
|
||||
* session:{agent_id}:{sess_id} — per-session working list
|
||||
*
|
||||
* Each agent operates on its OWN session list via the four task tools;
|
||||
* the colony template is queen-owned and read by the UI.
|
||||
*/
|
||||
|
||||
import { api, ApiError } from "./client";
|
||||
|
||||
export type TaskStatus = "pending" | "in_progress" | "completed";
|
||||
export type TaskListRole = "template" | "session";
|
||||
|
||||
export interface TaskRecord {
|
||||
id: number;
|
||||
subject: string;
|
||||
description: string;
|
||||
active_form: string | null;
|
||||
owner: string | null;
|
||||
status: TaskStatus;
|
||||
blocks: number[];
|
||||
blocked_by: number[];
|
||||
metadata: Record<string, unknown>;
|
||||
created_at: number;
|
||||
updated_at: number;
|
||||
}
|
||||
|
||||
export interface TaskListSnapshot {
|
||||
task_list_id: string;
|
||||
role: TaskListRole;
|
||||
meta: {
|
||||
task_list_id: string;
|
||||
role: TaskListRole;
|
||||
creator_agent_id: string | null;
|
||||
created_at: number;
|
||||
last_seen_session_ids: string[];
|
||||
schema_version: number;
|
||||
} | null;
|
||||
tasks: TaskRecord[];
|
||||
}
|
||||
|
||||
export interface ColonyTaskLists {
|
||||
template_task_list_id: string;
|
||||
queen_session_task_list_id: string | null;
|
||||
}
|
||||
|
||||
export interface SessionTaskListInfo {
|
||||
task_list_id: string | null;
|
||||
picked_up_from: { colony_id: string; task_id: number } | null;
|
||||
}
|
||||
|
||||
export const tasksApi = {
|
||||
/**
|
||||
* Snapshot of one task list, identified by its full task_list_id.
|
||||
*
|
||||
* Returns ``null`` if the list does not exist on disk yet (404). That
|
||||
* happens when a session has just started and no agent has called
|
||||
* ``task_create`` — the panel should hide until the first task is
|
||||
* created instead of surfacing the 404 as an error.
|
||||
*/
|
||||
async getList(taskListId: string): Promise<TaskListSnapshot | null> {
|
||||
try {
|
||||
return await api.get<TaskListSnapshot>(`/tasks/${encodeURIComponent(taskListId)}`);
|
||||
} catch (err) {
|
||||
if (err instanceof ApiError && err.status === 404) return null;
|
||||
throw err;
|
||||
}
|
||||
},
|
||||
/** Helper: resolve template + queen-session list ids for a colony. */
|
||||
async getColonyLists(
|
||||
colonyId: string,
|
||||
queenSessionId?: string,
|
||||
): Promise<ColonyTaskLists> {
|
||||
const qs = queenSessionId ? `?queen_session_id=${encodeURIComponent(queenSessionId)}` : "";
|
||||
return api.get<ColonyTaskLists>(`/colonies/${encodeURIComponent(colonyId)}/task_lists${qs}`);
|
||||
},
|
||||
/** Helper: resolve task_list_id + picked_up_from for a session. */
|
||||
async getSessionInfo(
|
||||
sessionId: string,
|
||||
agentId: string = "queen",
|
||||
): Promise<SessionTaskListInfo> {
|
||||
return api.get<SessionTaskListInfo>(
|
||||
`/sessions/${encodeURIComponent(sessionId)}/task_list_id?agent_id=${encodeURIComponent(agentId)}`,
|
||||
);
|
||||
},
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SSE event payload shapes
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export interface TaskCreatedEvent {
|
||||
task_list_id: string;
|
||||
task: TaskRecord;
|
||||
}
|
||||
|
||||
export interface TaskUpdatedEvent {
|
||||
task_list_id: string;
|
||||
task_id: number;
|
||||
after: TaskRecord;
|
||||
fields: string[];
|
||||
}
|
||||
|
||||
export interface TaskDeletedEvent {
|
||||
task_list_id: string;
|
||||
task_id: number;
|
||||
cascade: number[];
|
||||
}
|
||||
|
||||
export interface ColonyTemplateAssignmentEvent {
|
||||
colony_id: string;
|
||||
task_id: number;
|
||||
assigned_session: string | null;
|
||||
assigned_worker_id: string | null;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/** Parse a task_list_id into structured parts (mirrors server-side scoping). */
|
||||
export function parseTaskListId(taskListId: string): {
|
||||
kind: "colony" | "session" | "raw";
|
||||
colony_id?: string;
|
||||
agent_id?: string;
|
||||
session_id?: string;
|
||||
raw?: string;
|
||||
} {
|
||||
if (taskListId.startsWith("colony:")) {
|
||||
return { kind: "colony", colony_id: taskListId.slice("colony:".length) };
|
||||
}
|
||||
if (taskListId.startsWith("session:")) {
|
||||
const rest = taskListId.slice("session:".length);
|
||||
const idx = rest.indexOf(":");
|
||||
return idx > 0
|
||||
? { kind: "session", agent_id: rest.slice(0, idx), session_id: rest.slice(idx + 1) }
|
||||
: { kind: "raw", raw: taskListId };
|
||||
}
|
||||
return { kind: "raw", raw: taskListId };
|
||||
}
|
||||
|
||||
export function colonyTaskListId(colonyId: string): string {
|
||||
return `colony:${colonyId}`;
|
||||
}
|
||||
|
||||
export function sessionTaskListId(agentId: string, sessionId: string): string {
|
||||
return `session:${agentId}:${sessionId}`;
|
||||
}
|
||||
@@ -287,7 +287,13 @@ export type EventTypeName =
|
||||
| "trigger_fired"
|
||||
| "trigger_removed"
|
||||
| "trigger_updated"
|
||||
| "queen_identity_selected";
|
||||
| "queen_identity_selected"
|
||||
| "task_created"
|
||||
| "task_updated"
|
||||
| "task_deleted"
|
||||
| "task_list_reset"
|
||||
| "task_list_reattach_mismatch"
|
||||
| "colony_template_assignment";
|
||||
|
||||
export interface AgentEvent {
|
||||
type: EventTypeName;
|
||||
|
||||
@@ -151,8 +151,11 @@ interface ChatPanelProps {
|
||||
onStartNewSession?: () => void;
|
||||
/** When true, disable the start-new-session button (request in flight). */
|
||||
startingNewSession?: boolean;
|
||||
/** Cumulative LLM token usage for this session */
|
||||
tokenUsage?: { input: number; output: number };
|
||||
/** Cumulative LLM token usage for this session.
|
||||
* `cached` (cache reads) and `cacheCreated` (cache writes) are subsets of
|
||||
* `input` — providers count both inside prompt_tokens. Display them
|
||||
* separately; do not add to a total. */
|
||||
tokenUsage?: { input: number; output: number; cached?: number; cacheCreated?: number; costUsd?: number };
|
||||
/** Optional action element rendered on the right side of the "Conversation" header */
|
||||
headerAction?: React.ReactNode;
|
||||
}
|
||||
@@ -1482,11 +1485,41 @@ export default function ChatPanel({
|
||||
Context: {fmt(queenUsage.estimatedTokens)}/{fmt(queenUsage.maxTokens)}
|
||||
</span>
|
||||
)}
|
||||
{hasTokens && (
|
||||
<span title="LLM tokens used this session (input + output)">
|
||||
Tokens: {fmt(tokenUsage!.input + tokenUsage!.output)}
|
||||
</span>
|
||||
)}
|
||||
{hasTokens && (() => {
|
||||
const cached = tokenUsage!.cached ?? 0;
|
||||
const created = tokenUsage!.cacheCreated ?? 0;
|
||||
const cost = tokenUsage!.costUsd ?? 0;
|
||||
// cached/created are subsets of input — never sum; surface separately.
|
||||
// Cost can be < $0.01; show 4 decimals so small-model sessions aren't "$0.00".
|
||||
const costStr = cost > 0 ? `$${cost.toFixed(4)}` : "—";
|
||||
return (
|
||||
<span className="group relative cursor-help transition-colors hover:text-muted-foreground">
|
||||
Tokens: {fmt(tokenUsage!.input + tokenUsage!.output)}
|
||||
<span
|
||||
role="tooltip"
|
||||
className="pointer-events-none invisible absolute bottom-full right-0 z-50 mb-2 whitespace-nowrap rounded-md border border-border bg-popover px-3 py-2 text-[11px] text-popover-foreground opacity-0 shadow-lg transition-[opacity,transform] duration-150 translate-y-1 group-hover:visible group-hover:opacity-100 group-hover:translate-y-0"
|
||||
>
|
||||
<span className="mb-1.5 block text-muted-foreground">
|
||||
LLM tokens used this session
|
||||
</span>
|
||||
<span className="grid grid-cols-[auto_1fr] gap-x-4 gap-y-0.5 tabular-nums">
|
||||
<span>Input</span>
|
||||
<span className="text-right">{fmt(tokenUsage!.input)}</span>
|
||||
<span className="pl-3 text-muted-foreground">cache read</span>
|
||||
<span className="text-right text-muted-foreground">{fmt(cached)}</span>
|
||||
<span className="pl-3 text-muted-foreground">cache write</span>
|
||||
<span className="text-right text-muted-foreground">{fmt(created)}</span>
|
||||
<span>Output</span>
|
||||
<span className="text-right">{fmt(tokenUsage!.output)}</span>
|
||||
<span className="mt-1 border-t border-border/50 pt-1">Cost</span>
|
||||
<span className="mt-1 border-t border-border/50 pt-1 text-right font-medium">
|
||||
{costStr}
|
||||
</span>
|
||||
</span>
|
||||
</span>
|
||||
</span>
|
||||
);
|
||||
})()}
|
||||
</div>
|
||||
);
|
||||
})()}
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
import { useCallback } from "react";
|
||||
import { coloniesApi } from "@/api/colonies";
|
||||
import ToolsEditor from "./ToolsEditor";
|
||||
|
||||
export default function ColonyToolsSection({
|
||||
colonyName,
|
||||
}: {
|
||||
colonyName: string;
|
||||
}) {
|
||||
const fetchSnapshot = useCallback(
|
||||
() => coloniesApi.getTools(colonyName),
|
||||
[colonyName],
|
||||
);
|
||||
const saveAllowlist = useCallback(
|
||||
(enabled: string[] | null) => coloniesApi.updateTools(colonyName, enabled),
|
||||
[colonyName],
|
||||
);
|
||||
return (
|
||||
<ToolsEditor
|
||||
subjectKey={`colony:${colonyName}`}
|
||||
title="Tools"
|
||||
caveat="Changes apply to the next worker spawn. Running workers keep the tool list they booted with."
|
||||
fetchSnapshot={fetchSnapshot}
|
||||
saveAllowlist={saveAllowlist}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -35,6 +35,10 @@ import { sessionsApi } from "@/api/sessions";
|
||||
import { cronToLabel } from "@/lib/graphUtils";
|
||||
import type { GraphNode } from "@/components/graph-types";
|
||||
import { useColonyWorkers } from "@/context/ColonyWorkersContext";
|
||||
import TaskListPanel from "@/components/TaskListPanel";
|
||||
|
||||
// Re-export so the WorkerDetail block can use it without forward decl.
|
||||
const TaskListPanelLazy = TaskListPanel;
|
||||
import { DataGrid, type SortDir } from "@/components/data-grid";
|
||||
|
||||
interface ColonyWorkersPanelProps {
|
||||
@@ -1607,6 +1611,11 @@ function WorkerDetail({
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Worker session task list — embedded panel, not a separate rail. */}
|
||||
<div className="mb-3">
|
||||
<WorkerTaskList workerId={workerId} colonyName={colonyName} />
|
||||
</div>
|
||||
|
||||
{isHistorical ? (
|
||||
<HistoricalWorkerPlaceholder workerId={workerId} />
|
||||
) : (
|
||||
@@ -1616,6 +1625,28 @@ function WorkerDetail({
|
||||
);
|
||||
}
|
||||
|
||||
function WorkerTaskList({
|
||||
workerId,
|
||||
colonyName: _colonyName,
|
||||
}: {
|
||||
workerId: string;
|
||||
colonyName: string | null;
|
||||
}) {
|
||||
// Workers' task_list_id is session:{worker_id}:{worker_id} (the worker's
|
||||
// session_id == its worker_id under ColonyRuntime.spawn). The SSE
|
||||
// events for it ride on the colony's bus, which we subscribe to via
|
||||
// the queen's session id (already streaming in this view).
|
||||
const { sessionId } = useColonyWorkers();
|
||||
return (
|
||||
<TaskListPanelLazy
|
||||
taskListId={`session:${workerId}:${workerId}`}
|
||||
sessionId={sessionId ?? ""}
|
||||
title="Worker session"
|
||||
variant="embedded"
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
function LiveWorkerProgress({
|
||||
colonyName,
|
||||
workerId,
|
||||
|
||||
@@ -0,0 +1,651 @@
|
||||
import { useEffect, useState } from "react";
|
||||
import {
|
||||
Plus,
|
||||
Trash2,
|
||||
RefreshCw,
|
||||
Loader2,
|
||||
AlertCircle,
|
||||
Check,
|
||||
X,
|
||||
Server,
|
||||
CircleCheck,
|
||||
CircleAlert,
|
||||
CircleDashed,
|
||||
} from "lucide-react";
|
||||
import {
|
||||
mcpApi,
|
||||
type McpServer,
|
||||
type McpTransport,
|
||||
type AddMcpServerBody,
|
||||
} from "@/api/mcp";
|
||||
|
||||
type TransportKey = McpTransport;
|
||||
|
||||
const TRANSPORT_OPTIONS: TransportKey[] = ["stdio", "http", "sse", "unix"];
|
||||
|
||||
function healthBadge(server: McpServer) {
|
||||
if (!server.enabled) {
|
||||
return (
|
||||
<span className="flex items-center gap-1 text-[11px] text-muted-foreground">
|
||||
<CircleDashed className="w-3 h-3" /> Disabled
|
||||
</span>
|
||||
);
|
||||
}
|
||||
if (server.last_health_status === "healthy") {
|
||||
return (
|
||||
<span className="flex items-center gap-1 text-[11px] text-green-500">
|
||||
<CircleCheck className="w-3 h-3" /> Healthy
|
||||
</span>
|
||||
);
|
||||
}
|
||||
if (server.last_health_status === "unhealthy") {
|
||||
return (
|
||||
<span
|
||||
className="flex items-center gap-1 text-[11px] text-red-400"
|
||||
title={server.last_error || "Unhealthy"}
|
||||
>
|
||||
<CircleAlert className="w-3 h-3" /> Unhealthy
|
||||
</span>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<span className="flex items-center gap-1 text-[11px] text-muted-foreground">
|
||||
<CircleDashed className="w-3 h-3" /> Unknown
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
interface AddFormState {
|
||||
name: string;
|
||||
transport: TransportKey;
|
||||
command: string;
|
||||
args: string;
|
||||
env: string;
|
||||
cwd: string;
|
||||
url: string;
|
||||
headers: string;
|
||||
socketPath: string;
|
||||
description: string;
|
||||
}
|
||||
|
||||
const EMPTY_FORM: AddFormState = {
|
||||
name: "",
|
||||
transport: "stdio",
|
||||
command: "",
|
||||
args: "",
|
||||
env: "",
|
||||
cwd: "",
|
||||
url: "",
|
||||
headers: "",
|
||||
socketPath: "",
|
||||
description: "",
|
||||
};
|
||||
|
||||
function parseKeyValueLines(text: string): Record<string, string> {
|
||||
const out: Record<string, string> = {};
|
||||
text
|
||||
.split("\n")
|
||||
.map((l) => l.trim())
|
||||
.filter(Boolean)
|
||||
.forEach((line) => {
|
||||
const eq = line.indexOf("=");
|
||||
if (eq < 0) return;
|
||||
const k = line.slice(0, eq).trim();
|
||||
const v = line.slice(eq + 1).trim();
|
||||
if (k) out[k] = v;
|
||||
});
|
||||
return out;
|
||||
}
|
||||
|
||||
function buildAddBody(form: AddFormState): AddMcpServerBody {
|
||||
const body: AddMcpServerBody = {
|
||||
name: form.name.trim(),
|
||||
transport: form.transport,
|
||||
description: form.description.trim() || undefined,
|
||||
};
|
||||
if (form.transport === "stdio") {
|
||||
body.command = form.command.trim();
|
||||
const args = form.args
|
||||
.split("\n")
|
||||
.map((s) => s.trim())
|
||||
.filter(Boolean);
|
||||
if (args.length) body.args = args;
|
||||
const env = parseKeyValueLines(form.env);
|
||||
if (Object.keys(env).length) body.env = env;
|
||||
if (form.cwd.trim()) body.cwd = form.cwd.trim();
|
||||
} else if (form.transport === "http" || form.transport === "sse") {
|
||||
body.url = form.url.trim();
|
||||
const headers = parseKeyValueLines(form.headers);
|
||||
if (Object.keys(headers).length) body.headers = headers;
|
||||
} else if (form.transport === "unix") {
|
||||
body.socket_path = form.socketPath.trim();
|
||||
}
|
||||
return body;
|
||||
}
|
||||
|
||||
export default function McpServersPanel() {
|
||||
const [servers, setServers] = useState<McpServer[] | null>(null);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
const [adding, setAdding] = useState(false);
|
||||
const [form, setForm] = useState<AddFormState>(EMPTY_FORM);
|
||||
const [submitting, setSubmitting] = useState(false);
|
||||
const [submitError, setSubmitError] = useState<string | null>(null);
|
||||
|
||||
const [busyByName, setBusyByName] = useState<Record<string, boolean>>({});
|
||||
|
||||
const refresh = async () => {
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
try {
|
||||
const { servers } = await mcpApi.listServers();
|
||||
setServers(servers);
|
||||
} catch (e: unknown) {
|
||||
setError((e as Error)?.message || "Failed to load MCP servers");
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
refresh();
|
||||
}, []);
|
||||
|
||||
const setBusy = (name: string, v: boolean) =>
|
||||
setBusyByName((p) => ({ ...p, [name]: v }));
|
||||
|
||||
const handleToggle = async (server: McpServer) => {
|
||||
setBusy(server.name, true);
|
||||
try {
|
||||
await mcpApi.setEnabled(server.name, !server.enabled);
|
||||
await refresh();
|
||||
} catch (e: unknown) {
|
||||
setError((e as Error)?.message || "Toggle failed");
|
||||
} finally {
|
||||
setBusy(server.name, false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleRemove = async (server: McpServer) => {
|
||||
if (!confirm(`Remove MCP server "${server.name}"?`)) return;
|
||||
setBusy(server.name, true);
|
||||
try {
|
||||
await mcpApi.removeServer(server.name);
|
||||
await refresh();
|
||||
} catch (e: unknown) {
|
||||
const body = (e as { body?: { error?: string } }).body;
|
||||
setError(body?.error || (e as Error)?.message || "Remove failed");
|
||||
} finally {
|
||||
setBusy(server.name, false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleHealth = async (server: McpServer) => {
|
||||
setBusy(server.name, true);
|
||||
try {
|
||||
await mcpApi.checkHealth(server.name);
|
||||
await refresh();
|
||||
} catch (e: unknown) {
|
||||
setError((e as Error)?.message || "Health check failed");
|
||||
} finally {
|
||||
setBusy(server.name, false);
|
||||
}
|
||||
};
|
||||
|
||||
const canSubmit = (() => {
|
||||
if (!form.name.trim()) return false;
|
||||
if (form.transport === "stdio") return !!form.command.trim();
|
||||
if (form.transport === "http" || form.transport === "sse")
|
||||
return !!form.url.trim();
|
||||
if (form.transport === "unix") return !!form.socketPath.trim();
|
||||
return false;
|
||||
})();
|
||||
|
||||
const handleSubmit = async () => {
|
||||
if (!canSubmit) return;
|
||||
setSubmitting(true);
|
||||
setSubmitError(null);
|
||||
try {
|
||||
const body = buildAddBody(form);
|
||||
const { server } = await mcpApi.addServer(body);
|
||||
// Best-effort: auto-run health check so the UI shows tool count.
|
||||
try {
|
||||
await mcpApi.checkHealth(server.name);
|
||||
} catch {
|
||||
/* health check is informational; don't block the add flow */
|
||||
}
|
||||
setAdding(false);
|
||||
setForm(EMPTY_FORM);
|
||||
await refresh();
|
||||
} catch (e: unknown) {
|
||||
const body = (e as { body?: { error?: string; fix?: string } }).body;
|
||||
setSubmitError(
|
||||
[body?.error, body?.fix].filter(Boolean).join(" — ") ||
|
||||
(e as Error)?.message ||
|
||||
"Add failed",
|
||||
);
|
||||
} finally {
|
||||
setSubmitting(false);
|
||||
}
|
||||
};
|
||||
|
||||
// Group by origin. "local" = user-registered via the UI or CLI. Everything
|
||||
// else (built-in package entries, registry-installed entries) sits under
|
||||
// "Built-in" since the user can't remove them from the UI.
|
||||
const builtIns = (servers || []).filter((s) => s.source !== "local");
|
||||
const custom = (servers || []).filter((s) => s.source === "local");
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-5">
|
||||
<div className="flex items-start justify-between gap-3">
|
||||
<div>
|
||||
<h3 className="text-lg font-semibold text-foreground">MCP Servers</h3>
|
||||
<p className="text-sm text-muted-foreground mt-1">
|
||||
Register your own MCP servers so queens can use their tools. New
|
||||
servers take effect in the next queen session you start.
|
||||
</p>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<button
|
||||
onClick={refresh}
|
||||
disabled={loading}
|
||||
className="flex items-center gap-1.5 px-3 py-1.5 rounded-md border border-border/60 text-xs text-muted-foreground hover:text-foreground hover:bg-muted/30 disabled:opacity-50"
|
||||
title="Refresh"
|
||||
>
|
||||
<RefreshCw className={`w-3 h-3 ${loading ? "animate-spin" : ""}`} />
|
||||
</button>
|
||||
<button
|
||||
onClick={() => {
|
||||
setAdding(true);
|
||||
setForm(EMPTY_FORM);
|
||||
setSubmitError(null);
|
||||
}}
|
||||
className="flex items-center gap-1.5 px-3 py-1.5 rounded-md bg-primary text-primary-foreground text-xs font-semibold hover:bg-primary/90"
|
||||
>
|
||||
<Plus className="w-3 h-3" />
|
||||
Add MCP Server
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="flex items-start gap-2 text-xs text-destructive p-2.5 rounded-md bg-destructive/10 border border-destructive/30">
|
||||
<AlertCircle className="w-3.5 h-3.5 mt-0.5 flex-shrink-0" />
|
||||
<span className="flex-1">{error}</span>
|
||||
<button
|
||||
onClick={() => setError(null)}
|
||||
className="text-destructive/70 hover:text-destructive"
|
||||
>
|
||||
<X className="w-3 h-3" />
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{loading && !servers && (
|
||||
<div className="flex items-center gap-2 text-xs text-muted-foreground">
|
||||
<Loader2 className="w-3 h-3 animate-spin" /> Loading MCP servers…
|
||||
</div>
|
||||
)}
|
||||
|
||||
{servers && (
|
||||
<>
|
||||
{custom.length > 0 && (
|
||||
<Section title="My Custom">
|
||||
{custom.map((s) => (
|
||||
<ServerRow
|
||||
key={s.name}
|
||||
server={s}
|
||||
busy={!!busyByName[s.name]}
|
||||
onToggle={() => handleToggle(s)}
|
||||
onRemove={() => handleRemove(s)}
|
||||
onHealth={() => handleHealth(s)}
|
||||
isLocal
|
||||
/>
|
||||
))}
|
||||
</Section>
|
||||
)}
|
||||
<Section title="Built-in">
|
||||
{builtIns.length === 0 ? (
|
||||
<p className="text-xs text-muted-foreground px-2 py-2">
|
||||
No built-in servers registered.
|
||||
</p>
|
||||
) : (
|
||||
builtIns.map((s) => (
|
||||
<ServerRow
|
||||
key={s.name}
|
||||
server={s}
|
||||
busy={!!busyByName[s.name]}
|
||||
onToggle={() => handleToggle(s)}
|
||||
onRemove={() => handleRemove(s)}
|
||||
onHealth={() => handleHealth(s)}
|
||||
/>
|
||||
))
|
||||
)}
|
||||
</Section>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Add MCP modal */}
|
||||
{adding && (
|
||||
<div className="fixed inset-0 z-[60] flex items-center justify-center">
|
||||
<div
|
||||
className="absolute inset-0 bg-black/50"
|
||||
onClick={() => !submitting && setAdding(false)}
|
||||
/>
|
||||
<div className="relative bg-card border border-border/60 rounded-xl shadow-2xl w-full max-w-lg p-5 space-y-4 max-h-[85vh] overflow-y-auto">
|
||||
<div className="flex items-center justify-between">
|
||||
<h3 className="text-sm font-semibold text-foreground">
|
||||
Add MCP Server
|
||||
</h3>
|
||||
<button
|
||||
onClick={() => !submitting && setAdding(false)}
|
||||
className="p-1 rounded text-muted-foreground hover:text-foreground"
|
||||
>
|
||||
<X className="w-4 h-4" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<FieldRow label="Name *" hint="Unique identifier, e.g. my-search-tool">
|
||||
<input
|
||||
autoFocus
|
||||
value={form.name}
|
||||
onChange={(e) =>
|
||||
setForm((f) => ({
|
||||
...f,
|
||||
name: e.target.value.toLowerCase().replace(/[^a-z0-9_-]/g, ""),
|
||||
}))
|
||||
}
|
||||
placeholder="my-search-tool"
|
||||
className={inputCls}
|
||||
/>
|
||||
</FieldRow>
|
||||
|
||||
<FieldRow label="Transport *">
|
||||
<div className="flex gap-1">
|
||||
{TRANSPORT_OPTIONS.map((t) => (
|
||||
<button
|
||||
key={t}
|
||||
onClick={() => setForm((f) => ({ ...f, transport: t }))}
|
||||
className={`flex-1 px-3 py-1.5 rounded-md text-xs font-medium border ${
|
||||
form.transport === t
|
||||
? "bg-primary/15 text-primary border-primary/40"
|
||||
: "text-muted-foreground hover:text-foreground border-border/60 hover:bg-muted/30"
|
||||
}`}
|
||||
>
|
||||
{t}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
</FieldRow>
|
||||
|
||||
{form.transport === "stdio" && (
|
||||
<>
|
||||
<FieldRow
|
||||
label="Command *"
|
||||
hint="Executable that speaks MCP over stdin/stdout"
|
||||
>
|
||||
<input
|
||||
value={form.command}
|
||||
onChange={(e) =>
|
||||
setForm((f) => ({ ...f, command: e.target.value }))
|
||||
}
|
||||
placeholder="uv"
|
||||
className={inputCls}
|
||||
/>
|
||||
</FieldRow>
|
||||
<FieldRow label="Args (one per line)">
|
||||
<textarea
|
||||
value={form.args}
|
||||
onChange={(e) =>
|
||||
setForm((f) => ({ ...f, args: e.target.value }))
|
||||
}
|
||||
rows={3}
|
||||
placeholder={"run\npython\nmy_server.py\n--stdio"}
|
||||
className={textareaCls}
|
||||
/>
|
||||
</FieldRow>
|
||||
<FieldRow label="Env (KEY=VALUE, one per line)">
|
||||
<textarea
|
||||
value={form.env}
|
||||
onChange={(e) =>
|
||||
setForm((f) => ({ ...f, env: e.target.value }))
|
||||
}
|
||||
rows={2}
|
||||
placeholder="API_KEY=abc123"
|
||||
className={textareaCls}
|
||||
/>
|
||||
</FieldRow>
|
||||
<FieldRow label="Working directory">
|
||||
<input
|
||||
value={form.cwd}
|
||||
onChange={(e) =>
|
||||
setForm((f) => ({ ...f, cwd: e.target.value }))
|
||||
}
|
||||
placeholder="/path/to/repo"
|
||||
className={inputCls}
|
||||
/>
|
||||
</FieldRow>
|
||||
</>
|
||||
)}
|
||||
|
||||
{(form.transport === "http" || form.transport === "sse") && (
|
||||
<>
|
||||
<FieldRow label="URL *">
|
||||
<input
|
||||
value={form.url}
|
||||
onChange={(e) =>
|
||||
setForm((f) => ({ ...f, url: e.target.value }))
|
||||
}
|
||||
placeholder="https://example.com/mcp"
|
||||
className={inputCls}
|
||||
/>
|
||||
</FieldRow>
|
||||
<FieldRow label="Headers (KEY=VALUE, one per line)">
|
||||
<textarea
|
||||
value={form.headers}
|
||||
onChange={(e) =>
|
||||
setForm((f) => ({ ...f, headers: e.target.value }))
|
||||
}
|
||||
rows={2}
|
||||
placeholder="Authorization=Bearer ..."
|
||||
className={textareaCls}
|
||||
/>
|
||||
</FieldRow>
|
||||
</>
|
||||
)}
|
||||
|
||||
{form.transport === "unix" && (
|
||||
<FieldRow label="Socket path *">
|
||||
<input
|
||||
value={form.socketPath}
|
||||
onChange={(e) =>
|
||||
setForm((f) => ({ ...f, socketPath: e.target.value }))
|
||||
}
|
||||
placeholder="/tmp/mcp.sock"
|
||||
className={inputCls}
|
||||
/>
|
||||
</FieldRow>
|
||||
)}
|
||||
|
||||
<FieldRow label="Description">
|
||||
<input
|
||||
value={form.description}
|
||||
onChange={(e) =>
|
||||
setForm((f) => ({ ...f, description: e.target.value }))
|
||||
}
|
||||
placeholder="What this server does"
|
||||
className={inputCls}
|
||||
/>
|
||||
</FieldRow>
|
||||
|
||||
{submitError && (
|
||||
<div className="flex items-start gap-2 text-xs text-destructive p-2 rounded-md bg-destructive/10 border border-destructive/30">
|
||||
<AlertCircle className="w-3.5 h-3.5 mt-0.5 flex-shrink-0" />
|
||||
<span>{submitError}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex justify-end gap-2 pt-1">
|
||||
<button
|
||||
onClick={() => setAdding(false)}
|
||||
disabled={submitting}
|
||||
className="px-3 py-1.5 rounded-md text-xs font-medium text-muted-foreground hover:text-foreground hover:bg-muted/30"
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
<button
|
||||
onClick={handleSubmit}
|
||||
disabled={!canSubmit || submitting}
|
||||
className="flex items-center gap-1.5 px-3 py-1.5 rounded-md bg-primary text-primary-foreground text-xs font-semibold hover:bg-primary/90 disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
>
|
||||
{submitting ? (
|
||||
<Loader2 className="w-3 h-3 animate-spin" />
|
||||
) : (
|
||||
<Check className="w-3 h-3" />
|
||||
)}
|
||||
{submitting ? "Adding…" : "Add"}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const inputCls =
|
||||
"w-full bg-muted/30 border border-border/50 rounded-lg px-3 py-2 text-sm text-foreground focus:outline-none focus:ring-1 focus:ring-primary/40";
|
||||
const textareaCls = `${inputCls} resize-none font-mono text-xs`;
|
||||
|
||||
function FieldRow({
|
||||
label,
|
||||
hint,
|
||||
children,
|
||||
}: {
|
||||
label: string;
|
||||
hint?: string;
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
return (
|
||||
<div>
|
||||
<label className="text-[11px] font-semibold text-muted-foreground uppercase tracking-wider mb-1.5 block">
|
||||
{label}
|
||||
</label>
|
||||
{children}
|
||||
{hint && (
|
||||
<p className="text-[11px] text-muted-foreground/70 mt-1">{hint}</p>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function Section({
|
||||
title,
|
||||
children,
|
||||
}: {
|
||||
title: string;
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
return (
|
||||
<div>
|
||||
<p className="text-[11px] font-semibold text-muted-foreground/60 uppercase tracking-wider mb-2">
|
||||
{title}
|
||||
</p>
|
||||
<div className="flex flex-col gap-1">{children}</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ServerRow({
|
||||
server,
|
||||
busy,
|
||||
onToggle,
|
||||
onRemove,
|
||||
onHealth,
|
||||
isLocal,
|
||||
}: {
|
||||
server: McpServer;
|
||||
busy: boolean;
|
||||
onToggle: () => void;
|
||||
onRemove: () => void;
|
||||
onHealth: () => void;
|
||||
isLocal?: boolean;
|
||||
}) {
|
||||
// Package-baked servers live in the repo and aren't managed by
|
||||
// MCPRegistry, so toggling / removing / health-checking them would
|
||||
// fail against the backend. Show them as read-only.
|
||||
const isBuiltIn = server.source === "built-in";
|
||||
return (
|
||||
<div className="flex items-center gap-3 py-2.5 px-2 rounded-lg hover:bg-muted/20">
|
||||
<div className="w-9 h-9 rounded-full bg-primary/10 flex items-center justify-center flex-shrink-0">
|
||||
<Server className="w-4 h-4 text-primary" />
|
||||
</div>
|
||||
<div className="flex-1 min-w-0">
|
||||
<div className="flex items-center gap-2">
|
||||
<p className="text-sm font-medium text-foreground truncate">
|
||||
{server.name}
|
||||
</p>
|
||||
<span className="text-[10px] uppercase tracking-wider text-muted-foreground/60">
|
||||
{server.transport}
|
||||
</span>
|
||||
{isBuiltIn && (
|
||||
<span className="text-[10px] uppercase tracking-wider text-muted-foreground/80 bg-muted/40 px-1.5 py-0.5 rounded">
|
||||
Built-in
|
||||
</span>
|
||||
)}
|
||||
{server.tool_count !== null && server.tool_count !== undefined && (
|
||||
<span className="text-[11px] text-muted-foreground">
|
||||
{server.tool_count} tools
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
{!isBuiltIn && healthBadge(server)}
|
||||
{server.description && (
|
||||
<span className="text-xs text-muted-foreground truncate">
|
||||
{isBuiltIn ? server.description : `· ${server.description}`}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
{!isBuiltIn && (
|
||||
<>
|
||||
<button
|
||||
onClick={onHealth}
|
||||
disabled={busy}
|
||||
className="p-1.5 rounded-md text-muted-foreground hover:text-foreground hover:bg-muted/40 disabled:opacity-50"
|
||||
title="Health check"
|
||||
>
|
||||
{busy ? (
|
||||
<Loader2 className="w-3.5 h-3.5 animate-spin" />
|
||||
) : (
|
||||
<RefreshCw className="w-3.5 h-3.5" />
|
||||
)}
|
||||
</button>
|
||||
<button
|
||||
onClick={onToggle}
|
||||
disabled={busy}
|
||||
className={`px-3 py-1 rounded-md text-[11px] font-semibold border disabled:opacity-50 ${
|
||||
server.enabled
|
||||
? "text-muted-foreground border-border/60 hover:bg-muted/30"
|
||||
: "bg-primary/15 text-primary border-primary/40 hover:bg-primary/25"
|
||||
}`}
|
||||
>
|
||||
{server.enabled ? "Disable" : "Enable"}
|
||||
</button>
|
||||
</>
|
||||
)}
|
||||
{isLocal && !isBuiltIn && (
|
||||
<button
|
||||
onClick={onRemove}
|
||||
disabled={busy}
|
||||
className="p-1.5 rounded-md text-muted-foreground hover:text-red-400 hover:bg-red-500/10 disabled:opacity-50"
|
||||
title="Remove"
|
||||
>
|
||||
<Trash2 className="w-3.5 h-3.5" />
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import { executionApi } from "@/api/execution";
|
||||
import { compressImage } from "@/lib/image-utils";
|
||||
import type { Colony } from "@/types/colony";
|
||||
import { slugToColonyId } from "@/lib/colony-registry";
|
||||
import QueenToolsSection from "./QueenToolsSection";
|
||||
|
||||
interface QueenProfilePanelProps {
|
||||
queenId: string;
|
||||
@@ -354,6 +355,10 @@ export default function QueenProfilePanel({ queenId, colonies, onClose }: QueenP
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="mb-6">
|
||||
<QueenToolsSection queenId={queenId} />
|
||||
</div>
|
||||
|
||||
{colonies.length > 0 && (
|
||||
<div>
|
||||
<h4 className="text-[11px] font-semibold text-muted-foreground uppercase tracking-wider mb-2">Assigned Colonies</h4>
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
import { useCallback } from "react";
|
||||
import { queensApi } from "@/api/queens";
|
||||
import ToolsEditor from "./ToolsEditor";
|
||||
|
||||
export default function QueenToolsSection({ queenId }: { queenId: string }) {
|
||||
const fetchSnapshot = useCallback(
|
||||
() => queensApi.getTools(queenId),
|
||||
[queenId],
|
||||
);
|
||||
const saveAllowlist = useCallback(
|
||||
(enabled: string[] | null) => queensApi.updateTools(queenId, enabled),
|
||||
[queenId],
|
||||
);
|
||||
const resetToRoleDefault = useCallback(
|
||||
() => queensApi.resetTools(queenId),
|
||||
[queenId],
|
||||
);
|
||||
return (
|
||||
<ToolsEditor
|
||||
subjectKey={`queen:${queenId}`}
|
||||
title="Tools"
|
||||
fetchSnapshot={fetchSnapshot}
|
||||
saveAllowlist={saveAllowlist}
|
||||
resetToRoleDefault={resetToRoleDefault}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -6,11 +6,12 @@ import { useModel, LLM_PROVIDERS } from "@/context/ModelContext";
|
||||
import { credentialsApi } from "@/api/credentials";
|
||||
import { configApi, type ModelOption } from "@/api/config";
|
||||
import { compressImage } from "@/lib/image-utils";
|
||||
import McpServersPanel from "./McpServersPanel";
|
||||
|
||||
interface SettingsModalProps {
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
initialSection?: "profile" | "byok";
|
||||
initialSection?: "profile" | "byok" | "mcp";
|
||||
}
|
||||
|
||||
function ValidationBadge({ state }: { state: "validating" | { valid: boolean | null; message: string } | undefined }) {
|
||||
@@ -37,7 +38,7 @@ export default function SettingsModal({ open, onClose, initialSection }: Setting
|
||||
|
||||
const [displayName, setDisplayName] = useState(userProfile.displayName);
|
||||
const [about, setAbout] = useState(userProfile.about);
|
||||
const [activeSection, setActiveSection] = useState<"profile" | "byok">(initialSection || "profile");
|
||||
const [activeSection, setActiveSection] = useState<"profile" | "byok" | "mcp">(initialSection || "profile");
|
||||
const [editingProvider, setEditingProvider] = useState<string | null>(null);
|
||||
const [keyInput, setKeyInput] = useState("");
|
||||
const [showKey, setShowKey] = useState(false);
|
||||
@@ -187,6 +188,12 @@ export default function SettingsModal({ open, onClose, initialSection }: Setting
|
||||
>
|
||||
BYOK
|
||||
</button>
|
||||
<button
|
||||
onClick={() => setActiveSection("mcp")}
|
||||
className={`text-left text-sm px-3 py-1.5 rounded-md ${activeSection === "mcp" ? "bg-primary/15 text-primary font-medium" : "text-muted-foreground hover:text-foreground hover:bg-muted/30"}`}
|
||||
>
|
||||
MCP Servers
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -267,6 +274,8 @@ export default function SettingsModal({ open, onClose, initialSection }: Setting
|
||||
</>
|
||||
)}
|
||||
|
||||
{activeSection === "mcp" && <McpServersPanel />}
|
||||
|
||||
{activeSection === "byok" && (
|
||||
<>
|
||||
<div>
|
||||
|
||||
@@ -5,13 +5,13 @@ import {
|
||||
ChevronRight,
|
||||
MessageSquarePlus,
|
||||
Network,
|
||||
Sparkles,
|
||||
KeyRound,
|
||||
ChevronDown,
|
||||
Plus,
|
||||
X,
|
||||
Crown,
|
||||
Loader2,
|
||||
Library,
|
||||
} from "lucide-react";
|
||||
import SidebarColonyItem from "./SidebarColonyItem";
|
||||
import SidebarQueenItem from "./SidebarQueenItem";
|
||||
@@ -28,6 +28,7 @@ export default function Sidebar() {
|
||||
);
|
||||
const [coloniesExpanded, setColoniesExpanded] = useState(true);
|
||||
const [queensExpanded, setQueensExpanded] = useState(true);
|
||||
const [libraryExpanded, setLibraryExpanded] = useState(false);
|
||||
|
||||
// Colony creation
|
||||
const [createColonyOpen, setCreateColonyOpen] = useState(false);
|
||||
@@ -165,13 +166,6 @@ export default function Sidebar() {
|
||||
<Network className="w-4 h-4" />
|
||||
<span>Org Chart</span>
|
||||
</button>
|
||||
<button
|
||||
onClick={() => navigate("/prompt-library")}
|
||||
className="flex items-center gap-2.5 px-3 py-1.5 rounded-md text-sm text-foreground/70 hover:bg-sidebar-item-hover hover:text-foreground transition-colors"
|
||||
>
|
||||
<Sparkles className="w-4 h-4" />
|
||||
<span>Prompt Library</span>
|
||||
</button>
|
||||
<button
|
||||
onClick={() => navigate("/credentials")}
|
||||
className="flex items-center gap-2.5 px-3 py-1.5 rounded-md text-sm text-foreground/70 hover:bg-sidebar-item-hover hover:text-foreground transition-colors"
|
||||
@@ -179,6 +173,40 @@ export default function Sidebar() {
|
||||
<KeyRound className="w-4 h-4" />
|
||||
<span>Credentials</span>
|
||||
</button>
|
||||
<button
|
||||
onClick={() => setLibraryExpanded((v) => !v)}
|
||||
className="flex items-center gap-2.5 px-3 py-1.5 rounded-md text-sm text-foreground/70 hover:bg-sidebar-item-hover hover:text-foreground transition-colors"
|
||||
>
|
||||
<Library className="w-4 h-4" />
|
||||
<span className="flex-1 text-left">Configuration</span>
|
||||
<ChevronDown
|
||||
className={`w-3.5 h-3.5 transition-transform ${
|
||||
libraryExpanded ? "" : "-rotate-90"
|
||||
}`}
|
||||
/>
|
||||
</button>
|
||||
{libraryExpanded && (
|
||||
<>
|
||||
<button
|
||||
onClick={() => navigate("/prompt-library")}
|
||||
className="flex items-center gap-2.5 pl-9 pr-3 py-1.5 rounded-md text-sm text-foreground/70 hover:bg-sidebar-item-hover hover:text-foreground transition-colors"
|
||||
>
|
||||
<span>Prompts</span>
|
||||
</button>
|
||||
<button
|
||||
onClick={() => navigate("/skills-library")}
|
||||
className="flex items-center gap-2.5 pl-9 pr-3 py-1.5 rounded-md text-sm text-foreground/70 hover:bg-sidebar-item-hover hover:text-foreground transition-colors"
|
||||
>
|
||||
<span>Skills</span>
|
||||
</button>
|
||||
<button
|
||||
onClick={() => navigate("/tool-library")}
|
||||
className="flex items-center gap-2.5 pl-9 pr-3 py-1.5 rounded-md text-sm text-foreground/70 hover:bg-sidebar-item-hover hover:text-foreground transition-colors"
|
||||
>
|
||||
<span>Tools</span>
|
||||
</button>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* COLONIES section */}
|
||||
|
||||
@@ -0,0 +1,102 @@
|
||||
import { useState } from "react";
|
||||
import { Check, Circle, Hourglass, Loader2 } from "lucide-react";
|
||||
|
||||
import type { TaskRecord, TaskStatus } from "@/api/tasks";
|
||||
|
||||
interface TaskItemProps {
|
||||
task: TaskRecord;
|
||||
unresolvedBlockers: number[];
|
||||
onJumpToBlocker?: (id: number) => void;
|
||||
}
|
||||
|
||||
const STATUS_ICON: Record<TaskStatus, JSX.Element> = {
|
||||
in_progress: (
|
||||
<Loader2 className="h-3.5 w-3.5 animate-spin text-amber-500" aria-label="in progress" />
|
||||
),
|
||||
pending: <Circle className="h-3.5 w-3.5 text-muted-foreground" aria-label="pending" />,
|
||||
completed: <Check className="h-3.5 w-3.5 text-emerald-600" aria-label="completed" />,
|
||||
};
|
||||
|
||||
function elapsedSince(ts: number): string {
|
||||
const now = Date.now() / 1000;
|
||||
const diff = Math.max(0, now - ts);
|
||||
if (diff < 60) return `${Math.floor(diff)}s`;
|
||||
if (diff < 3600) return `${Math.floor(diff / 60)}m ${Math.floor(diff % 60)}s`;
|
||||
return `${Math.floor(diff / 3600)}h ${Math.floor((diff % 3600) / 60)}m`;
|
||||
}
|
||||
|
||||
export default function TaskItem({ task, unresolvedBlockers, onJumpToBlocker }: TaskItemProps) {
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
const isBlocked = task.status === "pending" && unresolvedBlockers.length > 0;
|
||||
const elapsed = task.status === "in_progress" ? elapsedSince(task.updated_at) : null;
|
||||
|
||||
return (
|
||||
<li className="group">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setExpanded((v) => !v)}
|
||||
className="w-full text-left flex items-start gap-2 px-2 py-1.5 rounded hover:bg-muted/50 focus:bg-muted/60 focus:outline-none"
|
||||
>
|
||||
<span className="mt-0.5 flex-shrink-0">
|
||||
{isBlocked ? (
|
||||
<Hourglass
|
||||
className="h-3.5 w-3.5 text-muted-foreground/70"
|
||||
aria-label="waiting on dependency"
|
||||
/>
|
||||
) : (
|
||||
STATUS_ICON[task.status]
|
||||
)}
|
||||
</span>
|
||||
<span className="flex-1 min-w-0">
|
||||
<span className="text-sm flex items-baseline gap-1.5">
|
||||
<span className="text-muted-foreground tabular-nums">#{task.id}</span>
|
||||
<span className="truncate">
|
||||
{task.status === "in_progress" && task.active_form
|
||||
? task.active_form
|
||||
: task.subject}
|
||||
</span>
|
||||
</span>
|
||||
<span className="flex items-center gap-2 text-xs text-muted-foreground mt-0.5">
|
||||
{task.owner ? (
|
||||
<span className="rounded bg-muted px-1.5 py-0.5">{task.owner.slice(0, 12)}</span>
|
||||
) : null}
|
||||
{elapsed ? <span>{elapsed}</span> : null}
|
||||
{unresolvedBlockers.length > 0 ? (
|
||||
<span>
|
||||
blocked by{" "}
|
||||
{unresolvedBlockers.map((b, idx) => (
|
||||
<span key={b}>
|
||||
<button
|
||||
type="button"
|
||||
className="text-foreground/70 hover:underline"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
onJumpToBlocker?.(b);
|
||||
}}
|
||||
>
|
||||
#{b}
|
||||
</button>
|
||||
{idx < unresolvedBlockers.length - 1 ? ", " : ""}
|
||||
</span>
|
||||
))}
|
||||
</span>
|
||||
) : null}
|
||||
</span>
|
||||
</span>
|
||||
</button>
|
||||
{expanded ? (
|
||||
<div className="ml-7 mb-2 text-xs text-muted-foreground space-y-1">
|
||||
{task.description ? <p className="whitespace-pre-wrap">{task.description}</p> : null}
|
||||
{task.metadata && Object.keys(task.metadata).length > 0 ? (
|
||||
<pre className="text-[10px] bg-muted/40 rounded p-2 overflow-x-auto">
|
||||
{JSON.stringify(task.metadata, null, 2)}
|
||||
</pre>
|
||||
) : null}
|
||||
<p className="text-[10px]">
|
||||
updated {new Date(task.updated_at * 1000).toLocaleString()}
|
||||
</p>
|
||||
</div>
|
||||
) : null}
|
||||
</li>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,277 @@
|
||||
/**
|
||||
* Task list panel — renders one task list (queen-DM session, colony
|
||||
* template, or worker session). Variants:
|
||||
*
|
||||
* variant="rail" -> right-rail panel with header & close button
|
||||
* variant="embedded" -> inline (e.g., inside WorkerDetail)
|
||||
*/
|
||||
|
||||
import { useRef, useState } from "react";
|
||||
import { ChevronDown, ChevronRight, X } from "lucide-react";
|
||||
|
||||
import {
|
||||
TaskListProvider,
|
||||
useTaskList,
|
||||
bucketTasks,
|
||||
unresolvedBlockers,
|
||||
} from "@/context/TaskListContext";
|
||||
import TaskItem from "@/components/TaskItem";
|
||||
import type { TaskRecord } from "@/api/tasks";
|
||||
|
||||
interface TaskListPanelProps {
|
||||
taskListId: string;
|
||||
sessionId?: string;
|
||||
/** Override the default header label. */
|
||||
title?: string;
|
||||
variant?: "rail" | "embedded";
|
||||
onClose?: () => void;
|
||||
}
|
||||
|
||||
export default function TaskListPanel(props: TaskListPanelProps) {
|
||||
return (
|
||||
<TaskListProvider taskListId={props.taskListId} sessionId={props.sessionId}>
|
||||
<TaskListPanelInner {...props} />
|
||||
</TaskListProvider>
|
||||
);
|
||||
}
|
||||
|
||||
function TaskListPanelInner({ title, variant = "rail", onClose }: TaskListPanelProps) {
|
||||
const { tasks, loading, error, role, exists } = useTaskList();
|
||||
const buckets = bucketTasks(tasks);
|
||||
|
||||
// Don't render anything when the list doesn't exist yet AND we're in
|
||||
// the rail variant (queen-DM session that hasn't created any task).
|
||||
// The embedded variant always shows so the section in WorkerDetail/
|
||||
// colony overview keeps a stable layout.
|
||||
if (!loading && !exists && variant === "rail") return null;
|
||||
|
||||
const [activeOpen, setActiveOpen] = useState(true);
|
||||
const [pendingOpen, setPendingOpen] = useState(true);
|
||||
const [completedOpen, setCompletedOpen] = useState(false);
|
||||
|
||||
const itemRefs = useRef(new Map<number, HTMLLIElement>());
|
||||
const handleJumpToBlocker = (id: number) => {
|
||||
const node = itemRefs.current.get(id);
|
||||
if (!node) return;
|
||||
node.scrollIntoView({ behavior: "smooth", block: "center" });
|
||||
node.classList.add("ring-2", "ring-primary/40");
|
||||
setTimeout(() => node.classList.remove("ring-2", "ring-primary/40"), 1500);
|
||||
};
|
||||
|
||||
const headerLabel =
|
||||
title ??
|
||||
(role === "template"
|
||||
? "Colony plan"
|
||||
: role === "session"
|
||||
? "Tasks"
|
||||
: "Tasks");
|
||||
const inProgressCount = buckets.active.length;
|
||||
const totalVisible = buckets.visible.length;
|
||||
|
||||
return (
|
||||
<aside
|
||||
className={
|
||||
variant === "rail"
|
||||
? "w-[320px] flex-shrink-0 border-l border-border bg-background flex flex-col h-full overflow-hidden"
|
||||
: "w-full border border-border rounded-md bg-background flex flex-col"
|
||||
}
|
||||
>
|
||||
<div className="flex items-center justify-between px-3 py-2 border-b border-border">
|
||||
<h2 className="text-sm font-semibold flex items-center gap-2">
|
||||
<span>{headerLabel}</span>
|
||||
<span className="text-xs text-muted-foreground tabular-nums">
|
||||
{inProgressCount}/{totalVisible}
|
||||
</span>
|
||||
</h2>
|
||||
{onClose ? (
|
||||
<button
|
||||
type="button"
|
||||
onClick={onClose}
|
||||
className="text-muted-foreground hover:text-foreground"
|
||||
aria-label="Close"
|
||||
>
|
||||
<X className="h-4 w-4" />
|
||||
</button>
|
||||
) : null}
|
||||
</div>
|
||||
|
||||
<div className="flex-1 overflow-y-auto p-2">
|
||||
{loading ? (
|
||||
<p className="text-xs text-muted-foreground p-2">Loading…</p>
|
||||
) : error ? (
|
||||
<p className="text-xs text-destructive p-2">Error: {error}</p>
|
||||
) : totalVisible === 0 ? (
|
||||
<p className="text-xs text-muted-foreground p-2">
|
||||
{role === "template"
|
||||
? "No template entries yet. The queen will populate this when planning a fan-out."
|
||||
: "No tasks yet. The agent will create them as it plans."}
|
||||
</p>
|
||||
) : (
|
||||
<>
|
||||
{/* Completed sits above Active so finished tasks stay visually
|
||||
* "above" the work that came after them — preserves the order
|
||||
* the user originally saw before the status flipped. */}
|
||||
<Section
|
||||
label="Completed"
|
||||
count={buckets.completed.length}
|
||||
open={completedOpen}
|
||||
onToggle={() => setCompletedOpen((v) => !v)}
|
||||
>
|
||||
{buckets.completed.map((t) => (
|
||||
<RefItem
|
||||
key={t.id}
|
||||
task={t}
|
||||
itemRefs={itemRefs}
|
||||
unresolved={[]}
|
||||
onJumpToBlocker={handleJumpToBlocker}
|
||||
/>
|
||||
))}
|
||||
</Section>
|
||||
<Section
|
||||
label="Active"
|
||||
count={buckets.active.length}
|
||||
open={activeOpen}
|
||||
onToggle={() => setActiveOpen((v) => !v)}
|
||||
>
|
||||
{buckets.active.map((t) => (
|
||||
<RefItem
|
||||
key={t.id}
|
||||
task={t}
|
||||
itemRefs={itemRefs}
|
||||
unresolved={unresolvedBlockers(t, buckets.completedIds)}
|
||||
onJumpToBlocker={handleJumpToBlocker}
|
||||
/>
|
||||
))}
|
||||
</Section>
|
||||
<Section
|
||||
label="Pending"
|
||||
count={buckets.pending.length}
|
||||
open={pendingOpen}
|
||||
onToggle={() => setPendingOpen((v) => !v)}
|
||||
>
|
||||
{buckets.pending.map((t) => (
|
||||
<RefItem
|
||||
key={t.id}
|
||||
task={t}
|
||||
itemRefs={itemRefs}
|
||||
unresolved={unresolvedBlockers(t, buckets.completedIds)}
|
||||
onJumpToBlocker={handleJumpToBlocker}
|
||||
/>
|
||||
))}
|
||||
</Section>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</aside>
|
||||
);
|
||||
}
|
||||
|
||||
function Section({
|
||||
label,
|
||||
count,
|
||||
open,
|
||||
onToggle,
|
||||
children,
|
||||
}: {
|
||||
label: string;
|
||||
count: number;
|
||||
open: boolean;
|
||||
onToggle: () => void;
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
if (count === 0) return null;
|
||||
return (
|
||||
<div className="mb-2">
|
||||
<button
|
||||
type="button"
|
||||
onClick={onToggle}
|
||||
className="flex items-center gap-1 text-xs font-medium text-muted-foreground px-2 py-1 hover:text-foreground"
|
||||
>
|
||||
{open ? <ChevronDown className="h-3 w-3" /> : <ChevronRight className="h-3 w-3" />}
|
||||
<span>{label}</span>
|
||||
<span className="tabular-nums">({count})</span>
|
||||
</button>
|
||||
{open ? <ul className="space-y-0.5">{children}</ul> : null}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function RefItem({
|
||||
task,
|
||||
itemRefs,
|
||||
unresolved,
|
||||
onJumpToBlocker,
|
||||
}: {
|
||||
task: TaskRecord;
|
||||
itemRefs: React.MutableRefObject<Map<number, HTMLLIElement>>;
|
||||
unresolved: number[];
|
||||
onJumpToBlocker: (id: number) => void;
|
||||
}) {
|
||||
return (
|
||||
<li
|
||||
ref={(el) => {
|
||||
if (el) itemRefs.current.set(task.id, el);
|
||||
else itemRefs.current.delete(task.id);
|
||||
}}
|
||||
className="rounded transition-shadow"
|
||||
>
|
||||
<TaskItem
|
||||
task={task}
|
||||
unresolvedBlockers={unresolved}
|
||||
onJumpToBlocker={onJumpToBlocker}
|
||||
/>
|
||||
</li>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Stacked variant: two TaskListPanels (colony template + queen session).
|
||||
// Used in the colony chat right rail.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface TaskListPanelStackedProps {
|
||||
templateTaskListId: string;
|
||||
queenSessionTaskListId: string | null;
|
||||
sessionId: string;
|
||||
onClose?: () => void;
|
||||
}
|
||||
|
||||
export function TaskListPanelStacked(props: TaskListPanelStackedProps) {
|
||||
return (
|
||||
<aside className="w-[320px] flex-shrink-0 border-l border-border bg-background flex flex-col h-full overflow-hidden">
|
||||
<div className="flex items-center justify-between px-3 py-2 border-b border-border">
|
||||
<h2 className="text-sm font-semibold">Tasks</h2>
|
||||
{props.onClose ? (
|
||||
<button
|
||||
type="button"
|
||||
onClick={props.onClose}
|
||||
className="text-muted-foreground hover:text-foreground"
|
||||
aria-label="Close"
|
||||
>
|
||||
<X className="h-4 w-4" />
|
||||
</button>
|
||||
) : null}
|
||||
</div>
|
||||
<div className="flex-1 min-h-0 flex flex-col overflow-hidden">
|
||||
<div className="flex-1 min-h-0 overflow-hidden border-b border-border">
|
||||
<TaskListPanel
|
||||
taskListId={props.templateTaskListId}
|
||||
sessionId={props.sessionId}
|
||||
title="Colony plan"
|
||||
variant="embedded"
|
||||
/>
|
||||
</div>
|
||||
{props.queenSessionTaskListId ? (
|
||||
<div className="flex-1 min-h-0 overflow-hidden">
|
||||
<TaskListPanel
|
||||
taskListId={props.queenSessionTaskListId}
|
||||
sessionId={props.sessionId}
|
||||
title="Queen's notes"
|
||||
variant="embedded"
|
||||
/>
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
</aside>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,508 @@
|
||||
import { useEffect, useMemo, useRef, useState } from "react";
|
||||
import {
|
||||
ChevronDown,
|
||||
ChevronRight,
|
||||
Check,
|
||||
Loader2,
|
||||
Lock,
|
||||
Wrench,
|
||||
AlertCircle,
|
||||
} from "lucide-react";
|
||||
import type { ToolMeta, McpServerTools } from "@/api/queens";
|
||||
|
||||
/** Shape every Tools section (Queen / Colony) shares. */
|
||||
export interface ToolsSnapshot {
|
||||
enabled_mcp_tools: string[] | null;
|
||||
stale: boolean;
|
||||
lifecycle: ToolMeta[];
|
||||
synthetic: ToolMeta[];
|
||||
mcp_servers: McpServerTools[];
|
||||
/** Optional: when true, the allowlist came from the role-based
|
||||
* default (no explicit save). Only queens surface this today. */
|
||||
is_role_default?: boolean;
|
||||
}
|
||||
|
||||
export interface ToolsEditorProps {
|
||||
/** Stable identifier — refetches when it changes. */
|
||||
subjectKey: string;
|
||||
/** Title shown above the controls. */
|
||||
title?: string;
|
||||
/** One-line caveat rendered under the header (e.g. "Changes apply …"). */
|
||||
caveat?: string;
|
||||
/** Load the current snapshot. */
|
||||
fetchSnapshot: () => Promise<ToolsSnapshot>;
|
||||
/** Persist an allowlist. ``null`` is an explicit "allow all" save. */
|
||||
saveAllowlist: (
|
||||
enabled: string[] | null,
|
||||
) => Promise<{ enabled_mcp_tools: string[] | null }>;
|
||||
/** Optional: drop any saved allowlist so the subject falls back to
|
||||
* its role-based default. Shows a "Reset to role default" button
|
||||
* when provided. */
|
||||
resetToRoleDefault?: () => Promise<{ enabled_mcp_tools: string[] | null }>;
|
||||
}
|
||||
|
||||
type TriState = "checked" | "unchecked" | "indeterminate";
|
||||
|
||||
function triStateForServer(
|
||||
toolNames: string[],
|
||||
allowed: Set<string> | null,
|
||||
): TriState {
|
||||
if (allowed === null) return "checked";
|
||||
if (toolNames.length === 0) return "unchecked";
|
||||
const enabledCount = toolNames.reduce(
|
||||
(n, name) => n + (allowed.has(name) ? 1 : 0),
|
||||
0,
|
||||
);
|
||||
if (enabledCount === 0) return "unchecked";
|
||||
if (enabledCount === toolNames.length) return "checked";
|
||||
return "indeterminate";
|
||||
}
|
||||
|
||||
function TriStateCheckbox({
|
||||
state,
|
||||
onChange,
|
||||
disabled,
|
||||
}: {
|
||||
state: TriState;
|
||||
onChange: (next: boolean) => void;
|
||||
disabled?: boolean;
|
||||
}) {
|
||||
const ref = useRef<HTMLInputElement>(null);
|
||||
useEffect(() => {
|
||||
if (ref.current) ref.current.indeterminate = state === "indeterminate";
|
||||
}, [state]);
|
||||
return (
|
||||
<input
|
||||
ref={ref}
|
||||
type="checkbox"
|
||||
checked={state === "checked"}
|
||||
disabled={disabled}
|
||||
onChange={(e) => onChange(e.target.checked)}
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
className="h-3.5 w-3.5 rounded border-border/70 text-primary focus:ring-primary/40"
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
function ToolRow({
|
||||
name,
|
||||
description,
|
||||
enabled,
|
||||
editable,
|
||||
onToggle,
|
||||
}: {
|
||||
name: string;
|
||||
description: string;
|
||||
enabled: boolean;
|
||||
editable: boolean;
|
||||
onToggle?: (next: boolean) => void;
|
||||
}) {
|
||||
return (
|
||||
<div className="flex items-start gap-2 py-1.5 px-2 rounded hover:bg-muted/30">
|
||||
{editable ? (
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={enabled}
|
||||
onChange={(e) => onToggle?.(e.target.checked)}
|
||||
className="mt-0.5 h-3.5 w-3.5 rounded border-border/70 text-primary focus:ring-primary/40"
|
||||
/>
|
||||
) : (
|
||||
<Lock className="mt-0.5 h-3 w-3 text-muted-foreground/60 flex-shrink-0" />
|
||||
)}
|
||||
<div className="min-w-0 flex-1">
|
||||
<div className="text-xs font-medium text-foreground font-mono">
|
||||
{name}
|
||||
</div>
|
||||
{description && (
|
||||
<div className="text-[11px] text-muted-foreground leading-relaxed line-clamp-2">
|
||||
{description}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function CollapsibleGroup({
|
||||
title,
|
||||
count,
|
||||
badge,
|
||||
expanded,
|
||||
onToggle,
|
||||
leading,
|
||||
children,
|
||||
}: {
|
||||
title: string;
|
||||
count: number;
|
||||
badge?: string;
|
||||
expanded: boolean;
|
||||
onToggle: () => void;
|
||||
leading?: React.ReactNode;
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
return (
|
||||
<div className="mb-2 rounded-lg border border-border/40 bg-muted/10 overflow-hidden">
|
||||
<button
|
||||
onClick={onToggle}
|
||||
className="w-full flex items-center gap-2 px-2.5 py-1.5 text-left hover:bg-muted/30"
|
||||
>
|
||||
{expanded ? (
|
||||
<ChevronDown className="w-3.5 h-3.5 text-muted-foreground" />
|
||||
) : (
|
||||
<ChevronRight className="w-3.5 h-3.5 text-muted-foreground" />
|
||||
)}
|
||||
{leading}
|
||||
<span className="text-xs font-medium text-foreground flex-1 truncate">
|
||||
{title}
|
||||
</span>
|
||||
<span className="text-[11px] text-muted-foreground">
|
||||
{badge ?? count}
|
||||
</span>
|
||||
</button>
|
||||
{expanded && (
|
||||
<div className="border-t border-border/30 px-1 py-1">{children}</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default function ToolsEditor({
|
||||
subjectKey,
|
||||
title = "Tools",
|
||||
caveat,
|
||||
fetchSnapshot,
|
||||
saveAllowlist,
|
||||
resetToRoleDefault,
|
||||
}: ToolsEditorProps) {
|
||||
const [data, setData] = useState<ToolsSnapshot | null>(null);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
const [draftAllowed, setDraftAllowed] = useState<Set<string> | null>(null);
|
||||
const baselineRef = useRef<Set<string> | null>(null);
|
||||
const [saving, setSaving] = useState(false);
|
||||
const [saveError, setSaveError] = useState<string | null>(null);
|
||||
const [savedRecently, setSavedRecently] = useState(false);
|
||||
|
||||
const [expanded, setExpanded] = useState<Record<string, boolean>>({});
|
||||
|
||||
useEffect(() => {
|
||||
let cancelled = false;
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
fetchSnapshot()
|
||||
.then((d) => {
|
||||
if (cancelled) return;
|
||||
setData(d);
|
||||
const baseline =
|
||||
d.enabled_mcp_tools === null
|
||||
? null
|
||||
: new Set<string>(d.enabled_mcp_tools);
|
||||
baselineRef.current = baseline === null ? null : new Set(baseline);
|
||||
setDraftAllowed(baseline);
|
||||
})
|
||||
.catch((e) => {
|
||||
if (cancelled) return;
|
||||
setError((e as Error)?.message || "Failed to load tools");
|
||||
})
|
||||
.finally(() => {
|
||||
if (!cancelled) setLoading(false);
|
||||
});
|
||||
return () => {
|
||||
cancelled = true;
|
||||
};
|
||||
}, [subjectKey, fetchSnapshot]);
|
||||
|
||||
const allMcpNames = useMemo(() => {
|
||||
const s = new Set<string>();
|
||||
data?.mcp_servers.forEach((srv) => srv.tools.forEach((t) => s.add(t.name)));
|
||||
return s;
|
||||
}, [data]);
|
||||
|
||||
const dirty = useMemo(() => {
|
||||
const a = draftAllowed;
|
||||
const b = baselineRef.current;
|
||||
if (a === null && b === null) return false;
|
||||
if (a === null || b === null) return true;
|
||||
if (a.size !== b.size) return true;
|
||||
for (const n of a) if (!b.has(n)) return true;
|
||||
return false;
|
||||
}, [draftAllowed]);
|
||||
|
||||
const applyResult = (updated: string[] | null, isRoleDefault: boolean) => {
|
||||
baselineRef.current = updated === null ? null : new Set(updated);
|
||||
setDraftAllowed(updated === null ? null : new Set(updated));
|
||||
if (data) {
|
||||
const u = updated === null ? null : new Set(updated);
|
||||
setData({
|
||||
...data,
|
||||
enabled_mcp_tools: updated,
|
||||
is_role_default: isRoleDefault,
|
||||
mcp_servers: data.mcp_servers.map((srv) => ({
|
||||
...srv,
|
||||
tools: srv.tools.map((t) => ({
|
||||
...t,
|
||||
enabled: u === null ? true : u.has(t.name),
|
||||
})),
|
||||
})),
|
||||
});
|
||||
}
|
||||
setSavedRecently(true);
|
||||
setTimeout(() => setSavedRecently(false), 2500);
|
||||
};
|
||||
|
||||
const toggleOne = (name: string, next: boolean) => {
|
||||
setDraftAllowed((prev) => {
|
||||
const base =
|
||||
prev === null ? new Set<string>(allMcpNames) : new Set<string>(prev);
|
||||
if (next) base.add(name);
|
||||
else base.delete(name);
|
||||
return base;
|
||||
});
|
||||
};
|
||||
|
||||
const toggleServer = (serverNames: string[], next: boolean) => {
|
||||
setDraftAllowed((prev) => {
|
||||
const base =
|
||||
prev === null ? new Set<string>(allMcpNames) : new Set<string>(prev);
|
||||
if (next) serverNames.forEach((n) => base.add(n));
|
||||
else serverNames.forEach((n) => base.delete(n));
|
||||
return base;
|
||||
});
|
||||
};
|
||||
|
||||
const handleAllowAll = () => setDraftAllowed(null);
|
||||
|
||||
const handleCancel = () => {
|
||||
const baseline = baselineRef.current;
|
||||
setDraftAllowed(baseline === null ? null : new Set(baseline));
|
||||
setSaveError(null);
|
||||
};
|
||||
|
||||
const handleSave = async () => {
|
||||
setSaving(true);
|
||||
setSaveError(null);
|
||||
try {
|
||||
// Only send tool names the server knows about (MCP tools).
|
||||
// The draft may contain lifecycle/synthetic names from the
|
||||
// baseline — strip those to avoid "Unknown MCP tool name" errors.
|
||||
const payload =
|
||||
draftAllowed === null
|
||||
? null
|
||||
: Array.from(draftAllowed)
|
||||
.filter((name) => allMcpNames.has(name))
|
||||
.sort();
|
||||
const result = await saveAllowlist(payload);
|
||||
applyResult(result.enabled_mcp_tools, false);
|
||||
} catch (e: unknown) {
|
||||
const err = e as { body?: { error?: string; unknown?: string[] } };
|
||||
const extra = err.body?.unknown
|
||||
? ` (${err.body.unknown.join(", ")})`
|
||||
: "";
|
||||
setSaveError((err.body?.error || "Save failed") + extra);
|
||||
} finally {
|
||||
setSaving(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleResetToRoleDefault = async () => {
|
||||
if (!resetToRoleDefault) return;
|
||||
setSaving(true);
|
||||
setSaveError(null);
|
||||
try {
|
||||
const result = await resetToRoleDefault();
|
||||
applyResult(result.enabled_mcp_tools, true);
|
||||
} catch (e: unknown) {
|
||||
const err = e as { body?: { error?: string } };
|
||||
setSaveError(err.body?.error || "Reset failed");
|
||||
} finally {
|
||||
setSaving(false);
|
||||
}
|
||||
};
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<div className="flex items-center gap-2 text-xs text-muted-foreground py-3">
|
||||
<Loader2 className="w-3 h-3 animate-spin" />
|
||||
Loading tools…
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (error || !data) {
|
||||
return (
|
||||
<div className="flex items-start gap-2 text-xs text-destructive py-3">
|
||||
<AlertCircle className="w-3.5 h-3.5 mt-0.5 flex-shrink-0" />
|
||||
<span>{error || "Could not load tools"}</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const draftEnabledCount =
|
||||
draftAllowed === null ? allMcpNames.size : draftAllowed.size;
|
||||
const totalMcpCount = allMcpNames.size;
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="flex items-center justify-between mb-1.5">
|
||||
<h4 className="text-[11px] font-semibold text-muted-foreground uppercase tracking-wider flex items-center gap-1.5">
|
||||
<Wrench className="w-3 h-3" /> {title}
|
||||
</h4>
|
||||
<span className="text-[11px] text-muted-foreground">
|
||||
{draftEnabledCount}/{totalMcpCount} MCP enabled
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{caveat && (
|
||||
<div className="flex items-start gap-1.5 text-[11px] text-muted-foreground mb-2 px-2 py-1.5 rounded bg-muted/20 border border-border/40">
|
||||
<AlertCircle className="w-3 h-3 mt-0.5 flex-shrink-0" />
|
||||
<span>{caveat}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{data.stale && (
|
||||
<div className="flex items-start gap-1.5 text-[11px] text-muted-foreground mb-3 px-2 py-1.5 rounded bg-muted/30">
|
||||
<AlertCircle className="w-3 h-3 mt-0.5 flex-shrink-0" />
|
||||
<span>
|
||||
Catalog is unavailable. Start a session once to populate the tool list.
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{(data.lifecycle.length > 0 || data.synthetic.length > 0) && (
|
||||
<CollapsibleGroup
|
||||
title="System tools (always enabled)"
|
||||
count={data.lifecycle.length + data.synthetic.length}
|
||||
expanded={!!expanded["__system"]}
|
||||
onToggle={() =>
|
||||
setExpanded((p) => ({ ...p, __system: !p["__system"] }))
|
||||
}
|
||||
>
|
||||
<div className="flex flex-col">
|
||||
{data.synthetic.map((t) => (
|
||||
<ToolRow
|
||||
key={`syn-${t.name}`}
|
||||
name={t.name}
|
||||
description={t.description}
|
||||
enabled={true}
|
||||
editable={false}
|
||||
/>
|
||||
))}
|
||||
{data.lifecycle.map((t) => (
|
||||
<ToolRow
|
||||
key={`lc-${t.name}`}
|
||||
name={t.name}
|
||||
description={t.description}
|
||||
enabled={true}
|
||||
editable={false}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</CollapsibleGroup>
|
||||
)}
|
||||
|
||||
{data.mcp_servers.map((srv) => {
|
||||
const toolNames = srv.tools.map((t) => t.name);
|
||||
const state = triStateForServer(toolNames, draftAllowed);
|
||||
const enabledInServer =
|
||||
draftAllowed === null
|
||||
? toolNames.length
|
||||
: toolNames.reduce(
|
||||
(n, name) => n + (draftAllowed.has(name) ? 1 : 0),
|
||||
0,
|
||||
);
|
||||
return (
|
||||
<CollapsibleGroup
|
||||
key={srv.name}
|
||||
title={srv.name === "(unknown)" ? "MCP Tools" : srv.name}
|
||||
count={srv.tools.length}
|
||||
badge={`${enabledInServer}/${srv.tools.length}`}
|
||||
expanded={!!expanded[srv.name]}
|
||||
onToggle={() =>
|
||||
setExpanded((p) => ({ ...p, [srv.name]: !p[srv.name] }))
|
||||
}
|
||||
leading={
|
||||
<TriStateCheckbox
|
||||
state={state}
|
||||
onChange={(next) => toggleServer(toolNames, next)}
|
||||
/>
|
||||
}
|
||||
>
|
||||
<div className="flex flex-col">
|
||||
{srv.tools.map((t) => {
|
||||
const enabled =
|
||||
draftAllowed === null ? true : draftAllowed.has(t.name);
|
||||
return (
|
||||
<ToolRow
|
||||
key={`${srv.name}-${t.name}`}
|
||||
name={t.name}
|
||||
description={t.description}
|
||||
enabled={enabled}
|
||||
editable={true}
|
||||
onToggle={(next) => toggleOne(t.name, next)}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</CollapsibleGroup>
|
||||
);
|
||||
})}
|
||||
|
||||
<div className="flex items-center gap-2 pt-3 flex-wrap">
|
||||
{/* Primary actions */}
|
||||
<button
|
||||
onClick={handleSave}
|
||||
disabled={!dirty || saving}
|
||||
className="flex items-center gap-1.5 px-3 py-1.5 rounded-md bg-primary text-primary-foreground text-xs font-medium hover:bg-primary/90 disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
>
|
||||
{saving ? <Loader2 className="w-3 h-3 animate-spin" /> : <Check className="w-3 h-3" />}
|
||||
{saving ? "Saving…" : "Save"}
|
||||
</button>
|
||||
<button
|
||||
onClick={handleCancel}
|
||||
disabled={!dirty || saving}
|
||||
className="px-3 py-1.5 rounded-md border border-border/60 text-xs font-medium text-muted-foreground hover:text-foreground hover:bg-muted/30 disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
|
||||
{/* Status */}
|
||||
{savedRecently && !dirty && (
|
||||
<span className="text-[11px] text-green-500 flex items-center gap-1">
|
||||
<Check className="w-3 h-3" /> Saved
|
||||
</span>
|
||||
)}
|
||||
{dirty && !saving && (
|
||||
<span className="text-[11px] text-amber-500">Unsaved changes</span>
|
||||
)}
|
||||
|
||||
{/* Quick actions */}
|
||||
<div className="ml-auto flex items-center gap-2">
|
||||
<button
|
||||
onClick={handleAllowAll}
|
||||
disabled={saving || draftAllowed === null}
|
||||
className="px-3 py-1.5 rounded-md border border-border/60 text-xs font-medium text-muted-foreground hover:text-foreground hover:bg-muted/30 disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
>
|
||||
Allow all
|
||||
</button>
|
||||
{resetToRoleDefault && (
|
||||
<button
|
||||
onClick={handleResetToRoleDefault}
|
||||
disabled={saving || !!data.is_role_default}
|
||||
className="px-3 py-1.5 rounded-md border border-border/60 text-xs font-medium text-muted-foreground hover:text-foreground hover:bg-muted/30 disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
>
|
||||
Reset to defaults
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{saveError && (
|
||||
<div className="flex items-start gap-1.5 mt-2 text-[11px] text-destructive">
|
||||
<AlertCircle className="w-3 h-3 mt-0.5 flex-shrink-0" />
|
||||
<span>{saveError}</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,213 @@
|
||||
/**
|
||||
* Per-list live task state. Mounts a single list (snapshot + SSE diffs).
|
||||
*
|
||||
* Stack two of these for the colony-overview view (template + queen
|
||||
* session). Mount one for the queen-DM and worker-detail views.
|
||||
*/
|
||||
|
||||
import {
|
||||
createContext,
|
||||
useContext,
|
||||
useEffect,
|
||||
useMemo,
|
||||
useReducer,
|
||||
useRef,
|
||||
type ReactNode,
|
||||
} from "react";
|
||||
|
||||
import {
|
||||
tasksApi,
|
||||
type TaskRecord,
|
||||
type TaskListRole,
|
||||
type TaskCreatedEvent,
|
||||
type TaskUpdatedEvent,
|
||||
type TaskDeletedEvent,
|
||||
} from "@/api/tasks";
|
||||
import { useSSE } from "@/hooks/use-sse";
|
||||
import type { AgentEvent } from "@/api/types";
|
||||
|
||||
interface TaskListState {
|
||||
taskListId: string;
|
||||
role: TaskListRole | "unknown";
|
||||
tasks: TaskRecord[];
|
||||
loading: boolean;
|
||||
error: string | null;
|
||||
/** False until the list exists on disk. Sessions that haven't created
|
||||
* any task yet return 404 from the snapshot endpoint; the panel
|
||||
* should hide rather than render an error. Becomes true on first
|
||||
* successful snapshot or on the first task_created event. */
|
||||
exists: boolean;
|
||||
}
|
||||
|
||||
type Action =
|
||||
| { type: "SNAPSHOT"; tasks: TaskRecord[]; role: TaskListRole }
|
||||
| { type: "LOADING" }
|
||||
| { type: "NOT_FOUND" }
|
||||
| { type: "ERROR"; error: string }
|
||||
| { type: "CREATED"; task: TaskRecord }
|
||||
| { type: "UPDATED"; task: TaskRecord }
|
||||
| { type: "DELETED"; taskId: number; cascade: number[] };
|
||||
|
||||
function reducer(state: TaskListState, action: Action): TaskListState {
|
||||
switch (action.type) {
|
||||
case "LOADING":
|
||||
return { ...state, loading: true, error: null };
|
||||
case "NOT_FOUND":
|
||||
return { ...state, loading: false, error: null, exists: false, tasks: [] };
|
||||
case "ERROR":
|
||||
return { ...state, loading: false, error: action.error };
|
||||
case "SNAPSHOT":
|
||||
return {
|
||||
...state,
|
||||
tasks: action.tasks,
|
||||
role: action.role,
|
||||
loading: false,
|
||||
error: null,
|
||||
exists: true,
|
||||
};
|
||||
case "CREATED": {
|
||||
// First task_created event for a previously-empty session marks
|
||||
// the list as existing — the panel will reveal itself live.
|
||||
if (state.tasks.some((t) => t.id === action.task.id)) {
|
||||
return { ...state, exists: true };
|
||||
}
|
||||
const next = [...state.tasks, action.task].sort((a, b) => a.id - b.id);
|
||||
return { ...state, tasks: next, exists: true };
|
||||
}
|
||||
case "UPDATED": {
|
||||
const next = state.tasks.map((t) => (t.id === action.task.id ? action.task : t));
|
||||
return { ...state, tasks: next, exists: true };
|
||||
}
|
||||
case "DELETED": {
|
||||
const surviving = state.tasks
|
||||
.filter((t) => t.id !== action.taskId)
|
||||
.map((t) => {
|
||||
if (action.cascade.includes(t.id)) {
|
||||
return {
|
||||
...t,
|
||||
blocks: t.blocks.filter((b) => b !== action.taskId),
|
||||
blocked_by: t.blocked_by.filter((b) => b !== action.taskId),
|
||||
};
|
||||
}
|
||||
return t;
|
||||
});
|
||||
return { ...state, tasks: surviving };
|
||||
}
|
||||
default:
|
||||
return state;
|
||||
}
|
||||
}
|
||||
|
||||
const initial: TaskListState = {
|
||||
taskListId: "",
|
||||
role: "unknown",
|
||||
tasks: [],
|
||||
loading: false,
|
||||
error: null,
|
||||
exists: false,
|
||||
};
|
||||
|
||||
const TaskListContext = createContext<TaskListState | undefined>(undefined);
|
||||
|
||||
interface TaskListProviderProps {
|
||||
taskListId: string;
|
||||
// SSE source — the queen session id is a reasonable default; events for
|
||||
// a list are published on the colony's bus. If `sessionId` is missing,
|
||||
// the panel renders the snapshot but doesn't subscribe to live diffs.
|
||||
sessionId?: string;
|
||||
children: ReactNode;
|
||||
}
|
||||
|
||||
const TASK_EVENT_TYPES = [
|
||||
"task_created",
|
||||
"task_updated",
|
||||
"task_deleted",
|
||||
"task_list_reset",
|
||||
] as const;
|
||||
|
||||
export function TaskListProvider({ taskListId, sessionId, children }: TaskListProviderProps) {
|
||||
const [state, dispatch] = useReducer(reducer, { ...initial, taskListId });
|
||||
const taskListIdRef = useRef(taskListId);
|
||||
taskListIdRef.current = taskListId;
|
||||
|
||||
// Snapshot fetch — re-run when taskListId changes.
|
||||
useEffect(() => {
|
||||
if (!taskListId) return;
|
||||
let cancelled = false;
|
||||
dispatch({ type: "LOADING" });
|
||||
tasksApi
|
||||
.getList(taskListId)
|
||||
.then((snap) => {
|
||||
if (cancelled) return;
|
||||
if (snap === null) {
|
||||
// Not yet on disk — the panel hides until the first task_created
|
||||
// event arrives via SSE (see CREATED case in the reducer).
|
||||
dispatch({ type: "NOT_FOUND" });
|
||||
} else {
|
||||
dispatch({ type: "SNAPSHOT", tasks: snap.tasks, role: snap.role });
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
if (cancelled) return;
|
||||
dispatch({ type: "ERROR", error: String(err?.message ?? err) });
|
||||
});
|
||||
return () => {
|
||||
cancelled = true;
|
||||
};
|
||||
}, [taskListId]);
|
||||
|
||||
// Subscribe to SSE diffs scoped to this list_id.
|
||||
useSSE({
|
||||
sessionId: sessionId ?? "",
|
||||
eventTypes: TASK_EVENT_TYPES as unknown as AgentEvent["type"][],
|
||||
enabled: Boolean(sessionId),
|
||||
onEvent: (ev) => {
|
||||
const data = ev.data ?? {};
|
||||
if (data.task_list_id !== taskListIdRef.current) return;
|
||||
switch (ev.type) {
|
||||
case "task_created":
|
||||
dispatch({ type: "CREATED", task: (data as unknown as TaskCreatedEvent).task });
|
||||
return;
|
||||
case "task_updated":
|
||||
dispatch({ type: "UPDATED", task: (data as unknown as TaskUpdatedEvent).after });
|
||||
return;
|
||||
case "task_deleted": {
|
||||
const d = data as unknown as TaskDeletedEvent;
|
||||
dispatch({ type: "DELETED", taskId: d.task_id, cascade: d.cascade ?? [] });
|
||||
return;
|
||||
}
|
||||
case "task_list_reset":
|
||||
dispatch({ type: "SNAPSHOT", tasks: [], role: state.role === "unknown" ? "session" : state.role });
|
||||
return;
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
return <TaskListContext.Provider value={state}>{children}</TaskListContext.Provider>;
|
||||
}
|
||||
|
||||
export function useTaskList(): TaskListState {
|
||||
const ctx = useContext(TaskListContext);
|
||||
if (!ctx) throw new Error("useTaskList must be used inside <TaskListProvider>");
|
||||
return ctx;
|
||||
}
|
||||
|
||||
// Helpers for components that want pre-bucketed views.
|
||||
export function bucketTasks(tasks: TaskRecord[]) {
|
||||
const completedIds = new Set(tasks.filter((t) => t.status === "completed").map((t) => t.id));
|
||||
const visible = tasks.filter((t) => !(t.metadata as { _internal?: boolean })._internal);
|
||||
const active = visible.filter((t) => t.status === "in_progress");
|
||||
const pending = visible.filter((t) => t.status === "pending");
|
||||
const completed = visible.filter((t) => t.status === "completed");
|
||||
return { active, pending, completed, completedIds, visible };
|
||||
}
|
||||
|
||||
export function unresolvedBlockers(task: TaskRecord, completedIds: Set<number>): number[] {
|
||||
return task.blocked_by.filter((b) => !completedIds.has(b));
|
||||
}
|
||||
|
||||
export const TASK_LIST_PANEL_LOCALSTORAGE_KEY = (taskListId: string) =>
|
||||
`taskListPanel.${taskListId}`;
|
||||
|
||||
export const useMemoizedBuckets = (tasks: TaskRecord[]) =>
|
||||
useMemo(() => bucketTasks(tasks), [tasks]);
|
||||
@@ -4,6 +4,8 @@ import Sidebar from "@/components/Sidebar";
|
||||
import AppHeader from "@/components/AppHeader";
|
||||
import QueenProfilePanel from "@/components/QueenProfilePanel";
|
||||
import ColonyWorkersPanel from "@/components/ColonyWorkersPanel";
|
||||
import TaskListPanel, { TaskListPanelStacked } from "@/components/TaskListPanel";
|
||||
import { sessionTaskListId, colonyTaskListId } from "@/api/tasks";
|
||||
import { ColonyProvider, useColony } from "@/context/ColonyContext";
|
||||
import { HeaderActionsProvider } from "@/context/HeaderActionsContext";
|
||||
import { QueenProfileProvider } from "@/context/QueenProfileContext";
|
||||
@@ -64,7 +66,44 @@ function LayoutShell({
|
||||
}) {
|
||||
const { sessionId, colonyName, dismissed, toggleColonyWorkers } =
|
||||
useColonyWorkers();
|
||||
const showWorkersPanel = Boolean(sessionId && !dismissed);
|
||||
// Workers panel is colony-only — queen-DM may publish a sessionId for
|
||||
// the tasks panel below, but we don't want the workers panel showing
|
||||
// up there (no workers exist).
|
||||
const showWorkersPanel = Boolean(sessionId && colonyName && !dismissed);
|
||||
const location = useLocation();
|
||||
const [taskPanelDismissed, setTaskPanelDismissed] = useState(false);
|
||||
|
||||
// Determine which task panel to show based on the current route.
|
||||
// queen-DM (/queen/...) -> single TaskListPanel for queen session
|
||||
// colony chat (/colony/{name}) -> stacked (template + queen session)
|
||||
// anywhere else -> hidden
|
||||
const isColony = location.pathname.startsWith("/colony/");
|
||||
const isQueenDm = location.pathname.startsWith("/queen/");
|
||||
const showTasksPanel = !taskPanelDismissed && Boolean(sessionId) && (isQueenDm || isColony);
|
||||
|
||||
let tasksPanel: ReactNode = null;
|
||||
if (showTasksPanel && sessionId) {
|
||||
if (isColony) {
|
||||
const colonyId = colonyName ?? location.pathname.replace("/colony/", "");
|
||||
tasksPanel = (
|
||||
<TaskListPanelStacked
|
||||
templateTaskListId={colonyTaskListId(colonyId)}
|
||||
queenSessionTaskListId={sessionTaskListId("queen", sessionId)}
|
||||
sessionId={sessionId}
|
||||
onClose={() => setTaskPanelDismissed(true)}
|
||||
/>
|
||||
);
|
||||
} else {
|
||||
tasksPanel = (
|
||||
<TaskListPanel
|
||||
taskListId={sessionTaskListId("queen", sessionId)}
|
||||
sessionId={sessionId}
|
||||
variant="rail"
|
||||
onClose={() => setTaskPanelDismissed(true)}
|
||||
/>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex h-screen bg-background overflow-hidden">
|
||||
@@ -89,6 +128,7 @@ function LayoutShell({
|
||||
onClose={toggleColonyWorkers}
|
||||
/>
|
||||
)}
|
||||
{tasksPanel}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -140,13 +140,26 @@ export default function PromptLibrary() {
|
||||
promptsApi.list().then((r) => setCustomPrompts(r.prompts)).catch(() => {});
|
||||
}, []);
|
||||
|
||||
// Merge built-in + custom prompts
|
||||
const allPrompts = useMemo(() => [...customPrompts, ...prompts], [customPrompts]);
|
||||
// Filtered custom (my) prompts
|
||||
const filteredCustom = useMemo(() => {
|
||||
let result: (Prompt | CustomPrompt)[] = customPrompts;
|
||||
if (selectedCategory && selectedCategory !== "custom") {
|
||||
result = result.filter((p) => p.category === selectedCategory);
|
||||
}
|
||||
if (searchQuery.trim()) {
|
||||
const query = searchQuery.toLowerCase();
|
||||
result = result.filter(
|
||||
(p) => p.title.toLowerCase().includes(query) || p.content.toLowerCase().includes(query),
|
||||
);
|
||||
}
|
||||
return result;
|
||||
}, [customPrompts, searchQuery, selectedCategory]);
|
||||
|
||||
const filteredPrompts = useMemo(() => {
|
||||
let result = allPrompts;
|
||||
// Filtered built-in (community) prompts
|
||||
const filteredBuiltIn = useMemo(() => {
|
||||
let result: Prompt[] = prompts;
|
||||
if (selectedCategory === "custom") {
|
||||
result = result.filter((p) => "custom" in p && p.custom);
|
||||
result = [];
|
||||
} else if (selectedCategory) {
|
||||
result = result.filter((p) => p.category === selectedCategory);
|
||||
}
|
||||
@@ -157,13 +170,13 @@ export default function PromptLibrary() {
|
||||
);
|
||||
}
|
||||
return result;
|
||||
}, [allPrompts, searchQuery, selectedCategory]);
|
||||
}, [searchQuery, selectedCategory]);
|
||||
|
||||
// Reset page when filters change
|
||||
useEffect(() => setPage(0), [searchQuery, selectedCategory]);
|
||||
|
||||
const totalPages = Math.max(1, Math.ceil(filteredPrompts.length / PAGE_SIZE));
|
||||
const pagedPrompts = filteredPrompts.slice(page * PAGE_SIZE, (page + 1) * PAGE_SIZE);
|
||||
const totalPages = Math.max(1, Math.ceil(filteredBuiltIn.length / PAGE_SIZE));
|
||||
const pagedBuiltIn = filteredBuiltIn.slice(page * PAGE_SIZE, (page + 1) * PAGE_SIZE);
|
||||
|
||||
const handleUsePrompt = (content: string, category: string) => {
|
||||
const queenId = categoryToQueen[category];
|
||||
@@ -196,7 +209,7 @@ export default function PromptLibrary() {
|
||||
Prompt Library
|
||||
</h2>
|
||||
<span className="text-xs text-muted-foreground">
|
||||
{allPrompts.length} prompts across {promptCategories.length + (customCount > 0 ? 1 : 0)} categories
|
||||
{customCount > 0 ? `${customCount} custom · ` : ""}{prompts.length} community prompts
|
||||
</span>
|
||||
</div>
|
||||
<button onClick={() => setAddModalOpen(true)}
|
||||
@@ -240,23 +253,47 @@ export default function PromptLibrary() {
|
||||
|
||||
{/* Prompts grid */}
|
||||
<div className="flex-1 overflow-y-auto p-6">
|
||||
{pagedPrompts.length > 0 ? (
|
||||
<div className="grid grid-cols-1 gap-4 md:grid-cols-2 lg:grid-cols-3">
|
||||
{pagedPrompts.map((prompt) => (
|
||||
<PromptCard
|
||||
key={typeof prompt.id === "string" ? prompt.id : `builtin-${prompt.id}`}
|
||||
prompt={prompt}
|
||||
onUse={handleUsePrompt}
|
||||
onDelete={"custom" in prompt && prompt.custom ? () => handleDeletePrompt(prompt.id as string) : undefined}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
) : (
|
||||
{filteredCustom.length === 0 && pagedBuiltIn.length === 0 ? (
|
||||
<div className="flex flex-col items-center justify-center h-full text-center">
|
||||
<Sparkles className="w-10 h-10 text-muted-foreground/30 mb-3" />
|
||||
<p className="text-sm text-muted-foreground">No prompts found</p>
|
||||
<p className="text-xs text-muted-foreground/60 mt-1">Try adjusting your search or category filter</p>
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
{/* My Prompts section */}
|
||||
{filteredCustom.length > 0 && (
|
||||
<div className="mb-8">
|
||||
<h3 className="text-xs font-semibold text-muted-foreground uppercase tracking-wider mb-3">My Prompts</h3>
|
||||
<div className="grid grid-cols-1 gap-4 md:grid-cols-2 lg:grid-cols-3">
|
||||
{filteredCustom.map((prompt) => (
|
||||
<PromptCard
|
||||
key={prompt.id as string}
|
||||
prompt={prompt}
|
||||
onUse={handleUsePrompt}
|
||||
onDelete={"custom" in prompt && prompt.custom ? () => handleDeletePrompt(prompt.id as string) : undefined}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Community Prompts section */}
|
||||
{pagedBuiltIn.length > 0 && selectedCategory !== "custom" && (
|
||||
<div>
|
||||
<h3 className="text-xs font-semibold text-muted-foreground uppercase tracking-wider mb-3">Community Prompts</h3>
|
||||
<div className="grid grid-cols-1 gap-4 md:grid-cols-2 lg:grid-cols-3">
|
||||
{pagedBuiltIn.map((prompt) => (
|
||||
<PromptCard
|
||||
key={`builtin-${prompt.id}`}
|
||||
prompt={prompt}
|
||||
onUse={handleUsePrompt}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -264,7 +301,7 @@ export default function PromptLibrary() {
|
||||
{totalPages > 1 && (
|
||||
<div className="px-6 py-3 border-t border-border/60 flex items-center justify-between">
|
||||
<span className="text-xs text-muted-foreground">
|
||||
{page * PAGE_SIZE + 1}–{Math.min((page + 1) * PAGE_SIZE, filteredPrompts.length)} of {filteredPrompts.length}
|
||||
{page * PAGE_SIZE + 1}–{Math.min((page + 1) * PAGE_SIZE, filteredBuiltIn.length)} of {filteredBuiltIn.length}
|
||||
</span>
|
||||
<div className="flex items-center gap-1">
|
||||
<button onClick={() => setPage((p) => Math.max(0, p - 1))} disabled={page === 0}
|
||||
|
||||
@@ -18,11 +18,38 @@ import {
|
||||
replayEventsToMessages,
|
||||
} from "@/lib/chat-helpers";
|
||||
import { useColony } from "@/context/ColonyContext";
|
||||
import { useColonyWorkers } from "@/context/ColonyWorkersContext";
|
||||
import { useHeaderActions } from "@/context/HeaderActionsContext";
|
||||
import { getQueenForAgent, slugToColonyId } from "@/lib/colony-registry";
|
||||
|
||||
const makeId = () => Math.random().toString(36).slice(2, 9);
|
||||
|
||||
// Remembers the last session the user had open in each queen DM so that
|
||||
// navigating away (e.g. to another queen) and back lands on the session
|
||||
// they were just in, instead of whichever session the server picks.
|
||||
const lastSessionKey = (queenId: string) => `hive:queen:${queenId}:lastSession`;
|
||||
const readLastSession = (queenId: string): string | null => {
|
||||
try {
|
||||
return localStorage.getItem(lastSessionKey(queenId));
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
};
|
||||
const writeLastSession = (queenId: string, sessionId: string) => {
|
||||
try {
|
||||
localStorage.setItem(lastSessionKey(queenId), sessionId);
|
||||
} catch {
|
||||
/* storage disabled/full — best-effort */
|
||||
}
|
||||
};
|
||||
const clearLastSession = (queenId: string) => {
|
||||
try {
|
||||
localStorage.removeItem(lastSessionKey(queenId));
|
||||
} catch {
|
||||
/* ignore */
|
||||
}
|
||||
};
|
||||
|
||||
export default function QueenDM() {
|
||||
const { queenId } = useParams<{ queenId: string }>();
|
||||
const [searchParams, setSearchParams] = useSearchParams();
|
||||
@@ -45,7 +72,17 @@ export default function QueenDM() {
|
||||
{ id: string; prompt: string; options?: string[] }[] | null
|
||||
>(null);
|
||||
const [awaitingInput, setAwaitingInput] = useState(false);
|
||||
const [tokenUsage, setTokenUsage] = useState({ input: 0, output: 0 });
|
||||
// `cached` and `cacheCreated` are subsets of `input` (providers count both
|
||||
// inside prompt_tokens already) — display them, never add them to a total.
|
||||
// `costUsd` is the session-total USD cost when the provider supplies one
|
||||
// (Anthropic, OpenAI, OpenRouter); 0 means unreported, not free.
|
||||
const [tokenUsage, setTokenUsage] = useState({
|
||||
input: 0,
|
||||
output: 0,
|
||||
cached: 0,
|
||||
cacheCreated: 0,
|
||||
costUsd: 0,
|
||||
});
|
||||
const [historySessions, setHistorySessions] = useState<HistorySession[]>([]);
|
||||
const [historyLoading, setHistoryLoading] = useState(false);
|
||||
const [switchingSessionId, setSwitchingSessionId] = useState<string | null>(
|
||||
@@ -83,6 +120,16 @@ export default function QueenDM() {
|
||||
"independent" | "incubating" | "working" | "reviewing"
|
||||
>("independent");
|
||||
|
||||
// Publish the active session id into the shared workers/tasks context
|
||||
// so AppLayout's right-rail TaskListPanel can attach to it. The colony
|
||||
// workers panel itself stays hidden in queen-DM because we don't set
|
||||
// colonyName (AppLayout requires both — see LayoutShell).
|
||||
const { setSessionId: setCtxSessionId } = useColonyWorkers();
|
||||
useEffect(() => {
|
||||
setCtxSessionId(sessionId ?? null);
|
||||
return () => setCtxSessionId(null);
|
||||
}, [sessionId, setCtxSessionId]);
|
||||
|
||||
const resetViewState = useCallback(() => {
|
||||
setSessionId(null);
|
||||
setMessages([]);
|
||||
@@ -92,7 +139,7 @@ export default function QueenDM() {
|
||||
setPendingQuestions(null);
|
||||
setAwaitingInput(false);
|
||||
setQueenPhase("independent");
|
||||
setTokenUsage({ input: 0, output: 0 });
|
||||
setTokenUsage({ input: 0, output: 0, cached: 0, cacheCreated: 0, costUsd: 0 });
|
||||
setInitialDraft(null);
|
||||
setColonySpawned(false);
|
||||
setSpawnedColonyName(null);
|
||||
@@ -160,6 +207,31 @@ export default function QueenDM() {
|
||||
);
|
||||
replayStateRef.current = replayState;
|
||||
|
||||
// Sum historical llm_turn_complete events so Tokens/Cost carry over
|
||||
// across resume. SSE does not replay llm_turn_complete (see
|
||||
// routes_events.py _REPLAY_TYPES), so no double-count risk — live
|
||||
// SSE deltas that may have already landed are kept via functional
|
||||
// merge below.
|
||||
const seed = { input: 0, output: 0, cached: 0, cacheCreated: 0, costUsd: 0 };
|
||||
for (const evt of events) {
|
||||
if (evt.type !== "llm_turn_complete" || !evt.data) continue;
|
||||
const d = evt.data as Record<string, unknown>;
|
||||
seed.input += (d.input_tokens as number) || 0;
|
||||
seed.output += (d.output_tokens as number) || 0;
|
||||
seed.cached += (d.cached_tokens as number) || 0;
|
||||
seed.cacheCreated += (d.cache_creation_tokens as number) || 0;
|
||||
seed.costUsd += (d.cost_usd as number) || 0;
|
||||
}
|
||||
if (!cancelled()) {
|
||||
setTokenUsage((prev) => ({
|
||||
input: prev.input + seed.input,
|
||||
output: prev.output + seed.output,
|
||||
cached: prev.cached + seed.cached,
|
||||
cacheCreated: prev.cacheCreated + seed.cacheCreated,
|
||||
costUsd: prev.costUsd + seed.costUsd,
|
||||
}));
|
||||
}
|
||||
|
||||
// Show a banner if the server truncated older events.
|
||||
const droppedCount = Math.max(0, total - returned);
|
||||
if (truncated && droppedCount > 0) {
|
||||
@@ -199,6 +271,19 @@ export default function QueenDM() {
|
||||
useEffect(() => {
|
||||
if (!queenId) return;
|
||||
|
||||
// If we arrived without an explicit session in the URL and aren't
|
||||
// bootstrapping a new one, redirect to the last session the user had
|
||||
// open for this queen. Session IDs are always of the form
|
||||
// "session_<timestamp>_<hex>", so we gate on that prefix to avoid
|
||||
// redirecting to anything unexpected that landed in storage.
|
||||
if (!selectedSessionParam && newSessionFlag !== "1") {
|
||||
const stored = readLastSession(queenId);
|
||||
if (stored && stored.startsWith("session_")) {
|
||||
setSearchParams({ session: stored }, { replace: true });
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
resetViewState();
|
||||
setLoading(true);
|
||||
|
||||
@@ -314,7 +399,17 @@ export default function QueenDM() {
|
||||
await restoreMessages(sid, () => cancelled);
|
||||
refresh();
|
||||
} catch {
|
||||
// Session creation failed
|
||||
// Session creation/selection failed. If the URL param came from
|
||||
// our own localStorage restore, the stored session is stale (e.g.
|
||||
// deleted on disk) — clear it so the next navigation falls
|
||||
// through to getOrCreate instead of looping on the bad id.
|
||||
if (
|
||||
queenId &&
|
||||
selectedSessionParam &&
|
||||
selectedSessionParam === readLastSession(queenId)
|
||||
) {
|
||||
clearLastSession(queenId);
|
||||
}
|
||||
} finally {
|
||||
if (!cancelled) {
|
||||
setLoading(false);
|
||||
@@ -337,6 +432,13 @@ export default function QueenDM() {
|
||||
setSearchParams,
|
||||
]);
|
||||
|
||||
// Remember the session the user is currently viewing so switching queens
|
||||
// and coming back lands on it instead of whatever the server picks.
|
||||
useEffect(() => {
|
||||
if (!queenId || !sessionId) return;
|
||||
writeLastSession(queenId, sessionId);
|
||||
}, [queenId, sessionId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!queenId) return;
|
||||
let cancelled = false;
|
||||
@@ -520,7 +622,18 @@ export default function QueenDM() {
|
||||
if (event.data) {
|
||||
const inp = (event.data.input_tokens as number) || 0;
|
||||
const out = (event.data.output_tokens as number) || 0;
|
||||
setTokenUsage((prev) => ({ input: prev.input + inp, output: prev.output + out }));
|
||||
// cached / cache_creation are subsets of input — accumulate
|
||||
// separately for display, do NOT roll into input/total.
|
||||
const cached = (event.data.cached_tokens as number) || 0;
|
||||
const cacheCreated = (event.data.cache_creation_tokens as number) || 0;
|
||||
const costUsd = (event.data.cost_usd as number) || 0;
|
||||
setTokenUsage((prev) => ({
|
||||
input: prev.input + inp,
|
||||
output: prev.output + out,
|
||||
cached: prev.cached + cached,
|
||||
cacheCreated: prev.cacheCreated + cacheCreated,
|
||||
costUsd: prev.costUsd + costUsd,
|
||||
}));
|
||||
}
|
||||
// Flush one queued message per LLM turn boundary. This is the
|
||||
// real "turn ended" signal in a queen DM — execution_completed
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,283 @@
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
import { Wrench, Crown, Network, Server, Loader2, AlertCircle } from "lucide-react";
|
||||
import { queensApi } from "@/api/queens";
|
||||
import { coloniesApi, type ColonySummary } from "@/api/colonies";
|
||||
import { slugToDisplayName, sortQueenProfiles } from "@/lib/colony-registry";
|
||||
import QueenToolsSection from "@/components/QueenToolsSection";
|
||||
import ColonyToolsSection from "@/components/ColonyToolsSection";
|
||||
import McpServersPanel from "@/components/McpServersPanel";
|
||||
|
||||
type Tab = "queens" | "colonies" | "mcp";
|
||||
|
||||
export default function ToolLibrary() {
|
||||
const [tab, setTab] = useState<Tab>("queens");
|
||||
|
||||
return (
|
||||
<div className="flex-1 flex flex-col min-w-0 overflow-hidden">
|
||||
{/* Header */}
|
||||
<div className="px-6 py-4 border-b border-border/60">
|
||||
<div className="flex items-baseline gap-3 mb-3">
|
||||
<h2 className="text-lg font-semibold text-foreground flex items-center gap-2">
|
||||
<Wrench className="w-5 h-5 text-primary" />
|
||||
Tool Configuration
|
||||
</h2>
|
||||
<span className="text-xs text-muted-foreground">
|
||||
Curate which tools each queen and colony can call, and register your own MCP servers.
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
<TabButton active={tab === "queens"} onClick={() => setTab("queens")} icon={<Crown className="w-3.5 h-3.5" />}>
|
||||
Queens
|
||||
</TabButton>
|
||||
<TabButton active={tab === "colonies"} onClick={() => setTab("colonies")} icon={<Network className="w-3.5 h-3.5" />}>
|
||||
Colonies
|
||||
</TabButton>
|
||||
<TabButton active={tab === "mcp"} onClick={() => setTab("mcp")} icon={<Server className="w-3.5 h-3.5" />}>
|
||||
MCP Servers
|
||||
</TabButton>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex-1 overflow-y-auto">
|
||||
{tab === "queens" && <QueensTab />}
|
||||
{tab === "colonies" && <ColoniesTab />}
|
||||
{tab === "mcp" && (
|
||||
<div className="px-6 py-6 max-w-4xl">
|
||||
<McpServersPanel />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function TabButton({
|
||||
active,
|
||||
onClick,
|
||||
icon,
|
||||
children,
|
||||
}: {
|
||||
active: boolean;
|
||||
onClick: () => void;
|
||||
icon: React.ReactNode;
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
return (
|
||||
<button
|
||||
onClick={onClick}
|
||||
className={`flex items-center gap-1.5 px-3 py-1.5 rounded-md text-sm font-medium ${
|
||||
active
|
||||
? "bg-primary/15 text-primary"
|
||||
: "text-muted-foreground hover:text-foreground hover:bg-muted/30"
|
||||
}`}
|
||||
>
|
||||
{icon}
|
||||
{children}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
|
||||
// ----- Queens tab ---------------------------------------------------------
|
||||
|
||||
function QueensTab() {
|
||||
const [queens, setQueens] = useState<Array<{ id: string; name: string; title: string }> | null>(null);
|
||||
const [selected, setSelected] = useState<string | null>(null);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
queensApi
|
||||
.list()
|
||||
.then((r) => {
|
||||
const sorted = sortQueenProfiles(r.queens);
|
||||
setQueens(sorted);
|
||||
if (sorted.length > 0) setSelected((prev) => prev ?? sorted[0].id);
|
||||
})
|
||||
.catch((e: Error) => setError(e.message || "Failed to load queens"));
|
||||
}, []);
|
||||
|
||||
if (error) return <ErrorBlock message={error} />;
|
||||
if (queens === null) return <LoadingBlock label="Loading queens…" />;
|
||||
if (queens.length === 0)
|
||||
return <EmptyBlock label="No queens yet. Create one to curate its tools." />;
|
||||
|
||||
return (
|
||||
<div className="flex h-full">
|
||||
<SidePicker>
|
||||
{queens.map((q) => (
|
||||
<PickerItem
|
||||
key={q.id}
|
||||
active={selected === q.id}
|
||||
onClick={() => setSelected(q.id)}
|
||||
primary={q.name}
|
||||
secondary={q.title}
|
||||
/>
|
||||
))}
|
||||
</SidePicker>
|
||||
<div className="flex-1 overflow-y-auto px-6 py-5 min-w-0">
|
||||
{selected ? (
|
||||
<>
|
||||
{(() => {
|
||||
const queen = queens.find((q) => q.id === selected);
|
||||
return queen ? (
|
||||
<div className="mb-4 pb-3 border-b border-border/40">
|
||||
<h3 className="text-base font-semibold text-foreground">
|
||||
{queen.name}
|
||||
</h3>
|
||||
<p className="text-xs text-muted-foreground mt-0.5">
|
||||
{queen.title}
|
||||
</p>
|
||||
</div>
|
||||
) : null;
|
||||
})()}
|
||||
<QueenToolsSection queenId={selected} />
|
||||
</>
|
||||
) : (
|
||||
<EmptyBlock label="Pick a queen to edit her tool allowlist." />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ----- Colonies tab -------------------------------------------------------
|
||||
|
||||
function ColoniesTab() {
|
||||
const [colonies, setColonies] = useState<ColonySummary[] | null>(null);
|
||||
const [selected, setSelected] = useState<string | null>(null);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
coloniesApi
|
||||
.list()
|
||||
.then((r) => {
|
||||
setColonies(r.colonies);
|
||||
if (r.colonies.length > 0)
|
||||
setSelected((prev) => prev ?? r.colonies[0].name);
|
||||
})
|
||||
.catch((e: Error) => setError(e.message || "Failed to load colonies"));
|
||||
}, []);
|
||||
|
||||
const sorted = useMemo(() => {
|
||||
if (!colonies) return null;
|
||||
return [...colonies].sort((a, b) => a.name.localeCompare(b.name));
|
||||
}, [colonies]);
|
||||
|
||||
if (error) return <ErrorBlock message={error} />;
|
||||
if (sorted === null) return <LoadingBlock label="Loading colonies…" />;
|
||||
if (sorted.length === 0)
|
||||
return (
|
||||
<EmptyBlock label="No colonies yet. Ask a queen to incubate one and its tools will show up here." />
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="flex h-full">
|
||||
<SidePicker>
|
||||
{sorted.map((c) => (
|
||||
<PickerItem
|
||||
key={c.name}
|
||||
active={selected === c.name}
|
||||
onClick={() => setSelected(c.name)}
|
||||
primary={slugToDisplayName(c.name)}
|
||||
secondary={
|
||||
c.has_allowlist
|
||||
? `${c.enabled_count ?? 0} tools allowed · ${c.queen_name ?? ""}`
|
||||
: `all tools · ${c.queen_name ?? ""}`
|
||||
}
|
||||
tertiary={c.name}
|
||||
/>
|
||||
))}
|
||||
</SidePicker>
|
||||
<div className="flex-1 overflow-y-auto px-6 py-5 min-w-0">
|
||||
{selected ? (
|
||||
<>
|
||||
<div className="mb-4 pb-3 border-b border-border/40">
|
||||
<h3 className="text-base font-semibold text-foreground">
|
||||
{slugToDisplayName(selected)}
|
||||
</h3>
|
||||
<p className="text-[11px] text-muted-foreground font-mono mt-0.5">
|
||||
{selected}
|
||||
</p>
|
||||
</div>
|
||||
<ColonyToolsSection colonyName={selected} />
|
||||
</>
|
||||
) : (
|
||||
<EmptyBlock label="Pick a colony to edit its tool allowlist." />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ----- Shared primitives --------------------------------------------------
|
||||
|
||||
function SidePicker({ children }: { children: React.ReactNode }) {
|
||||
return (
|
||||
<div className="w-[260px] flex-shrink-0 border-r border-border/60 overflow-y-auto py-3 px-2 flex flex-col gap-1">
|
||||
{children}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function PickerItem({
|
||||
active,
|
||||
onClick,
|
||||
primary,
|
||||
secondary,
|
||||
tertiary,
|
||||
}: {
|
||||
active: boolean;
|
||||
onClick: () => void;
|
||||
primary: string;
|
||||
secondary?: string;
|
||||
tertiary?: string;
|
||||
}) {
|
||||
return (
|
||||
<button
|
||||
onClick={onClick}
|
||||
className={`text-left px-3 py-2 rounded-md text-sm ${
|
||||
active
|
||||
? "bg-primary/15 text-primary"
|
||||
: "text-foreground hover:bg-muted/30"
|
||||
}`}
|
||||
>
|
||||
<div className="font-medium truncate">{primary}</div>
|
||||
{secondary && (
|
||||
<div className="text-[11px] text-muted-foreground truncate">
|
||||
{secondary}
|
||||
</div>
|
||||
)}
|
||||
{tertiary && (
|
||||
<div className="text-[10px] text-muted-foreground/60 font-mono truncate">
|
||||
{tertiary}
|
||||
</div>
|
||||
)}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
|
||||
function LoadingBlock({ label }: { label: string }) {
|
||||
return (
|
||||
<div className="flex items-center gap-2 text-xs text-muted-foreground px-6 py-6">
|
||||
<Loader2 className="w-3 h-3 animate-spin" />
|
||||
{label}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function EmptyBlock({ label }: { label: string }) {
|
||||
return (
|
||||
<div className="flex items-start gap-2 text-xs text-muted-foreground px-6 py-6">
|
||||
<AlertCircle className="w-3.5 h-3.5 mt-0.5" />
|
||||
<span>{label}</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ErrorBlock({ message }: { message: string }) {
|
||||
return (
|
||||
<div className="flex items-start gap-2 text-xs text-destructive px-6 py-6">
|
||||
<AlertCircle className="w-3.5 h-3.5 mt-0.5 flex-shrink-0" />
|
||||
<span>{message}</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -42,6 +42,7 @@ _HIVE_PATH_CONSUMERS = (
|
||||
"framework.server.session_manager",
|
||||
"framework.server.queen_orchestrator",
|
||||
"framework.server.routes_queens",
|
||||
"framework.server.routes_skills",
|
||||
"framework.server.app",
|
||||
"framework.agents.discovery",
|
||||
"framework.agents.queen.queen_profiles",
|
||||
|
||||
@@ -9,26 +9,25 @@ from framework.llm.provider import Tool
|
||||
|
||||
|
||||
class TestSupportsImageToolResults:
|
||||
"""Verify the deny-list correctly identifies models that can't handle images."""
|
||||
"""Verify catalog-driven vision capability checks."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4-turbo",
|
||||
"openai/gpt-4o",
|
||||
"anthropic/claude-sonnet-4-20250514",
|
||||
# Catalog entries with supports_vision=true
|
||||
"claude-haiku-4-5-20251001",
|
||||
"gemini/gemini-1.5-pro",
|
||||
"google/gemini-1.5-flash",
|
||||
"mistral/mistral-large",
|
||||
"groq/llama3-70b",
|
||||
"together/meta-llama/Llama-3-70b",
|
||||
"fireworks_ai/llama-v3-70b",
|
||||
"azure/gpt-4o",
|
||||
"kimi/claude-sonnet-4-20250514",
|
||||
"hive/claude-sonnet-4-20250514",
|
||||
"claude-sonnet-4-5-20250929",
|
||||
"claude-opus-4-6",
|
||||
"gpt-5.4",
|
||||
"gpt-5.4-mini",
|
||||
"gemini-3-flash-preview",
|
||||
"kimi-k2.5",
|
||||
# Provider-prefixed catalog entries
|
||||
"openrouter/openai/gpt-5.4",
|
||||
"openrouter/anthropic/claude-sonnet-4.6",
|
||||
# Unknown models default to True (hosted frontier assumption)
|
||||
"some-future-model",
|
||||
"azure/gpt-5",
|
||||
],
|
||||
)
|
||||
def test_supported_models(self, model: str):
|
||||
@@ -37,27 +36,24 @@ class TestSupportsImageToolResults:
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"deepseek/deepseek-chat",
|
||||
"deepseek/deepseek-coder",
|
||||
"deepseek-chat",
|
||||
# Catalog entries with supports_vision=false
|
||||
"deepseek-reasoner",
|
||||
"ollama/llama3",
|
||||
"ollama/mistral",
|
||||
"ollama_chat/llama3",
|
||||
"lm_studio/my-model",
|
||||
"vllm/meta-llama/Llama-3-70b",
|
||||
"llamacpp/model",
|
||||
"cerebras/llama3-70b",
|
||||
"deepseek-v4-pro",
|
||||
"deepseek-v4-flash",
|
||||
"glm-5.1",
|
||||
"queen",
|
||||
"MiniMax-M2.7",
|
||||
"codestral-2508",
|
||||
"llama-3.3-70b-versatile",
|
||||
# Provider-prefixed forms resolve to the same catalog entry
|
||||
"deepseek/deepseek-reasoner",
|
||||
"hive/glm-5.1",
|
||||
"groq/llama-3.3-70b-versatile",
|
||||
],
|
||||
)
|
||||
def test_unsupported_models(self, model: str):
|
||||
assert supports_image_tool_results(model) is False
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert supports_image_tool_results("DeepSeek/deepseek-chat") is False
|
||||
assert supports_image_tool_results("OLLAMA/llama3") is False
|
||||
assert supports_image_tool_results("GPT-4o") is True
|
||||
|
||||
|
||||
class TestFilterToolsForModel:
|
||||
"""Verify ``filter_tools_for_model`` — the real helper used by AgentLoop."""
|
||||
@@ -68,7 +64,7 @@ class TestFilterToolsForModel:
|
||||
Tool(name="browser_screenshot", description="take a screenshot", produces_image=True),
|
||||
Tool(name="browser_snapshot", description="get page content"),
|
||||
]
|
||||
filtered, hidden = filter_tools_for_model(tools, "glm-5")
|
||||
filtered, hidden = filter_tools_for_model(tools, "glm-5.1")
|
||||
names = [t.name for t in filtered]
|
||||
assert "browser_screenshot" not in names
|
||||
assert "read_file" in names
|
||||
@@ -80,7 +76,7 @@ class TestFilterToolsForModel:
|
||||
Tool(name="read_file", description="read a file"),
|
||||
Tool(name="browser_screenshot", description="take a screenshot", produces_image=True),
|
||||
]
|
||||
filtered, hidden = filter_tools_for_model(tools, "claude-sonnet-4-20250514")
|
||||
filtered, hidden = filter_tools_for_model(tools, "claude-sonnet-4-5-20250929")
|
||||
assert {t.name for t in filtered} == {"read_file", "browser_screenshot"}
|
||||
assert hidden == []
|
||||
|
||||
@@ -90,8 +86,8 @@ class TestFilterToolsForModel:
|
||||
Tool(name="read_file", description="read a file"),
|
||||
Tool(name="web_search", description="search the web"),
|
||||
]
|
||||
text_only, text_hidden = filter_tools_for_model(tools, "glm-5")
|
||||
vision, vision_hidden = filter_tools_for_model(tools, "gpt-4o")
|
||||
text_only, text_hidden = filter_tools_for_model(tools, "glm-5.1")
|
||||
vision, vision_hidden = filter_tools_for_model(tools, "claude-sonnet-4-5-20250929")
|
||||
assert len(text_only) == 2 and text_hidden == []
|
||||
assert len(vision) == 2 and vision_hidden == []
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ needs and run everything against a temp directory.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
import json
|
||||
from pathlib import Path
|
||||
@@ -278,6 +279,17 @@ async def test_colony_spawn_creates_correct_artifacts(tmp_path, monkeypatch):
|
||||
assert resp.status == 200, await resp.text()
|
||||
body = await resp.json()
|
||||
|
||||
# fork_session_into_colony schedules the compaction + worker-storage
|
||||
# copy onto _BACKGROUND_FORK_TASKS and returns. In prod the colony-
|
||||
# open path blocks on compaction_status.await_completion; the test
|
||||
# skips that step, so drain the bg tasks here before asserting on
|
||||
# the artifacts they produce (otherwise the worker-storage check is
|
||||
# a race that flakes under CI load).
|
||||
from framework.server.routes_execution import _BACKGROUND_FORK_TASKS
|
||||
|
||||
if _BACKGROUND_FORK_TASKS:
|
||||
await asyncio.gather(*list(_BACKGROUND_FORK_TASKS), return_exceptions=True)
|
||||
|
||||
colony_session_id = body["queen_session_id"]
|
||||
assert body["colony_name"] == "honeycomb"
|
||||
assert body["is_new"] is True
|
||||
|
||||
@@ -63,6 +63,7 @@ class MockStreamingLLM(LLMProvider):
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
**kwargs,
|
||||
) -> AsyncIterator:
|
||||
self.stream_calls.append({"messages": messages, "system": system, "tools": tools})
|
||||
|
||||
@@ -311,7 +312,7 @@ class TestReportToParent:
|
||||
model: str = "mock"
|
||||
stream_calls: list[dict] = []
|
||||
|
||||
async def stream(self, messages, system="", tools=None, max_tokens=4096):
|
||||
async def stream(self, messages, system="", tools=None, max_tokens=4096, **kwargs):
|
||||
self.stream_calls.append({"messages": messages})
|
||||
raise RuntimeError("boom — simulated LLM crash")
|
||||
yield # pragma: no cover — make this an async generator
|
||||
@@ -485,8 +486,12 @@ class TestReportToParentGatingByStream:
|
||||
try:
|
||||
# Spawn a parallel worker — its tool list should include report_to_parent
|
||||
await colony.spawn(task="test", count=1)
|
||||
# After the worker's first LLM call, check the recorded tools
|
||||
await asyncio.sleep(0.2) # let the background task run
|
||||
# Poll until the worker fires its first LLM call. Bare sleeps were
|
||||
# flaky on slow Windows CI; loop with a generous deadline instead.
|
||||
for _ in range(100):
|
||||
if llm.stream_calls:
|
||||
break
|
||||
await asyncio.sleep(0.05)
|
||||
assert llm.stream_calls, "Worker never called the LLM"
|
||||
worker_tools = llm.stream_calls[0]["tools"]
|
||||
tool_names = [t.name for t in (worker_tools or [])]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user