fix: remove worker tool from dm

This commit is contained in:
Timothy
2026-04-16 12:23:19 -07:00
parent 7b0948cd62
commit 37672c5581
10 changed files with 320 additions and 74 deletions
+10 -2
View File
@@ -184,8 +184,16 @@ _QUEEN_INDEPENDENT_TOOLS = [
"search_files",
"run_command",
"undo_changes",
# Parallel fan-out (Phase 4 unified ColonyRuntime)
"run_parallel_workers",
# NOTE (2026-04-16): ``run_parallel_workers`` was removed from the
# independent phase. The queen's pure DM mode is for conversation
# with the user; spawning workers from here puts their activity
# into a chat surface that's supposed to stay queen↔user only.
# Users who want to fan out parallel work should (a) use
# ``create_colony`` to fork into a persistent colony (where
# worker activity has its own page), or (b) load an agent via
# build/stage and use ``run_parallel_workers`` in the running
# phase where a worker context already exists.
#
# Fork this session into a persistent colony for headless /
# recurring / background work that needs to keep running in
# parallel to (or after) this chat.
+9
View File
@@ -757,6 +757,8 @@ class ColonyRuntime:
async def spawn_batch(
self,
tasks: list[dict[str, Any]],
*,
tools_override: list[Any] | None = None,
) -> list[str]:
"""Spawn a batch of parallel workers, one per task spec.
@@ -769,6 +771,12 @@ class ColonyRuntime:
The overseer's ``run_parallel_workers`` tool is the usual
caller; it pairs ``spawn_batch`` + ``wait_for_worker_reports``
into a single fan-out/fan-in primitive.
When ``tools_override`` is supplied, every spawned worker
receives that tool list instead of the colony's default. Used
by ``run_parallel_workers`` to drop tools whose credentials
failed the pre-flight check (so the spawned workers don't
waste a startup trying to use them).
"""
worker_ids: list[str] = []
for spec in tasks:
@@ -780,6 +788,7 @@ class ColonyRuntime:
task=task_text,
count=1,
input_data=task_data or {"task": task_text},
tools=tools_override,
)
worker_ids.extend(ids)
return worker_ids
+22 -9
View File
@@ -51,13 +51,18 @@ DEFAULT_EVENT_TYPES = [
# Keepalive interval in seconds
KEEPALIVE_INTERVAL = 15.0
# Phase 5 SSE filter: parallel-worker streams (stream_id="worker:{uuid}")
# publish high-frequency LLM deltas / tool calls that would flood the
# user's queen DM chat. We let only this small allowlist of worker
# events through to the queen-chat SSE so the frontend can render
# fan-out lifecycle and structured fan-in reports without seeing the
# raw worker chatter. Per-worker SSE panels (Phase 5b) bypass this
# filter via a dedicated /workers/{worker_id}/events route.
# Session-SSE worker filter: workers run outside the queen's DM
# chat. Worker activity is observable via the dedicated
# ``/api/workers/{worker_id}/events`` per-worker SSE route, not via
# the session chat. This keeps the queen↔user conversation clean of
# tool-call chatter regardless of whether the worker was spawned by
# ``run_agent_with_input`` (stream_id="worker") or
# ``run_parallel_workers`` (stream_id="worker:{uuid}").
#
# Lifecycle events the frontend needs for fan-in summaries
# (SUBAGENT_REPORT, EXECUTION_COMPLETED, EXECUTION_FAILED) are still
# allowed through so the queen can show "N workers done" surfaces
# without exposing the per-turn chatter.
_WORKER_EVENT_ALLOWLIST = {
EventType.SUBAGENT_REPORT.value,
EventType.EXECUTION_COMPLETED.value,
@@ -66,9 +71,17 @@ _WORKER_EVENT_ALLOWLIST = {
def _is_worker_noise(evt_dict: dict) -> bool:
"""True if the event is a parallel-worker event we should drop."""
"""True if the event belongs to a worker stream and should not
surface in the queen DM chat.
Matches any stream starting with ``worker`` both the bare
``"worker"`` tag used by single-worker spawns and the
``"worker:{uuid}"`` tag used by parallel fan-outs. The allowlist
carves out the three terminal/lifecycle events the UI still
needs to render fan-in summaries.
"""
stream_id = evt_dict.get("stream_id") or ""
if not stream_id.startswith("worker:"):
if not stream_id.startswith("worker"):
return False
return evt_dict.get("type") not in _WORKER_EVENT_ALLOWLIST
+165 -55
View File
@@ -903,10 +903,76 @@ def register_queen_lifecycle_tools(
# ``start_worker`` was removed in the Phase 4 unification — its
# bare-bones spawn duplicated ``run_agent_with_input`` (which has
# credential preflight, concurrency guard, and phase tracking on
# top). The shared preflight timeout below is still used by
# ``run_agent_with_input``.
# top). The shared preflight timeout below is used by both
# ``run_agent_with_input`` and ``run_parallel_workers``.
_START_PREFLIGHT_TIMEOUT = 15 # seconds
async def _preflight_credentials(
legacy: Any,
*,
tool_label: str,
) -> set[str]:
"""Compute tools whose credentials are missing and resync MCP servers.
Shared between ``run_agent_with_input`` (single spawn) and
``run_parallel_workers`` (batch spawn). Returns the set of
tool names whose credentials failed validation; the caller
filters these out of the spawn's tool lists.
Exceptions (including validator bugs) are logged and treated
as "no tools dropped" so a broken validator can't block a
spawn. Wall-clock bound at ``_START_PREFLIGHT_TIMEOUT``
slow credential HTTP health checks can't stall the LLM turn.
"""
unavailable: set[str] = set()
async def _run() -> None:
nonlocal unavailable
try:
from framework.credentials.validation import compute_unavailable_tools
loop = asyncio.get_running_loop()
drop, messages = await loop.run_in_executor(
None,
lambda: compute_unavailable_tools(legacy.graph.nodes),
)
unavailable = drop
if drop:
logger.warning(
"%s: dropping %d tool(s) with unavailable credentials: %s",
tool_label,
len(drop),
"; ".join(messages),
)
except Exception as exc:
logger.warning(
"%s: compute_unavailable_tools raised, proceeding without "
"credential-based tool filtering: %s",
tool_label,
exc,
)
runner = getattr(session, "runner", None)
if runner is not None:
try:
loop = asyncio.get_running_loop()
await loop.run_in_executor(
None,
lambda: runner._tool_registry.resync_mcp_servers_if_needed(),
)
except Exception as exc:
logger.warning("%s: MCP resync failed: %s", tool_label, exc)
try:
await asyncio.wait_for(_run(), timeout=_START_PREFLIGHT_TIMEOUT)
except TimeoutError:
logger.warning(
"%s: credential preflight timed out after %ds — proceeding",
tool_label,
_START_PREFLIGHT_TIMEOUT,
)
return unavailable
# --- stop_worker -----------------------------------------------------------
async def stop_worker(*, reason: str = "Stopped by queen") -> str:
@@ -1078,6 +1144,51 @@ def register_queen_lifecycle_tools(
}
)
# Credential preflight — mirrors the one run_agent_with_input
# performs. Without this, missing credentials (e.g. stale
# GITHUB_TOKEN) fail once PER spawned worker, yielding N
# duplicate error reports for a single fixable issue. Catch
# once upfront, build a filtered tool list, and pass it to
# every spawn via tools_override.
legacy_for_preflight = _get_runtime()
unavailable_tools_parallel: set[str] = set()
tools_override_parallel: list[Any] | None = None
if legacy_for_preflight is not None:
try:
unavailable_tools_parallel = await _preflight_credentials(
legacy_for_preflight, tool_label="run_parallel_workers"
)
except CredentialError as e:
# Structured credential failure: publish the
# CREDENTIALS_REQUIRED event so the frontend's modal
# can fire, and return the same shape the single-path
# tool returns on the same failure.
error_payload = credential_errors_to_json(e)
error_payload["agent_path"] = str(getattr(session, "worker_path", "") or "")
bus = getattr(session, "event_bus", None)
if bus is not None:
await bus.publish(
AgentEvent(
type=EventType.CREDENTIALS_REQUIRED,
stream_id="queen",
data=error_payload,
)
)
return json.dumps(error_payload)
if unavailable_tools_parallel:
colony_tools = list(getattr(colony, "_tools", []) or [])
before = len(colony_tools)
tools_override_parallel = [
t
for t in colony_tools
if getattr(t, "name", None) not in unavailable_tools_parallel
]
logger.info(
"run_parallel_workers: dropped %d tool object(s) from spawn_tools (unavailable credentials)",
before - len(tools_override_parallel),
)
# Colony progress tracker wiring: if the session's loaded
# worker points at a colony directory that has a progress.db,
# inject db_path + colony_id into every per-task ``data``
@@ -1167,10 +1278,31 @@ def register_queen_lifecycle_tools(
)
try:
worker_ids = await colony.spawn_batch(normalised)
worker_ids = await colony.spawn_batch(
normalised,
tools_override=tools_override_parallel,
)
except Exception as e:
return json.dumps({"error": f"spawn_batch failed: {e}"})
# Phase transition — mirrors run_agent_with_input. With the
# batch now spawned, the queen is semantically "running" until
# wait_for_worker_reports returns, so phase-gated running
# tools (inject_message, reply_to_worker, ...) should be
# available. Without this change run_parallel_workers left
# the queen in whatever phase she was in (typically staging).
if phase_state is not None:
try:
await phase_state.switch_to_running()
_update_meta_json(
session_manager, manager_session_id, {"phase": "running"}
)
except Exception as exc:
logger.warning(
"run_parallel_workers: phase transition to 'running' failed (non-fatal): %s",
exc,
)
try:
reports = await colony.wait_for_worker_reports(
worker_ids,
@@ -4030,6 +4162,33 @@ def register_queen_lifecycle_tools(
task,
)
# Concurrency budget check — mirrors run_parallel_workers so a
# queen in a loop can't silently exceed max_concurrent_workers
# by hammering run_agent_with_input. Per-call count is 1, so
# the check is ``active + 1 > max_concurrent``.
colony_cfg = getattr(colony, "_config", None) or getattr(colony, "config", None)
max_concurrent = getattr(colony_cfg, "max_concurrent_workers", None)
if max_concurrent and max_concurrent > 0:
active = 0
try:
workers = getattr(colony, "_workers", {}) or {}
for w in workers.values():
handle = getattr(w, "_task_handle", None)
if handle is not None and not handle.done():
active += 1
except Exception:
active = 0
if active + 1 > max_concurrent:
return json.dumps(
{
"error": (
f"run_agent_with_input would exceed max_concurrent_workers "
f"({active} active + 1 new > {max_concurrent}). "
"Wait for an existing worker to finish or stop one."
)
}
)
try:
# Pre-flight: compute the set of tools whose credentials are
# NOT currently available, and resync MCP servers. We do NOT
@@ -4040,58 +4199,9 @@ def register_queen_lifecycle_tools(
# to block the whole spawn with a CredentialError; the fix
# is to treat unset credentials as "drop these tools" rather
# than "abort the worker".
#
# Note: the MCP admission gate (_build_mcp_admission_gate in
# tool_registry.py) already filters MCP tools at registration
# time. This preflight covers the non-MCP path — tools.py
# discoveries via discover_from_module — which has no
# credential gate of its own.
loop = asyncio.get_running_loop()
unavailable_tools: set[str] = set()
async def _preflight():
nonlocal unavailable_tools
try:
from framework.credentials.validation import compute_unavailable_tools
drop, messages = await loop.run_in_executor(
None,
lambda: compute_unavailable_tools(legacy.graph.nodes),
)
unavailable_tools = drop
if drop:
logger.warning(
"run_agent_with_input: dropping %d tool(s) with "
"unavailable credentials from worker spawn: %s",
len(drop),
"; ".join(messages),
)
except Exception as exc:
# Validation itself failing (not a credential failure —
# a code error in the validator) should not block the
# spawn. Log and proceed as if nothing was dropped.
logger.warning(
"compute_unavailable_tools raised, proceeding without credential-based tool filtering: %s",
exc,
)
runner = getattr(session, "runner", None)
if runner:
try:
await loop.run_in_executor(
None,
lambda: runner._tool_registry.resync_mcp_servers_if_needed(),
)
except Exception as e:
logger.warning("MCP resync failed: %s", e)
try:
await asyncio.wait_for(_preflight(), timeout=_START_PREFLIGHT_TIMEOUT)
except TimeoutError:
logger.warning(
"run_agent_with_input preflight timed out after %ds — proceeding",
_START_PREFLIGHT_TIMEOUT,
)
unavailable_tools = await _preflight_credentials(
legacy, tool_label="run_agent_with_input"
)
# Build a per-spawn AgentSpec that mirrors the loaded
# worker's entry-node identity. This is what makes the
+92 -6
View File
@@ -64,6 +64,12 @@ export interface ChatMessage {
nodeId?: string;
/** Backend execution_id for this message */
executionId?: string;
/** Backend stream_id — the per-worker identity used for grouping
* parallel-spawn workers into their own stacked WorkerRunBubble.
* "queen" for queen messages, "worker" for the single loaded
* worker (run_agent_with_input), or "worker:{uuid}" for each
* parallel worker spawned via run_parallel_workers. */
streamId?: string;
/** True when the message was sent while the queen was still processing */
queued?: boolean;
}
@@ -695,9 +701,36 @@ export default function ChatPanel({
type RenderItem =
| { kind: "message"; msg: ChatMessage }
| { kind: "parallel"; groupId: string; groups: SubagentGroup[] }
| { kind: "worker_run"; runId: string; group: WorkerRunGroup }
| {
kind: "worker_run";
runId: string;
group: WorkerRunGroup;
/** Optional short label shown next to the "Worker" badge.
* Only set when there are multiple parallel workers in the
* same run span (so users can tell them apart). */
label?: string;
}
| { kind: "day_divider"; key: string; createdAt: number };
/** Derive a short label from a parallel-worker stream id.
* `worker:abcdef12-3456-...` → `abcdef12` (first 8 chars of the
* uuid after the `worker:` prefix). Falls back to the first
* message's nodeId when the streamId isn't the expected shape. */
function deriveWorkerLabel(
streamKey: string,
msgs: ChatMessage[],
): string {
if (streamKey.startsWith("worker:")) {
const suffix = streamKey.slice("worker:".length);
// sessions are `session_YYYYMMDD_HHMMSS_<8-hex>` — show the
// trailing hex if present, else first 8 chars of the suffix.
const tail = suffix.match(/_[0-9a-f]{6,}$/i)?.[0]?.slice(1);
return tail ? tail.slice(0, 8) : suffix.slice(0, 8);
}
const nid = msgs.find((m) => m.nodeId)?.nodeId;
return nid || streamKey;
}
const renderItems = useMemo<RenderItem[]>(() => {
const items: RenderItem[] = [];
let i = 0;
@@ -744,11 +777,63 @@ export default function ChatPanel({
}
if (workerMsgs.length > 0) {
items.push({
kind: "worker_run",
runId: `wrun-${firstWorkerMsg.id}`,
group: { messages: workerMsgs },
});
// Parallel fan-out detection: if any message in this span
// is tagged with a parallel-worker streamId (``worker:{uuid}``),
// split the span by streamId and emit one ``worker_run``
// per worker — they render as stacked independent
// ``WorkerRunBubble``s. Un-tagged legacy messages and the
// single-worker ``streamId="worker"`` case fall through to
// the existing single-bubble behavior.
const hasParallel = workerMsgs.some(
(m) => !!m.streamId && /^worker:./.test(m.streamId),
);
if (hasParallel) {
const buckets = new Map<
string,
{ messages: ChatMessage[]; firstAt: number }
>();
// Messages with no streamId (system notes, orphans from
// old restore) attach to the most-recent keyed message's
// bucket so chronology is preserved.
let currentKey: string | null = null;
for (const m of workerMsgs) {
const key =
m.streamId && m.streamId.length > 0
? m.streamId
: currentKey;
if (!key) continue;
if (m.streamId && m.streamId.length > 0) currentKey = m.streamId;
let bucket = buckets.get(key);
if (!bucket) {
bucket = { messages: [], firstAt: m.createdAt ?? 0 };
buckets.set(key, bucket);
}
bucket.messages.push(m);
bucket.firstAt = Math.min(
bucket.firstAt,
m.createdAt ?? Number.POSITIVE_INFINITY,
);
}
const sorted = Array.from(buckets.entries()).sort(
([, a], [, b]) => a.firstAt - b.firstAt,
);
for (const [streamKey, { messages: bucketMsgs }] of sorted) {
items.push({
kind: "worker_run",
runId: `wrun-${firstWorkerMsg.id}-${streamKey}`,
group: { messages: bucketMsgs },
label: deriveWorkerLabel(streamKey, bucketMsgs),
});
}
} else {
items.push({
kind: "worker_run",
runId: `wrun-${firstWorkerMsg.id}`,
group: { messages: workerMsgs },
});
}
}
continue;
}
@@ -958,6 +1043,7 @@ export default function ChatPanel({
<WorkerRunBubble
runId={item.runId}
group={item.group}
label={item.label}
/>
</div>
);
@@ -13,6 +13,11 @@ export interface WorkerRunGroup {
interface WorkerRunBubbleProps {
runId: string;
group: WorkerRunGroup;
/** Short identifier shown next to the "Worker" badge. Populated
* only when the parent grouping has multiple parallel workers
* in the same run span, so N stacked bubbles can be told apart
* at a glance. Omitted for single-worker runs. */
label?: string;
}
/** Parse a tool_status JSON blob into a list of tool entries. */
@@ -60,7 +65,7 @@ function stripMarkdownToPreview(s: string, maxLen = 160): string {
* Expanded: scrollable list of every message and tool status in order.
*/
const WorkerRunBubble = memo(
function WorkerRunBubble({ group }: WorkerRunBubbleProps) {
function WorkerRunBubble({ group, label }: WorkerRunBubbleProps) {
const [expanded, setExpanded] = useState(false);
const bodyRef = useRef<HTMLDivElement>(null);
@@ -138,6 +143,11 @@ const WorkerRunBubble = memo(
<span className="font-medium text-xs" style={{ color: workerColor }}>
Worker
</span>
{label && (
<span className="text-[10px] font-mono text-muted-foreground/80 tabular-nums">
{label}
</span>
)}
<span
className={`text-[10px] font-medium px-1.5 py-0.5 rounded-md ${
isFinished
@@ -278,6 +288,7 @@ const WorkerRunBubble = memo(
},
(prev, next) =>
prev.runId === next.runId &&
prev.label === next.label &&
prev.group.messages.length === next.group.messages.length &&
prev.group.messages[prev.group.messages.length - 1]?.content ===
next.group.messages[next.group.messages.length - 1]?.content
+7
View File
@@ -119,6 +119,7 @@ export function sseEventToChatMessage(
createdAt,
nodeId: event.node_id || undefined,
executionId: event.execution_id || undefined,
streamId: event.stream_id || undefined,
};
}
@@ -138,6 +139,7 @@ export function sseEventToChatMessage(
type: "user",
thread,
createdAt,
streamId: event.stream_id || undefined,
};
}
@@ -158,6 +160,7 @@ export function sseEventToChatMessage(
createdAt,
nodeId: event.node_id || undefined,
executionId: event.execution_id || undefined,
streamId: event.stream_id || undefined,
};
}
@@ -172,6 +175,7 @@ export function sseEventToChatMessage(
type: "system",
thread,
createdAt,
streamId: event.stream_id || undefined,
};
}
@@ -186,6 +190,7 @@ export function sseEventToChatMessage(
type: "system",
thread,
createdAt,
streamId: event.stream_id || undefined,
};
}
@@ -301,6 +306,7 @@ export function replayEvent(
createdAt: eventCreatedAt,
nodeId: event.node_id || undefined,
executionId: event.execution_id || undefined,
streamId: streamId || undefined,
});
break;
}
@@ -331,6 +337,7 @@ export function replayEvent(
createdAt: eventCreatedAt,
nodeId: event.node_id || undefined,
executionId: event.execution_id || undefined,
streamId: streamId || undefined,
});
break;
}
+1
View File
@@ -855,6 +855,7 @@ export default function ColonyChat() {
createdAt: eventCreatedAt,
nodeId: event.node_id || undefined,
executionId: event.execution_id || undefined,
streamId: sid || undefined,
});
return { ...prev, isStreaming: false, activeToolCalls: newActive };
});
+1
View File
@@ -621,6 +621,7 @@ export default function QueenDM() {
createdAt: eventCreatedAt,
nodeId: event.node_id || undefined,
executionId: event.execution_id || undefined,
streamId: sid || undefined,
};
setMessages((prevMsgs) => {
const idx = prevMsgs.findIndex((m) => m.id === msgId);
+1 -1
View File
@@ -66,7 +66,7 @@ class TestDefaultSkillManager:
manager = DefaultSkillManager()
manager.load()
assert len(manager.active_skill_names) == 7
assert len(manager.active_skill_names) == len(SKILL_REGISTRY)
for name in SKILL_REGISTRY:
assert name in manager.active_skill_names