fix: parallel execution
This commit is contained in:
@@ -1100,6 +1100,89 @@ class ColonyRuntime:
|
||||
return True
|
||||
return False
|
||||
|
||||
def watch_batch_timeouts(
|
||||
self,
|
||||
worker_ids: list[str],
|
||||
*,
|
||||
soft_timeout: float,
|
||||
hard_timeout: float,
|
||||
warning_message: str | None = None,
|
||||
) -> asyncio.Task:
|
||||
"""Schedule a background task that enforces soft + hard timeouts.
|
||||
|
||||
Semantics:
|
||||
* At ``t = soft_timeout`` every worker in ``worker_ids`` that is
|
||||
still active AND hasn't already filed an ``_explicit_report``
|
||||
receives ``warning_message`` via ``send_to_worker`` — the inject
|
||||
appears as a user turn at the next agent-loop boundary, so the
|
||||
worker's LLM can see it and call ``report_to_parent`` with
|
||||
partial results.
|
||||
* At ``t = hard_timeout`` any worker still active is force-stopped
|
||||
via ``stop_worker``. ``Worker.run`` still emits its
|
||||
``SUBAGENT_REPORT`` on cancel (the explicit report survives,
|
||||
if the worker reported just before the stop) so the queen
|
||||
always sees a terminal inject for every spawned worker.
|
||||
|
||||
Returns the scheduled task so callers can await or cancel it.
|
||||
Non-blocking for the caller — the watcher runs on the event loop
|
||||
independently.
|
||||
"""
|
||||
if warning_message is None:
|
||||
grace = max(0.0, hard_timeout - soft_timeout)
|
||||
warning_message = (
|
||||
f"[SOFT TIMEOUT] You've been running for {soft_timeout:.0f}s. "
|
||||
"Wrap up now: call report_to_parent with whatever partial "
|
||||
"results you have. You have "
|
||||
f"~{grace:.0f}s more before a hard stop — anything not "
|
||||
"reported by then will be lost."
|
||||
)
|
||||
|
||||
async def _watch() -> None:
|
||||
try:
|
||||
await asyncio.sleep(soft_timeout)
|
||||
for wid in worker_ids:
|
||||
worker = self._workers.get(wid)
|
||||
if worker is None or not worker.is_active:
|
||||
continue
|
||||
if getattr(worker, "_explicit_report", None) is not None:
|
||||
continue
|
||||
try:
|
||||
await self.send_to_worker(wid, warning_message)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"watch_batch_timeouts: soft-timeout inject failed for %s",
|
||||
wid,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
remaining = hard_timeout - soft_timeout
|
||||
if remaining <= 0:
|
||||
return
|
||||
await asyncio.sleep(remaining)
|
||||
for wid in worker_ids:
|
||||
worker = self._workers.get(wid)
|
||||
if worker is None or not worker.is_active:
|
||||
continue
|
||||
try:
|
||||
await self.stop_worker(wid)
|
||||
logger.info(
|
||||
"watch_batch_timeouts: hard-stopped %s after %ss (no report)",
|
||||
wid,
|
||||
hard_timeout,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"watch_batch_timeouts: hard-stop failed for %s",
|
||||
wid,
|
||||
exc_info=True,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("watch_batch_timeouts: watcher crashed")
|
||||
|
||||
return asyncio.create_task(_watch(), name=f"batch-timeout:{worker_ids[0] if worker_ids else '?'}")
|
||||
|
||||
# ── Status & Query ──────────────────────────────────────────
|
||||
|
||||
def list_workers(self) -> list[WorkerInfo]:
|
||||
|
||||
@@ -170,13 +170,28 @@ class Worker:
|
||||
except asyncio.CancelledError:
|
||||
self.status = WorkerStatus.STOPPED
|
||||
duration = time.monotonic() - self._started_at
|
||||
self._result = WorkerResult(
|
||||
error="Worker stopped by queen",
|
||||
duration_seconds=duration,
|
||||
status="stopped",
|
||||
summary="Worker was cancelled before completion.",
|
||||
)
|
||||
await self._emit_terminal_events(None, force_status="stopped")
|
||||
# Preserve any explicit report the worker's LLM already filed
|
||||
# via ``report_to_parent`` before being cancelled — the caller
|
||||
# cares about that payload even on a hard stop. Only fall back
|
||||
# to the canned "stopped" message when no explicit report exists.
|
||||
explicit = self._explicit_report
|
||||
if explicit is not None:
|
||||
self._result = WorkerResult(
|
||||
error="Worker stopped by queen after reporting",
|
||||
duration_seconds=duration,
|
||||
status=explicit["status"],
|
||||
summary=explicit["summary"],
|
||||
data=explicit["data"],
|
||||
)
|
||||
await self._emit_terminal_events(None, force_status=explicit["status"])
|
||||
else:
|
||||
self._result = WorkerResult(
|
||||
error="Worker stopped by queen",
|
||||
duration_seconds=duration,
|
||||
status="stopped",
|
||||
summary="Worker was cancelled before completion.",
|
||||
)
|
||||
await self._emit_terminal_events(None, force_status="stopped")
|
||||
return self._result
|
||||
|
||||
except Exception as exc:
|
||||
|
||||
@@ -722,47 +722,71 @@ async def create_queen(
|
||||
|
||||
phase_state.inject_notification = _inject_phase_notification
|
||||
|
||||
async def _on_worker_done(event):
|
||||
async def _on_worker_report(event):
|
||||
"""Inject [WORKER_REPORT] into queen as each worker finishes.
|
||||
|
||||
Subscribes to SUBAGENT_REPORT events which carry the worker's
|
||||
real summary/data (preferring any explicit ``report_to_parent``
|
||||
call). Every spawned worker emits exactly one — success,
|
||||
partial, failed, timeout, or stopped. The queen sees the
|
||||
report as the next user turn and can react (reply to user,
|
||||
kick off follow-up work, etc.) without being blocked by the
|
||||
spawn call itself.
|
||||
"""
|
||||
if event.stream_id == "queen":
|
||||
return
|
||||
# "working" is the 3-phase target; "running" is the
|
||||
# legacy 6-phase equivalent — accept both until the
|
||||
# legacy lifecycle is deleted in Commit 2.
|
||||
if phase_state.phase in ("working", "running"):
|
||||
if event.type == EventType.EXECUTION_COMPLETED:
|
||||
session.worker_configured = True
|
||||
output = event.data.get("output", {})
|
||||
output_summary = ""
|
||||
if output:
|
||||
for key, value in output.items():
|
||||
val_str = str(value)
|
||||
if len(val_str) > 200:
|
||||
val_str = val_str[:200] + "..."
|
||||
output_summary += f"\n {key}: {val_str}"
|
||||
_out = output_summary or " (no output keys set)"
|
||||
notification = (
|
||||
"[WORKER_TERMINAL] Worker finished successfully.\n"
|
||||
f"Output:{_out}\n"
|
||||
"Report this to the user. "
|
||||
"Ask if they want to re-run with different input "
|
||||
"or tweak the configuration."
|
||||
)
|
||||
else:
|
||||
error = event.data.get("error", "Unknown error")
|
||||
notification = (
|
||||
"[WORKER_TERMINAL] Worker failed.\n"
|
||||
f"Error: {error}\n"
|
||||
"Report this to the user and help them troubleshoot. "
|
||||
"You can re-run with different input or escalate to "
|
||||
"building/planning if code changes are needed."
|
||||
)
|
||||
data = event.data or {}
|
||||
worker_id = data.get("worker_id", event.node_id or "unknown")
|
||||
status = data.get("status", "unknown")
|
||||
summary = data.get("summary") or "(no summary)"
|
||||
err = data.get("error")
|
||||
payload_data = data.get("data") or {}
|
||||
duration = data.get("duration_seconds")
|
||||
|
||||
await agent_loop.inject_event(notification)
|
||||
lines = ["[WORKER_REPORT]", f"worker_id: {worker_id}", f"status: {status}"]
|
||||
if duration is not None:
|
||||
try:
|
||||
lines.append(f"duration: {float(duration):.1f}s")
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
lines.append(f"summary: {summary}")
|
||||
if err:
|
||||
lines.append(f"error: {err}")
|
||||
if payload_data:
|
||||
# Compact JSON so the queen sees all keys without the
|
||||
# indentation blowing up the turn's token count.
|
||||
try:
|
||||
import json as _json
|
||||
|
||||
lines.append("data: " + _json.dumps(payload_data, ensure_ascii=False, default=str))
|
||||
except Exception:
|
||||
lines.append(f"data: {payload_data!r}")
|
||||
notification = "\n".join(lines)
|
||||
|
||||
await agent_loop.inject_event(notification)
|
||||
session.worker_configured = True
|
||||
|
||||
# Only transition to reviewing once the batch has quieted —
|
||||
# if other workers from a parallel spawn are still live, stay
|
||||
# in working so the queen's tool access (run_parallel_workers,
|
||||
# inject_message, stop_worker) remains available.
|
||||
colony_runtime = getattr(session, "colony_runtime", None)
|
||||
still_active = 0
|
||||
if colony_runtime is not None:
|
||||
try:
|
||||
still_active = sum(
|
||||
1
|
||||
for w in colony_runtime._workers.values() # type: ignore[attr-defined]
|
||||
if getattr(w, "is_active", False)
|
||||
)
|
||||
except Exception:
|
||||
still_active = 0
|
||||
if still_active == 0 and phase_state.phase in ("working", "running"):
|
||||
await phase_state.switch_to_reviewing(source="auto")
|
||||
|
||||
session.event_bus.subscribe(
|
||||
event_types=[EventType.EXECUTION_COMPLETED, EventType.EXECUTION_FAILED],
|
||||
handler=_on_worker_done,
|
||||
event_types=[EventType.SUBAGENT_REPORT],
|
||||
handler=_on_worker_report,
|
||||
)
|
||||
|
||||
# ---- Colony-scoped worker escalation routing ----
|
||||
|
||||
@@ -53,9 +53,21 @@ _WORKER_INHERITED_TOOLS: frozenset[str] = frozenset(
|
||||
|
||||
# Queen-lifecycle tools that are registered into the queen's tool registry
|
||||
# but NOT listed in any _QUEEN_*_TOOLS phase list (they're reachable only via
|
||||
# explicit registration, not phase-based gating). These must still be stripped
|
||||
# from forked worker configs.
|
||||
_QUEEN_LIFECYCLE_EXTRAS: frozenset[str] = frozenset()
|
||||
# explicit registration or as frontend-visible helpers, not phase-based
|
||||
# gating). These must still be stripped from forked / parallel-spawned
|
||||
# worker tool inventories.
|
||||
_QUEEN_LIFECYCLE_EXTRAS: frozenset[str] = frozenset(
|
||||
{
|
||||
# Phase-transition wrappers (method variants are on QueenPhaseState
|
||||
# but the queen also sees them as tools).
|
||||
"switch_to_reviewing",
|
||||
"switch_to_independent",
|
||||
# Frontend helpers that live outside phase lists.
|
||||
"list_credentials",
|
||||
"get_worker_health_summary",
|
||||
"enqueue_task",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _resolve_queen_only_tools() -> frozenset[str]:
|
||||
|
||||
@@ -728,22 +728,33 @@ def register_queen_lifecycle_tools(
|
||||
|
||||
# --- run_parallel_workers --------------------------------------------------
|
||||
#
|
||||
# Phase 4 fan-out tool. Reads the unified ColonyRuntime from
|
||||
# ``session.colony`` (built by SessionManager._start_unified_colony_runtime),
|
||||
# spawns one Worker per task spec via spawn_batch, then blocks on
|
||||
# wait_for_worker_reports until every worker has reported (or the
|
||||
# timeout fires and stragglers are force-stopped). Returns a JSON
|
||||
# array of structured reports {worker_id, status, summary, data,
|
||||
# error, duration_seconds, tokens_used} that the queen reads on its
|
||||
# next turn and aggregates into a user-facing summary.
|
||||
# Fire-and-forget fan-out tool. Spawns one Worker per task spec via
|
||||
# ``colony.spawn_batch`` and returns IMMEDIATELY with the worker ids
|
||||
# and schedule info. The tool no longer blocks on
|
||||
# ``wait_for_worker_reports`` — workers run in the background and
|
||||
# each emits a ``SUBAGENT_REPORT`` event when it terminates.
|
||||
# ``queen_orchestrator._on_worker_report`` subscribes to that event
|
||||
# and injects a ``[WORKER_REPORT]`` user turn into the queen's
|
||||
# conversation, so the queen sees each result as a normal inbound
|
||||
# message and can react without being blocked by the spawn call.
|
||||
#
|
||||
# Worker SUBAGENT_REPORT events flow through session.event_bus, so
|
||||
# the existing SSE pipeline surfaces them automatically. Workers'
|
||||
# individual LLM deltas / tool calls also publish to the same bus
|
||||
# under stream_id="worker:{worker_id}"; SSE filtering for those is
|
||||
# Phase 5 — for now they reach the queen DM channel.
|
||||
# Soft + hard timeouts are enforced by
|
||||
# ``ColonyRuntime.watch_batch_timeouts``: at soft-timeout, every
|
||||
# still-active worker that hasn't already filed an explicit report
|
||||
# receives a SOFT TIMEOUT inject telling it to call report_to_parent
|
||||
# now; at hard-timeout, any remaining worker is force-stopped
|
||||
# (and its SUBAGENT_REPORT still fires — explicit reports set right
|
||||
# before the stop are preserved).
|
||||
|
||||
_RUN_PARALLEL_DEFAULT_TIMEOUT = 600.0 # 10 minutes per batch
|
||||
_RUN_PARALLEL_DEFAULT_TIMEOUT = 600.0 # soft timeout (10 min)
|
||||
_RUN_PARALLEL_HARD_TIMEOUT_CAP = 3600.0 # absolute safety-net cap (1 hour)
|
||||
|
||||
def _compute_hard_timeout(soft: float) -> float:
|
||||
"""Default hard cutoff: max(4× soft, soft + 600), capped at 3600s."""
|
||||
return min(
|
||||
_RUN_PARALLEL_HARD_TIMEOUT_CAP,
|
||||
max(soft * 4.0, soft + 600.0),
|
||||
)
|
||||
|
||||
def _get_unified_colony():
|
||||
"""Read the unified ColonyRuntime (Phase 2 wiring) from session."""
|
||||
@@ -753,11 +764,21 @@ def register_queen_lifecycle_tools(
|
||||
*,
|
||||
tasks: list[dict],
|
||||
timeout: float | None = None,
|
||||
hard_timeout: float | None = None,
|
||||
) -> str:
|
||||
"""Spawn N parallel workers and wait for all reports.
|
||||
"""Spawn N parallel workers and return immediately.
|
||||
|
||||
Each task is a dict ``{"task": str, "data": dict | None}``.
|
||||
Returns a JSON array of structured reports in input order.
|
||||
Workers run in the background; each one emits a ``SUBAGENT_REPORT``
|
||||
when it finishes, which the queen sees as a ``[WORKER_REPORT]``
|
||||
user turn. The queen stays unblocked for other work.
|
||||
|
||||
``timeout`` is a **soft** deadline (default 600s). When it
|
||||
expires, each still-active worker without an explicit report
|
||||
gets a SOFT TIMEOUT inject telling it to call ``report_to_parent``
|
||||
now. Workers ignoring the warning are force-stopped at the
|
||||
``hard_timeout`` (default: derived from ``timeout``, capped at
|
||||
3600s).
|
||||
"""
|
||||
colony = _get_unified_colony()
|
||||
if colony is None:
|
||||
@@ -847,17 +868,26 @@ def register_queen_lifecycle_tools(
|
||||
)
|
||||
return json.dumps(error_payload)
|
||||
|
||||
if unavailable_tools_parallel:
|
||||
colony_tools = list(getattr(colony, "_tools", []) or [])
|
||||
before = len(colony_tools)
|
||||
tools_override_parallel = [
|
||||
t
|
||||
for t in colony_tools
|
||||
if getattr(t, "name", None) not in unavailable_tools_parallel
|
||||
]
|
||||
# Always filter queen-lifecycle tools + any tools with missing
|
||||
# credentials. Without the queen-only strip the spawned worker
|
||||
# inherits run_parallel_workers / create_colony / switch_to_*,
|
||||
# which lets it recurse or flip the parent queen's phase.
|
||||
from framework.server.routes_execution import _resolve_queen_only_tools
|
||||
|
||||
queen_only = _resolve_queen_only_tools()
|
||||
colony_tools = list(getattr(colony, "_tools", []) or [])
|
||||
before = len(colony_tools)
|
||||
tools_override_parallel = [
|
||||
t
|
||||
for t in colony_tools
|
||||
if getattr(t, "name", None) not in queen_only
|
||||
and getattr(t, "name", None) not in unavailable_tools_parallel
|
||||
]
|
||||
dropped = before - len(tools_override_parallel)
|
||||
if dropped:
|
||||
logger.info(
|
||||
"run_parallel_workers: dropped %d tool object(s) from spawn_tools (unavailable credentials)",
|
||||
before - len(tools_override_parallel),
|
||||
"run_parallel_workers: stripped %d queen/unavailable tool(s) from spawn_tools",
|
||||
dropped,
|
||||
)
|
||||
|
||||
# Colony progress tracker wiring: if the session's loaded
|
||||
@@ -956,11 +986,9 @@ def register_queen_lifecycle_tools(
|
||||
except Exception as e:
|
||||
return json.dumps({"error": f"spawn_batch failed: {e}"})
|
||||
|
||||
# Phase transition — with the batch now spawned, the queen is
|
||||
# semantically "working" until wait_for_worker_reports returns,
|
||||
# so phase-gated working tools (inject_message, reply_to_worker,
|
||||
# ...) should be available. Worker-finish auto-transitions the
|
||||
# queen to "reviewing" (see queen_orchestrator._on_worker_done).
|
||||
# Phase transition — workers are now live, queen is in "working"
|
||||
# phase. Worker-finish auto-transitions back to "reviewing" once
|
||||
# every worker has reported (see queen_orchestrator._on_worker_report).
|
||||
if phase_state is not None:
|
||||
try:
|
||||
await phase_state.switch_to_working()
|
||||
@@ -973,33 +1001,50 @@ def register_queen_lifecycle_tools(
|
||||
exc,
|
||||
)
|
||||
|
||||
# Soft + hard timeout watcher runs in the background. At soft,
|
||||
# it injects a "wrap up" message to every still-active worker
|
||||
# without an explicit report; at hard, it force-stops the stragglers.
|
||||
soft_timeout = timeout if timeout is not None else _RUN_PARALLEL_DEFAULT_TIMEOUT
|
||||
hard_timeout_effective = (
|
||||
hard_timeout if hard_timeout is not None else _compute_hard_timeout(soft_timeout)
|
||||
)
|
||||
if hard_timeout_effective <= soft_timeout:
|
||||
hard_timeout_effective = soft_timeout + 60.0 # enforce at least a 60s grace
|
||||
try:
|
||||
reports = await colony.wait_for_worker_reports(
|
||||
colony.watch_batch_timeouts(
|
||||
worker_ids,
|
||||
timeout=timeout if timeout is not None else _RUN_PARALLEL_DEFAULT_TIMEOUT,
|
||||
soft_timeout=soft_timeout,
|
||||
hard_timeout=hard_timeout_effective,
|
||||
)
|
||||
except Exception as e:
|
||||
return json.dumps(
|
||||
{
|
||||
"error": f"wait_for_worker_reports failed: {e}",
|
||||
"worker_ids": worker_ids,
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"run_parallel_workers: failed to schedule timeout watcher (non-fatal): %s",
|
||||
exc,
|
||||
)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"worker_count": len(reports),
|
||||
"reports": reports,
|
||||
"status": "started",
|
||||
"worker_count": len(worker_ids),
|
||||
"worker_ids": worker_ids,
|
||||
"soft_timeout_seconds": soft_timeout,
|
||||
"hard_timeout_seconds": hard_timeout_effective,
|
||||
"message": (
|
||||
"Workers running in the background. Each will report via "
|
||||
"[WORKER_REPORT] as it finishes. Reply to the user naturally "
|
||||
"in the meantime; you do not need to poll."
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
_run_parallel_tool = Tool(
|
||||
name="run_parallel_workers",
|
||||
description=(
|
||||
"Fan out a batch of tasks to parallel workers and wait for all "
|
||||
"reports. Use this when you can split the work into independent "
|
||||
"subtasks that can run concurrently (e.g. fetching N batches "
|
||||
"from an API, processing M files, comparing K candidates).\n\n"
|
||||
"Fan out a batch of tasks to parallel workers and RETURN "
|
||||
"IMMEDIATELY. Workers run in the background; each one reports "
|
||||
"back to you as a [WORKER_REPORT] user turn when it finishes, "
|
||||
"so you stay unblocked and can chat with the user, kick off "
|
||||
"more work, or do anything else in the meantime.\n\n"
|
||||
"CRITICAL: each worker is a FRESH process with NO memory of "
|
||||
"your conversation. Every task string must be FULLY "
|
||||
"self-contained — include the API endpoint, the exact "
|
||||
@@ -1008,14 +1053,20 @@ def register_queen_lifecycle_tools(
|
||||
"questions and cannot see your chat history. Write each "
|
||||
"task as if handing it to a stranger.\n\n"
|
||||
"Each worker runs in isolation with its own AgentLoop and "
|
||||
"reports back via the report_to_parent tool. The call "
|
||||
"blocks until every worker has reported or the timeout "
|
||||
"fires. Returns a JSON object with a 'reports' array; each "
|
||||
"report has worker_id, status "
|
||||
"(success|partial|failed|timeout|stopped), summary, data, "
|
||||
"error, duration_seconds, and tokens_used. Read the "
|
||||
"summaries on your next turn and synthesize a user-facing "
|
||||
"result. Default timeout is 600 seconds (10 minutes)."
|
||||
"reports back via the report_to_parent tool. The tool "
|
||||
"returns a JSON object with status='started' and the list "
|
||||
"of worker_ids you just spawned. Each worker's completion "
|
||||
"arrives later as a [WORKER_REPORT] message containing "
|
||||
"worker_id, status (success|partial|failed|timeout|stopped), "
|
||||
"summary, data, error, duration. Read those messages as "
|
||||
"they arrive and respond to the user naturally.\n\n"
|
||||
"TIMEOUT — 'timeout' is a SOFT deadline (default 600s). "
|
||||
"When it expires, every still-active worker that hasn't "
|
||||
"reported gets a [SOFT TIMEOUT] message telling it to "
|
||||
"call report_to_parent now. It has until 'hard_timeout' "
|
||||
"(default derived from timeout, capped at 3600s) to "
|
||||
"wrap up before being force-stopped. Explicit reports "
|
||||
"filed during the warning window ARE preserved."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
@@ -1049,9 +1100,17 @@ def register_queen_lifecycle_tools(
|
||||
"timeout": {
|
||||
"type": "number",
|
||||
"description": (
|
||||
"Per-batch timeout in seconds. Workers still "
|
||||
"running when the timeout fires are force-stopped "
|
||||
"and reported as status='timeout'. Default 600."
|
||||
"SOFT deadline in seconds. Workers still running "
|
||||
"at this point are messaged to call report_to_parent. "
|
||||
"Default 600 (10 minutes)."
|
||||
),
|
||||
},
|
||||
"hard_timeout": {
|
||||
"type": "number",
|
||||
"description": (
|
||||
"Absolute cutoff in seconds. Workers still active "
|
||||
"at this point are force-stopped. Defaults to "
|
||||
"max(timeout × 4, timeout + 600), capped at 3600s."
|
||||
),
|
||||
},
|
||||
},
|
||||
|
||||
@@ -304,16 +304,22 @@ export default function ColonyChat() {
|
||||
);
|
||||
}
|
||||
if (options?.reconcileOptimisticUser && chatMsg.type === "user" && prev.length > 0) {
|
||||
const lastIdx = prev.length - 1;
|
||||
const lastMsg = prev[lastIdx];
|
||||
// Match by content + timestamp across the whole list (not just
|
||||
// the last slot) so a queued user message still reconciles
|
||||
// even when the queen's previous reply slotted in between.
|
||||
// Also drops the "queued" indicator since the backend has
|
||||
// now confirmed receipt.
|
||||
const incomingTs = chatMsg.createdAt ?? Date.now();
|
||||
const lastTs = lastMsg.createdAt ?? incomingTs;
|
||||
if (
|
||||
lastMsg.type === "user" &&
|
||||
lastMsg.content === chatMsg.content &&
|
||||
Math.abs(incomingTs - lastTs) <= 15000
|
||||
) {
|
||||
return prev.map((m, i) => (i === lastIdx ? { ...m, id: chatMsg.id } : m));
|
||||
const matchIdx = prev.findIndex(
|
||||
(m) =>
|
||||
m.type === "user" &&
|
||||
m.content === chatMsg.content &&
|
||||
Math.abs(incomingTs - (m.createdAt ?? incomingTs)) <= 15000,
|
||||
);
|
||||
if (matchIdx !== -1) {
|
||||
return prev.map((m, i) =>
|
||||
i === matchIdx ? { ...m, id: chatMsg.id, queued: undefined } : m,
|
||||
);
|
||||
}
|
||||
}
|
||||
// Insert in sorted position by createdAt so tool pills and queen
|
||||
|
||||
@@ -515,17 +515,21 @@ export default function QueenDM() {
|
||||
);
|
||||
if (chatMsg) {
|
||||
setMessages((prev) => {
|
||||
// Reconcile optimistic user message
|
||||
// Reconcile optimistic user message. A matching echo from
|
||||
// the backend means the queen has now received this
|
||||
// message, so drop the "queued" indicator (it was set when
|
||||
// the user sent while the queen was still busy).
|
||||
if (chatMsg.type === "user" && prev.length > 0) {
|
||||
const last = prev[prev.length - 1];
|
||||
if (
|
||||
last.type === "user" &&
|
||||
last.content === chatMsg.content &&
|
||||
Math.abs((chatMsg.createdAt ?? 0) - (last.createdAt ?? 0)) <=
|
||||
15000
|
||||
) {
|
||||
const idx = prev.findIndex(
|
||||
(m) =>
|
||||
m.type === "user" &&
|
||||
m.content === chatMsg.content &&
|
||||
Math.abs((chatMsg.createdAt ?? 0) - (m.createdAt ?? 0)) <=
|
||||
15000,
|
||||
);
|
||||
if (idx !== -1) {
|
||||
return prev.map((m, i) =>
|
||||
i === prev.length - 1 ? { ...m, id: chatMsg.id } : m,
|
||||
i === idx ? { ...m, id: chatMsg.id, queued: undefined } : m,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
"""Phase 4 test: run_parallel_workers tool fans out through session.colony.
|
||||
"""Coverage of the run_parallel_workers tool (fire-and-forget contract).
|
||||
|
||||
End-to-end coverage of the queen-side parallel-worker tool:
|
||||
The tool spawns workers and returns immediately with worker_ids. Each
|
||||
worker's completion arrives on the event bus as SUBAGENT_REPORT, which
|
||||
the queen orchestrator's _on_worker_report bridge turns into a
|
||||
[WORKER_REPORT] user inject. These tests verify:
|
||||
|
||||
1. Build a real ``ColonyRuntime`` (the Phase 1 + 2 unified runtime).
|
||||
2. Stand up the queen lifecycle tools registered against a fake session
|
||||
that exposes ``session.colony``.
|
||||
3. Invoke the ``run_parallel_workers`` tool with three task specs whose
|
||||
workers each call ``report_to_parent`` with structured payloads.
|
||||
4. Assert that the tool returns aggregated reports in the same order as
|
||||
the input tasks and that all workers ran in parallel under
|
||||
``{storage}/workers/{worker_id}/``.
|
||||
1. The tool returns immediately with status="started" and the list of
|
||||
worker_ids, not with aggregated reports.
|
||||
2. SUBAGENT_REPORT events are emitted for every spawned worker with
|
||||
the expected payload (status, summary, data).
|
||||
3. Soft-timeout inject reaches still-active workers that haven't
|
||||
filed an explicit report; workers that finished early are not
|
||||
disturbed.
|
||||
4. Hard cutoff force-stops workers that ignored the warning, but
|
||||
preserves any explicit report filed right before the stop.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -24,7 +28,7 @@ import pytest
|
||||
|
||||
from framework.agent_loop.types import AgentSpec
|
||||
from framework.host.colony_runtime import ColonyRuntime
|
||||
from framework.host.event_bus import EventBus
|
||||
from framework.host.event_bus import AgentEvent, EventBus, EventType
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
from framework.llm.stream_events import FinishEvent, TextDeltaEvent, ToolCallEvent
|
||||
from framework.loader.tool_registry import ToolRegistry
|
||||
@@ -112,9 +116,11 @@ class _FakeSession:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_parallel_workers_tool_fans_out_and_aggregates(
|
||||
async def test_run_parallel_workers_tool_returns_immediately_and_emits_reports(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Contract: tool returns status='started' right away; SUBAGENT_REPORT
|
||||
events for every spawned worker arrive asynchronously on the bus."""
|
||||
bus = EventBus()
|
||||
llm = _ByTaskMockLLM(
|
||||
by_task={
|
||||
@@ -128,7 +134,7 @@ async def test_run_parallel_workers_tool_fans_out_and_aggregates(
|
||||
agent_spec=AgentSpec(
|
||||
id="test_colony",
|
||||
name="Test Colony",
|
||||
description="Phase 4 test colony.",
|
||||
description="async-spawn test colony.",
|
||||
system_prompt="You are a test agent.",
|
||||
agent_type="event_loop",
|
||||
output_keys=[],
|
||||
@@ -140,21 +146,27 @@ async def test_run_parallel_workers_tool_fans_out_and_aggregates(
|
||||
tools=[],
|
||||
tool_executor=_stub_executor,
|
||||
event_bus=bus,
|
||||
colony_id="phase4_test",
|
||||
colony_id="async_test",
|
||||
pipeline_stages=[],
|
||||
)
|
||||
await colony.start()
|
||||
|
||||
session = _FakeSession(colony, "phase4_test")
|
||||
# Collect SUBAGENT_REPORT events as they arrive.
|
||||
collected_reports: list[dict] = []
|
||||
|
||||
async def _on_report(event: AgentEvent) -> None:
|
||||
collected_reports.append(event.data or {})
|
||||
|
||||
bus.subscribe(event_types=[EventType.SUBAGENT_REPORT], handler=_on_report)
|
||||
|
||||
session = _FakeSession(colony, "async_test")
|
||||
registry = ToolRegistry()
|
||||
register_queen_lifecycle_tools(registry, session=session, session_id=session.id)
|
||||
|
||||
try:
|
||||
# Tool exists in the registry
|
||||
tools = registry.get_tools()
|
||||
assert "run_parallel_workers" in tools
|
||||
|
||||
# Invoke it via the registered executor
|
||||
executor = registry.get_executor()
|
||||
tool_use = ToolUse(
|
||||
id="tu_run_parallel",
|
||||
@@ -165,23 +177,41 @@ async def test_run_parallel_workers_tool_fans_out_and_aggregates(
|
||||
{"task": "fetch-B"},
|
||||
{"task": "fetch-C"},
|
||||
],
|
||||
"timeout": 10.0,
|
||||
"timeout": 30.0,
|
||||
},
|
||||
)
|
||||
result = executor(tool_use)
|
||||
if asyncio.iscoroutine(result):
|
||||
result = await result
|
||||
|
||||
# The tool must return quickly — well before workers finish.
|
||||
async def _invoke() -> Any:
|
||||
r = executor(tool_use)
|
||||
if asyncio.iscoroutine(r):
|
||||
r = await r
|
||||
return r
|
||||
|
||||
result = await asyncio.wait_for(_invoke(), timeout=5.0)
|
||||
|
||||
assert not result.is_error, f"Tool errored: {result.content}"
|
||||
payload = json.loads(result.content)
|
||||
assert payload["status"] == "started"
|
||||
assert payload["worker_count"] == 3
|
||||
reports = payload["reports"]
|
||||
assert len(reports) == 3
|
||||
assert len(payload["worker_ids"]) == 3
|
||||
assert payload["soft_timeout_seconds"] == 30.0
|
||||
assert payload["hard_timeout_seconds"] >= 30.0 + 60.0 # at least 60s grace
|
||||
assert "[WORKER_REPORT]" in payload["message"]
|
||||
assert "reports" not in payload # fire-and-forget — no aggregated reports
|
||||
|
||||
# Reports come back in the same order as the input tasks
|
||||
statuses = [r["status"] for r in reports]
|
||||
summaries = [r["summary"] for r in reports]
|
||||
assert statuses == ["success", "success", "failed"]
|
||||
# Now wait for workers to finish and SUBAGENT_REPORT to fire.
|
||||
for _ in range(40):
|
||||
if len(collected_reports) >= 3:
|
||||
break
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert len(collected_reports) == 3, (
|
||||
f"Expected 3 SUBAGENT_REPORT events, got {len(collected_reports)}"
|
||||
)
|
||||
statuses = sorted(r["status"] for r in collected_reports)
|
||||
summaries = sorted(r["summary"] for r in collected_reports)
|
||||
assert statuses == ["failed", "success", "success"]
|
||||
assert summaries == ["A done", "B done", "C broke"]
|
||||
|
||||
# Each worker landed under {storage}/workers/{worker_id}/
|
||||
@@ -268,3 +298,175 @@ async def test_run_parallel_workers_validates_tasks_input() -> None:
|
||||
assert "error" in await _call({"tasks": [{"data": {}}]})
|
||||
finally:
|
||||
await colony.stop()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Soft-timeout inject reaches slow workers; explicit-report preservation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _SlowLLM(LLMProvider):
|
||||
"""Mock LLM that stalls on _await_user_input by never yielding a finish.
|
||||
|
||||
Each call to ``stream`` awaits the ``stall_event`` before emitting any
|
||||
tokens — tests drive it via ``release()``. When the worker's LLM is
|
||||
stuck waiting, the watcher's inject message arrives at ``_input_queue``
|
||||
but the LLM turn doesn't see it until the current stream finishes.
|
||||
We simulate "worker is stuck mid-turn" by holding the stall until the
|
||||
test explicitly releases it.
|
||||
"""
|
||||
|
||||
model: str = "mock-slow"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.stall_event = asyncio.Event()
|
||||
self.release_after_inject: bool = False
|
||||
self.report_on_release: tuple[str, str, dict] | None = None
|
||||
self.inject_seen = asyncio.Event()
|
||||
self._turn_count = 0
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator:
|
||||
self._turn_count += 1
|
||||
# On the second call (after the watcher's inject), check whether the
|
||||
# SOFT TIMEOUT message arrived in the conversation.
|
||||
if self._turn_count >= 2:
|
||||
for m in messages:
|
||||
content = m.get("content", "")
|
||||
if isinstance(content, str) and "[SOFT TIMEOUT]" in content:
|
||||
self.inject_seen.set()
|
||||
if self.report_on_release:
|
||||
st, summary, data = self.report_on_release
|
||||
yield ToolCallEvent(
|
||||
tool_use_id=f"tu_report_{self._turn_count}",
|
||||
tool_name="report_to_parent",
|
||||
tool_input={"status": st, "summary": summary, "data": data},
|
||||
)
|
||||
yield FinishEvent(stop_reason="tool_calls", input_tokens=1, output_tokens=1, model="mock-slow")
|
||||
return
|
||||
# Otherwise loop forever (ignore warning).
|
||||
await self.stall_event.wait()
|
||||
yield FinishEvent(stop_reason="stop", input_tokens=1, output_tokens=1, model="mock-slow")
|
||||
return
|
||||
|
||||
# First turn: stall until released.
|
||||
await self.stall_event.wait()
|
||||
yield TextDeltaEvent(content="thinking...", snapshot="thinking...")
|
||||
yield FinishEvent(stop_reason="stop", input_tokens=1, output_tokens=1, model="mock-slow")
|
||||
|
||||
def complete(self, messages, system="", **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="", model="mock-slow", stop_reason="stop")
|
||||
|
||||
|
||||
async def _build_colony(tmp_path: Path, llm: LLMProvider, colony_id: str) -> ColonyRuntime:
|
||||
bus = EventBus()
|
||||
colony = ColonyRuntime(
|
||||
agent_spec=AgentSpec(
|
||||
id="t",
|
||||
name="t",
|
||||
description="t",
|
||||
system_prompt="t",
|
||||
agent_type="event_loop",
|
||||
tool_access_policy="all",
|
||||
),
|
||||
goal=Goal(id="g", name="g", description="g"),
|
||||
storage_path=tmp_path / colony_id,
|
||||
llm=llm,
|
||||
tools=[],
|
||||
tool_executor=_stub_executor,
|
||||
event_bus=bus,
|
||||
colony_id=colony_id,
|
||||
pipeline_stages=[],
|
||||
)
|
||||
await colony.start()
|
||||
return colony
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_watch_batch_timeouts_soft_inject_only_hits_stragglers(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Workers that already filed an explicit report must NOT receive the
|
||||
SOFT TIMEOUT warning inject."""
|
||||
fast_llm = _ByTaskMockLLM(by_task={"fast": _report("success", "fast done", {})})
|
||||
colony = await _build_colony(tmp_path, fast_llm, "soft_fast")
|
||||
|
||||
try:
|
||||
ids = await colony.spawn_batch([{"task": "fast"}])
|
||||
worker = colony._workers[ids[0]]
|
||||
|
||||
# Wait for the worker to finish naturally.
|
||||
for _ in range(50):
|
||||
if not worker.is_active:
|
||||
break
|
||||
await asyncio.sleep(0.05)
|
||||
assert not worker.is_active
|
||||
assert worker._explicit_report is not None # it did call report_to_parent
|
||||
|
||||
# Snapshot input-queue depth, then schedule watcher with short soft.
|
||||
before = worker._input_queue.qsize()
|
||||
task = colony.watch_batch_timeouts(
|
||||
ids,
|
||||
soft_timeout=0.1,
|
||||
hard_timeout=0.2,
|
||||
)
|
||||
await task
|
||||
# Worker already finished + reported — watcher must skip the inject.
|
||||
assert worker._input_queue.qsize() == before
|
||||
finally:
|
||||
await colony.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explicit_report_survives_cancel(tmp_path: Path) -> None:
|
||||
"""A worker that set _explicit_report right before being cancelled must
|
||||
emit a SUBAGENT_REPORT carrying the explicit payload, not the canned
|
||||
'Worker was cancelled' stub."""
|
||||
llm = _ByTaskMockLLM(
|
||||
by_task={"cancel-me": _report("success", "partial wrap-up", {"items_done": 3})}
|
||||
)
|
||||
colony = await _build_colony(tmp_path, llm, "cancel_survives")
|
||||
|
||||
collected: list[dict] = []
|
||||
|
||||
async def _on_report(event: AgentEvent) -> None:
|
||||
collected.append(event.data or {})
|
||||
|
||||
colony.event_bus.subscribe(event_types=[EventType.SUBAGENT_REPORT], handler=_on_report)
|
||||
|
||||
try:
|
||||
ids = await colony.spawn_batch([{"task": "cancel-me"}])
|
||||
worker = colony._workers[ids[0]]
|
||||
|
||||
# Let worker finish first turn so _explicit_report is set,
|
||||
# then cancel it.
|
||||
for _ in range(50):
|
||||
if worker._explicit_report is not None:
|
||||
break
|
||||
await asyncio.sleep(0.05)
|
||||
assert worker._explicit_report is not None, (
|
||||
"Worker never set _explicit_report — test precondition not met"
|
||||
)
|
||||
|
||||
# Cancel the already-reported worker.
|
||||
await colony.stop_worker(ids[0])
|
||||
|
||||
# Drain any pending events.
|
||||
for _ in range(20):
|
||||
if collected:
|
||||
break
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# The report we receive should be the explicit one.
|
||||
assert collected, "No SUBAGENT_REPORT emitted"
|
||||
# Find the cancel-survives worker's report (there should only be one).
|
||||
report = collected[0]
|
||||
assert report.get("summary") == "partial wrap-up", report
|
||||
assert report.get("data", {}).get("items_done") == 3, report
|
||||
finally:
|
||||
await colony.stop()
|
||||
|
||||
Reference in New Issue
Block a user