fix(runtime): prevent session restart until cancelled execution fully terminates (#7001)
* fix(runtime): prevent dual execution after forced cancel - keep bookkeeping until task termination - block restart while any execution task is still alive - make execution registration atomic under lock - avoid premature cleanup on cancel timeout - add regression tests for forced-cancel restart scenarios * chore: ruff format and import order --------- Co-authored-by: kowshikmente <kowshikmente@kowshikmentes-MacBook-Pro.local> Co-authored-by: hundao <alchemy_wimp@hotmail.com>
This commit is contained in:
@@ -1672,7 +1672,7 @@ class AgentHost:
|
||||
entry_point_id: str,
|
||||
execution_id: str,
|
||||
graph_id: str | None = None,
|
||||
) -> bool:
|
||||
) -> str:
|
||||
"""
|
||||
Cancel a running execution.
|
||||
|
||||
@@ -1682,11 +1682,11 @@ class AgentHost:
|
||||
graph_id: Graph to search (defaults to active graph)
|
||||
|
||||
Returns:
|
||||
True if cancelled, False if not found
|
||||
Cancellation outcome from the stream.
|
||||
"""
|
||||
stream = self._resolve_stream(entry_point_id, graph_id)
|
||||
if stream is None:
|
||||
return False
|
||||
return "not_found"
|
||||
return await stream.cancel_execution(execution_id)
|
||||
|
||||
# === QUERY OPERATIONS ===
|
||||
|
||||
@@ -16,7 +16,7 @@ from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from framework.host.event_bus import EventBus
|
||||
from framework.host.shared_state import IsolationLevel, SharedBufferManager
|
||||
@@ -48,6 +48,8 @@ class ExecutionAlreadyRunningError(RuntimeError):
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CancelExecutionResult = Literal["cancelled", "cancelling", "not_found"]
|
||||
|
||||
|
||||
class GraphScopedEventBus(EventBus):
|
||||
"""Proxy that stamps ``graph_id`` on every published event.
|
||||
@@ -130,7 +132,7 @@ class ExecutionContext:
|
||||
run_id: str | None = None # Unique ID per trigger() invocation
|
||||
started_at: datetime = field(default_factory=datetime.now)
|
||||
completed_at: datetime | None = None
|
||||
status: str = "pending" # pending, running, completed, failed, paused
|
||||
status: str = "pending" # pending, running, cancelling, completed, failed, paused, cancelled
|
||||
|
||||
|
||||
class ExecutionManager:
|
||||
@@ -315,6 +317,22 @@ class ExecutionManager:
|
||||
"""Return IDs of all currently active executions."""
|
||||
return list(self._active_executions.keys())
|
||||
|
||||
def _get_blocking_execution_ids_locked(self) -> list[str]:
|
||||
"""Return executions that still block a replacement from starting.
|
||||
|
||||
An execution continues to block replacement until its task has
|
||||
terminated and the task's final cleanup has removed its bookkeeping.
|
||||
This is intentional: a timed-out cancellation does not mean the old
|
||||
task is harmless. If it is still alive, it can still write shared
|
||||
session state, so letting a replacement start would guarantee
|
||||
overlapping mutations on the same session.
|
||||
"""
|
||||
blocking_ids: list[str] = list(self._active_executions.keys())
|
||||
for execution_id, task in self._execution_tasks.items():
|
||||
if not task.done() and execution_id not in self._active_executions:
|
||||
blocking_ids.append(execution_id)
|
||||
return blocking_ids
|
||||
|
||||
@property
|
||||
def agent_idle_seconds(self) -> float:
|
||||
"""Seconds since the last agent activity (LLM call, tool call, node transition).
|
||||
@@ -396,15 +414,22 @@ class ExecutionManager:
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the execution stream and cancel active executions."""
|
||||
if not self._running:
|
||||
return
|
||||
async with self._lock:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
self._running = False
|
||||
|
||||
# Cancel all active executions
|
||||
tasks_to_wait = []
|
||||
for _, task in self._execution_tasks.items():
|
||||
if not task.done():
|
||||
# Cancel all active executions, but keep bookkeeping until each
|
||||
# task reaches its own cleanup path.
|
||||
tasks_to_wait: list[asyncio.Task] = []
|
||||
for execution_id, task in self._execution_tasks.items():
|
||||
if task.done():
|
||||
continue
|
||||
ctx = self._active_executions.get(execution_id)
|
||||
if ctx is not None:
|
||||
ctx.status = "cancelling"
|
||||
self._cancel_reasons.setdefault(execution_id, "Execution cancelled")
|
||||
task.cancel()
|
||||
tasks_to_wait.append(task)
|
||||
|
||||
@@ -418,9 +443,6 @@ class ExecutionManager:
|
||||
len(pending),
|
||||
)
|
||||
|
||||
self._execution_tasks.clear()
|
||||
self._active_executions.clear()
|
||||
|
||||
logger.info(f"ExecutionStream '{self.stream_id}' stopped")
|
||||
|
||||
# Emit stream stopped event
|
||||
@@ -569,12 +591,16 @@ class ExecutionManager:
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
if not self._running:
|
||||
raise RuntimeError(f"ExecutionStream '{self.stream_id}' is not running")
|
||||
|
||||
blocking_ids = self._get_blocking_execution_ids_locked()
|
||||
if blocking_ids:
|
||||
raise ExecutionAlreadyRunningError(self.stream_id, blocking_ids)
|
||||
|
||||
self._active_executions[execution_id] = ctx
|
||||
self._completion_events[execution_id] = asyncio.Event()
|
||||
|
||||
# Start execution task
|
||||
task = asyncio.create_task(self._run_execution(ctx))
|
||||
self._execution_tasks[execution_id] = task
|
||||
self._execution_tasks[execution_id] = asyncio.create_task(self._run_execution(ctx))
|
||||
|
||||
logger.debug(f"Queued execution {execution_id} for stream {self.stream_id}")
|
||||
return execution_id
|
||||
@@ -1183,7 +1209,7 @@ class ExecutionManager:
|
||||
"""Get execution context."""
|
||||
return self._active_executions.get(execution_id)
|
||||
|
||||
async def cancel_execution(self, execution_id: str, *, reason: str | None = None) -> bool:
|
||||
async def cancel_execution(self, execution_id: str, *, reason: str | None = None) -> CancelExecutionResult:
|
||||
"""
|
||||
Cancel a running execution.
|
||||
|
||||
@@ -1194,33 +1220,38 @@ class ExecutionManager:
|
||||
provided, defaults to "Execution cancelled".
|
||||
|
||||
Returns:
|
||||
True if cancelled, False if not found
|
||||
"cancelled" if the task fully exited within the grace period,
|
||||
"cancelling" if cancellation was requested but the task is still
|
||||
shutting down, or "not_found" if no active task exists.
|
||||
"""
|
||||
task = self._execution_tasks.get(execution_id)
|
||||
if task and not task.done():
|
||||
async with self._lock:
|
||||
task = self._execution_tasks.get(execution_id)
|
||||
if task is None or task.done():
|
||||
return "not_found"
|
||||
|
||||
# Store the reason so the CancelledError handler can use it
|
||||
# when emitting the pause/fail event.
|
||||
self._cancel_reasons[execution_id] = reason or "Execution cancelled"
|
||||
ctx = self._active_executions.get(execution_id)
|
||||
if ctx is not None:
|
||||
ctx.status = "cancelling"
|
||||
task.cancel()
|
||||
# Wait briefly for the task to finish. Don't block indefinitely —
|
||||
# the task may be stuck in a long LLM API call that doesn't
|
||||
# respond to cancellation quickly.
|
||||
done, _ = await asyncio.wait({task}, timeout=5.0)
|
||||
if not done:
|
||||
# Task didn't finish within timeout — clean up bookkeeping now
|
||||
# so the session doesn't think it still has running executions.
|
||||
# The task will continue winding down in the background and its
|
||||
# finally block will harmlessly pop already-removed keys.
|
||||
logger.warning(
|
||||
"Execution %s did not finish within cancel timeout; force-cleaning bookkeeping",
|
||||
execution_id,
|
||||
)
|
||||
async with self._lock:
|
||||
self._active_executions.pop(execution_id, None)
|
||||
self._execution_tasks.pop(execution_id, None)
|
||||
self._active_executors.pop(execution_id, None)
|
||||
return True
|
||||
return False
|
||||
|
||||
# Wait briefly for the task to finish. Don't block indefinitely —
|
||||
# the task may be stuck in a long LLM API call that doesn't
|
||||
# respond to cancellation quickly.
|
||||
done, _ = await asyncio.wait({task}, timeout=5.0)
|
||||
if not done:
|
||||
# Keep bookkeeping in place until the task's own finally block runs.
|
||||
# We intentionally do not add deferred cleanup keyed by execution_id
|
||||
# here because resumed executions reuse the same id; a delayed pop
|
||||
# could otherwise delete bookkeeping that belongs to the new run.
|
||||
logger.warning(
|
||||
"Execution %s did not finish within cancel timeout; leaving bookkeeping in place until task exit",
|
||||
execution_id,
|
||||
)
|
||||
return "cancelling"
|
||||
return "cancelled"
|
||||
|
||||
# === STATS AND MONITORING ===
|
||||
|
||||
|
||||
@@ -0,0 +1,180 @@
|
||||
"""Regression tests for forced cancellation overlap in ExecutionStream."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.host.event_bus import AgentEvent, EventBus, EventType
|
||||
from framework.host.execution_manager import (
|
||||
EntryPointSpec,
|
||||
ExecutionAlreadyRunningError,
|
||||
ExecutionManager,
|
||||
)
|
||||
from framework.orchestrator.edge import GraphSpec
|
||||
from framework.orchestrator.goal import Goal
|
||||
from framework.orchestrator.orchestrator import ExecutionResult
|
||||
|
||||
|
||||
def _build_stream(tmp_path, *, event_bus: EventBus | None = None) -> ExecutionManager:
|
||||
graph = GraphSpec(
|
||||
id="test-graph",
|
||||
goal_id="goal-1",
|
||||
version="1.0.0",
|
||||
entry_node="start",
|
||||
entry_points={"start": "start"},
|
||||
terminal_nodes=[],
|
||||
pause_nodes=[],
|
||||
nodes=[],
|
||||
edges=[],
|
||||
)
|
||||
goal = Goal(id="goal-1", name="goal-1", description="test goal")
|
||||
entry_spec = EntryPointSpec(
|
||||
id="webhook",
|
||||
name="Webhook",
|
||||
entry_node="start",
|
||||
trigger_type="webhook",
|
||||
isolation_level="shared",
|
||||
max_concurrent=1,
|
||||
)
|
||||
|
||||
storage = SimpleNamespace(base_path=tmp_path)
|
||||
stream = ExecutionManager(
|
||||
stream_id="webhook",
|
||||
entry_spec=entry_spec,
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
state_manager=MagicMock(),
|
||||
storage=storage,
|
||||
outcome_aggregator=MagicMock(),
|
||||
event_bus=event_bus,
|
||||
)
|
||||
stream._running = True
|
||||
return stream
|
||||
|
||||
|
||||
def _install_blocking_executor(monkeypatch, release: asyncio.Event) -> None:
|
||||
class BlockingExecutor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.node_registry = {}
|
||||
|
||||
async def execute(self, *args, **kwargs):
|
||||
while True:
|
||||
try:
|
||||
await release.wait()
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
continue
|
||||
return ExecutionResult(success=True, output={"ok": True})
|
||||
|
||||
monkeypatch.setattr("framework.host.execution_manager.Orchestrator", BlockingExecutor)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forced_cancel_timeout_keeps_stream_locked_until_task_exit(tmp_path, monkeypatch):
|
||||
event_bus = EventBus()
|
||||
stream = _build_stream(tmp_path, event_bus=event_bus)
|
||||
release = asyncio.Event()
|
||||
_install_blocking_executor(monkeypatch, release)
|
||||
|
||||
started_events: list[AgentEvent] = []
|
||||
first_started = asyncio.Event()
|
||||
second_started = asyncio.Event()
|
||||
|
||||
async def on_started(event: AgentEvent) -> None:
|
||||
started_events.append(event)
|
||||
if len(started_events) == 1:
|
||||
first_started.set()
|
||||
elif len(started_events) == 2:
|
||||
second_started.set()
|
||||
|
||||
event_bus.subscribe(
|
||||
event_types=[EventType.EXECUTION_STARTED],
|
||||
handler=on_started,
|
||||
filter_stream="webhook",
|
||||
)
|
||||
|
||||
async def immediate_timeout(_tasks, timeout=None):
|
||||
return set(), set(_tasks)
|
||||
|
||||
execution_id = await stream.execute({}, session_state={"resume_session_id": "session-1"})
|
||||
await asyncio.wait_for(first_started.wait(), timeout=1)
|
||||
|
||||
old_task = stream._execution_tasks[execution_id]
|
||||
monkeypatch.setattr("framework.host.execution_manager.asyncio.wait", immediate_timeout)
|
||||
|
||||
try:
|
||||
cancelled = await stream.cancel_execution(execution_id, reason="forced timeout")
|
||||
|
||||
assert cancelled == "cancelling"
|
||||
assert execution_id in stream._execution_tasks
|
||||
assert execution_id in stream._active_executions
|
||||
assert execution_id in stream._completion_events
|
||||
assert stream._active_executions[execution_id].status == "cancelling"
|
||||
assert not old_task.done()
|
||||
|
||||
with pytest.raises(ExecutionAlreadyRunningError):
|
||||
await stream.execute({}, session_state={"resume_session_id": execution_id})
|
||||
|
||||
assert len(started_events) == 1
|
||||
|
||||
release.set()
|
||||
await asyncio.wait_for(old_task, timeout=1)
|
||||
|
||||
restarted_id = await stream.execute({}, session_state={"resume_session_id": execution_id})
|
||||
assert restarted_id == execution_id
|
||||
await asyncio.wait_for(second_started.wait(), timeout=1)
|
||||
finally:
|
||||
release.set()
|
||||
await asyncio.gather(*stream._execution_tasks.values(), return_exceptions=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_repeated_forced_restarts_do_not_accumulate_parallel_tasks(tmp_path, monkeypatch):
|
||||
event_bus = EventBus()
|
||||
stream = _build_stream(tmp_path, event_bus=event_bus)
|
||||
release = asyncio.Event()
|
||||
_install_blocking_executor(monkeypatch, release)
|
||||
|
||||
started_events: list[AgentEvent] = []
|
||||
first_started = asyncio.Event()
|
||||
|
||||
async def on_started(event: AgentEvent) -> None:
|
||||
started_events.append(event)
|
||||
first_started.set()
|
||||
|
||||
event_bus.subscribe(
|
||||
event_types=[EventType.EXECUTION_STARTED],
|
||||
handler=on_started,
|
||||
filter_stream="webhook",
|
||||
)
|
||||
|
||||
async def immediate_timeout(_tasks, timeout=None):
|
||||
return set(), set(_tasks)
|
||||
|
||||
monkeypatch.setattr("framework.host.execution_manager.asyncio.wait", immediate_timeout)
|
||||
|
||||
execution_id = await stream.execute({}, session_state={"resume_session_id": "session-1"})
|
||||
await asyncio.wait_for(first_started.wait(), timeout=1)
|
||||
|
||||
first_task = stream._execution_tasks[execution_id]
|
||||
|
||||
try:
|
||||
assert await stream.cancel_execution(execution_id, reason="restart-1") == "cancelling"
|
||||
|
||||
with pytest.raises(ExecutionAlreadyRunningError):
|
||||
await stream.execute({}, session_state={"resume_session_id": execution_id})
|
||||
|
||||
with pytest.raises(ExecutionAlreadyRunningError):
|
||||
await stream.execute({}, session_state={"resume_session_id": execution_id})
|
||||
|
||||
assert len(started_events) == 1
|
||||
assert list(stream._execution_tasks) == [execution_id]
|
||||
assert stream._execution_tasks[execution_id] is first_task
|
||||
assert not first_task.done()
|
||||
finally:
|
||||
release.set()
|
||||
await asyncio.wait_for(first_task, timeout=1)
|
||||
@@ -10,6 +10,7 @@ from aiohttp import web
|
||||
|
||||
from framework.agent_loop.conversation import LEGACY_RUN_ID
|
||||
from framework.credentials.validation import validate_agent_credentials
|
||||
from framework.host.execution_manager import ExecutionAlreadyRunningError
|
||||
from framework.server.app import resolve_session, safe_path_segment, sessions_dir
|
||||
from framework.server.routes_sessions import _credential_error_response
|
||||
|
||||
@@ -100,6 +101,17 @@ def _resolve_queen_only_tools() -> frozenset[str]:
|
||||
return frozenset(derived | _QUEEN_LIFECYCLE_EXTRAS)
|
||||
|
||||
|
||||
def _execution_already_running_response(exc: ExecutionAlreadyRunningError) -> web.Response:
|
||||
return web.json_response(
|
||||
{
|
||||
"error": str(exc),
|
||||
"stream_id": exc.stream_id,
|
||||
"active_execution_ids": exc.active_ids,
|
||||
},
|
||||
status=409,
|
||||
)
|
||||
|
||||
|
||||
async def handle_trigger(request: web.Request) -> web.Response:
|
||||
"""POST /api/sessions/{session_id}/trigger — start an execution.
|
||||
|
||||
@@ -141,11 +153,14 @@ async def handle_trigger(request: web.Request) -> web.Response:
|
||||
if "resume_session_id" not in session_state:
|
||||
session_state["resume_session_id"] = session.id
|
||||
|
||||
execution_id = await session.colony_runtime.trigger(
|
||||
entry_point_id,
|
||||
input_data,
|
||||
session_state=session_state,
|
||||
)
|
||||
try:
|
||||
execution_id = await session.colony_runtime.trigger(
|
||||
entry_point_id,
|
||||
input_data,
|
||||
session_state=session_state,
|
||||
)
|
||||
except ExecutionAlreadyRunningError as exc:
|
||||
return _execution_already_running_response(exc)
|
||||
|
||||
# Cancel queen's in-progress LLM turn so it picks up the phase change cleanly
|
||||
if session.queen_executor:
|
||||
@@ -434,11 +449,14 @@ async def handle_resume(request: web.Request) -> web.Response:
|
||||
|
||||
input_data = state.get("input_data", {})
|
||||
|
||||
execution_id = await session.colony_runtime.trigger(
|
||||
entry_points[0].id,
|
||||
input_data=input_data,
|
||||
session_state=resume_session_state,
|
||||
)
|
||||
try:
|
||||
execution_id = await session.colony_runtime.trigger(
|
||||
entry_points[0].id,
|
||||
input_data=input_data,
|
||||
session_state=resume_session_state,
|
||||
)
|
||||
except ExecutionAlreadyRunningError as exc:
|
||||
return _execution_already_running_response(exc)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
@@ -465,6 +483,7 @@ async def handle_pause(request: web.Request) -> web.Response:
|
||||
|
||||
runtime = session.colony_runtime
|
||||
cancelled = []
|
||||
cancelling = []
|
||||
|
||||
for colony_id in runtime.list_graphs():
|
||||
reg = runtime.get_graph_registration(colony_id)
|
||||
@@ -481,23 +500,26 @@ async def handle_pause(request: web.Request) -> web.Response:
|
||||
|
||||
for exec_id in list(stream.active_execution_ids):
|
||||
try:
|
||||
ok = await stream.cancel_execution(exec_id, reason="Execution paused by user")
|
||||
if ok:
|
||||
outcome = await stream.cancel_execution(exec_id, reason="Execution paused by user")
|
||||
if outcome == "cancelled":
|
||||
cancelled.append(exec_id)
|
||||
elif outcome == "cancelling":
|
||||
cancelling.append(exec_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Pause timers so the next tick doesn't restart execution
|
||||
runtime.pause_timers()
|
||||
|
||||
# Switch to staging (agent still loaded, ready to re-run)
|
||||
if session.phase_state is not None:
|
||||
# Only switch to staging once every execution has actually stopped.
|
||||
if session.phase_state is not None and not cancelling:
|
||||
await session.phase_state.switch_to_staging(source="frontend")
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"stopped": bool(cancelled),
|
||||
"stopped": bool(cancelled) and not cancelling,
|
||||
"cancelled": cancelled,
|
||||
"cancelling": cancelling,
|
||||
"timers_paused": True,
|
||||
}
|
||||
)
|
||||
@@ -534,8 +556,9 @@ async def handle_stop(request: web.Request) -> web.Response:
|
||||
if hasattr(node, "cancel_current_turn"):
|
||||
node.cancel_current_turn()
|
||||
|
||||
cancelled = await stream.cancel_execution(execution_id, reason="Execution stopped by user")
|
||||
if cancelled:
|
||||
outcome = await stream.cancel_execution(execution_id, reason="Execution stopped by user")
|
||||
|
||||
if outcome == "cancelled":
|
||||
# Cancel queen's in-progress LLM turn
|
||||
if session.queen_executor:
|
||||
node = session.queen_executor.node_registry.get("queen")
|
||||
@@ -549,9 +572,19 @@ async def handle_stop(request: web.Request) -> web.Response:
|
||||
return web.json_response(
|
||||
{
|
||||
"stopped": True,
|
||||
"cancelling": False,
|
||||
"execution_id": execution_id,
|
||||
}
|
||||
)
|
||||
if outcome == "cancelling":
|
||||
return web.json_response(
|
||||
{
|
||||
"stopped": False,
|
||||
"cancelling": True,
|
||||
"execution_id": execution_id,
|
||||
},
|
||||
status=202,
|
||||
)
|
||||
|
||||
return web.json_response({"stopped": False, "error": "Execution not found"}, status=404)
|
||||
|
||||
@@ -594,11 +627,14 @@ async def handle_replay(request: web.Request) -> web.Response:
|
||||
"run_id": _load_checkpoint_run_id(cp_path),
|
||||
}
|
||||
|
||||
execution_id = await session.colony_runtime.trigger(
|
||||
entry_points[0].id,
|
||||
input_data={},
|
||||
session_state=replay_session_state,
|
||||
)
|
||||
try:
|
||||
execution_id = await session.colony_runtime.trigger(
|
||||
entry_points[0].id,
|
||||
input_data={},
|
||||
session_state=replay_session_state,
|
||||
)
|
||||
except ExecutionAlreadyRunningError as exc:
|
||||
return _execution_already_running_response(exc)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
|
||||
@@ -14,6 +14,7 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
import pytest
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
from framework.host.execution_manager import ExecutionAlreadyRunningError
|
||||
from framework.host.triggers import TriggerDefinition
|
||||
from framework.llm.model_catalog import get_models_catalogue
|
||||
from framework.server import (
|
||||
@@ -89,8 +90,8 @@ class MockStream:
|
||||
_active_executors: dict = field(default_factory=dict)
|
||||
active_execution_ids: set = field(default_factory=set)
|
||||
|
||||
async def cancel_execution(self, execution_id: str, reason: str | None = None) -> bool:
|
||||
return execution_id in self._execution_tasks
|
||||
async def cancel_execution(self, execution_id: str, reason: str | None = None) -> str:
|
||||
return "cancelled" if execution_id in self._execution_tasks else "not_found"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -780,6 +781,21 @@ class TestExecution:
|
||||
data = await resp.json()
|
||||
assert data["execution_id"] == "exec_test_123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_returns_409_when_execution_still_running(self):
|
||||
session = _make_session()
|
||||
session.colony_runtime.trigger = AsyncMock(side_effect=ExecutionAlreadyRunningError("default", ["session-1"]))
|
||||
app = _make_app_with_session(session)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.post(
|
||||
"/api/sessions/test_agent/trigger",
|
||||
json={"entry_point_id": "default", "input_data": {"msg": "hi"}},
|
||||
)
|
||||
assert resp.status == 409
|
||||
data = await resp.json()
|
||||
assert data["stream_id"] == "default"
|
||||
assert data["active_execution_ids"] == ["session-1"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_not_found(self):
|
||||
app = create_app()
|
||||
@@ -918,6 +934,7 @@ class TestExecution:
|
||||
data = await resp.json()
|
||||
assert data["stopped"] is False
|
||||
assert data["cancelled"] == []
|
||||
assert data["cancelling"] == []
|
||||
assert data["timers_paused"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -1027,6 +1044,22 @@ class TestStop:
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["stopped"] is True
|
||||
assert data["cancelling"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_returns_accepted_while_execution_is_still_cancelling(self):
|
||||
session = _make_session()
|
||||
session.colony_runtime._mock_streams["default"].cancel_execution = AsyncMock(return_value="cancelling")
|
||||
app = _make_app_with_session(session)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.post(
|
||||
"/api/sessions/test_agent/stop",
|
||||
json={"execution_id": "exec_abc"},
|
||||
)
|
||||
assert resp.status == 202
|
||||
data = await resp.json()
|
||||
assert data["stopped"] is False
|
||||
assert data["cancelling"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_not_found(self):
|
||||
|
||||
@@ -921,7 +921,6 @@ def register_queen_lifecycle_tools(
|
||||
the queen called this tool.
|
||||
"""
|
||||
stopped_unified = 0
|
||||
stopped_legacy = 0
|
||||
errors: list[str] = []
|
||||
|
||||
# 1. Stop everything on the unified ColonyRuntime. This is
|
||||
@@ -945,9 +944,7 @@ def register_queen_lifecycle_tools(
|
||||
if legacy is not None:
|
||||
try:
|
||||
legacy_workers = legacy.list_workers()
|
||||
stopped_legacy = len(legacy_workers) if isinstance(legacy_workers, list) else 0
|
||||
await legacy.stop_all_workers()
|
||||
legacy.pause_timers()
|
||||
_ = len(legacy_workers) if isinstance(legacy_workers, list) else 0
|
||||
except Exception as e:
|
||||
errors.append(f"legacy: {e}")
|
||||
logger.warning(
|
||||
@@ -958,27 +955,74 @@ def register_queen_lifecycle_tools(
|
||||
if colony is None and legacy is None:
|
||||
return json.dumps({"error": "No runtime on this session."})
|
||||
|
||||
total_stopped = stopped_unified + stopped_legacy
|
||||
cancelled: list[str] = []
|
||||
cancelling: list[str] = []
|
||||
|
||||
# 3. Stop legacy runtime executions with per-stream cancellation so a
|
||||
# still-alive task keeps the worker in "cancelling" instead of being
|
||||
# reported as fully stopped too early.
|
||||
if legacy is not None:
|
||||
try:
|
||||
for graph_id in legacy.list_graphs():
|
||||
reg = legacy.get_graph_registration(graph_id)
|
||||
if reg is None:
|
||||
continue
|
||||
|
||||
for _ep_id, stream in reg.streams.items():
|
||||
for executor in stream._active_executors.values():
|
||||
for node in executor.node_registry.values():
|
||||
if hasattr(node, "signal_shutdown"):
|
||||
node.signal_shutdown()
|
||||
if hasattr(node, "cancel_current_turn"):
|
||||
node.cancel_current_turn()
|
||||
|
||||
for exec_id in list(stream.active_execution_ids):
|
||||
try:
|
||||
outcome = await stream.cancel_execution(exec_id, reason=reason)
|
||||
if outcome == "cancelled":
|
||||
cancelled.append(exec_id)
|
||||
elif outcome == "cancelling":
|
||||
cancelling.append(exec_id)
|
||||
except Exception as e:
|
||||
errors.append(f"legacy-cancel:{exec_id}: {e}")
|
||||
logger.warning("Failed to cancel %s: %s", exec_id, e)
|
||||
|
||||
legacy.pause_timers()
|
||||
except Exception as e:
|
||||
errors.append(f"legacy-runtime: {e}")
|
||||
logger.warning(
|
||||
"stop_worker: failed to inspect legacy runtime executions",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
total_stopped = stopped_unified + len(cancelled)
|
||||
logger.info(
|
||||
"stop_worker: stopped %d workers (unified=%d, legacy=%d). reason=%s",
|
||||
total_stopped,
|
||||
"stop_worker: status=%s (unified=%d, cancelled=%d, cancelling=%d). reason=%s",
|
||||
"cancelling" if cancelling else "stopped" if total_stopped else "no_active_executions",
|
||||
stopped_unified,
|
||||
stopped_legacy,
|
||||
len(cancelled),
|
||||
len(cancelling),
|
||||
reason,
|
||||
)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"status": "stopped",
|
||||
"status": ("cancelling" if cancelling else "stopped" if total_stopped else "no_active_executions"),
|
||||
"workers_stopped": total_stopped,
|
||||
"unified_stopped": stopped_unified,
|
||||
"legacy_stopped": stopped_legacy,
|
||||
"legacy_stopped": len(cancelled),
|
||||
"cancelled": cancelled,
|
||||
"cancelling": cancelling,
|
||||
"timers_paused": legacy is not None,
|
||||
"reason": reason,
|
||||
"errors": errors if errors else None,
|
||||
}
|
||||
)
|
||||
|
||||
def _stop_result_allows_phase_transition(stop_result: str) -> tuple[dict, bool]:
|
||||
result = json.loads(stop_result)
|
||||
return result, result.get("status") != "cancelling"
|
||||
|
||||
_stop_tool = Tool(
|
||||
name="stop_worker",
|
||||
description=(
|
||||
@@ -1561,18 +1605,23 @@ def register_queen_lifecycle_tools(
|
||||
inject config adjustments, or escalate to building/planning.
|
||||
"""
|
||||
stop_result = await stop_worker()
|
||||
result, can_transition = _stop_result_allows_phase_transition(stop_result)
|
||||
|
||||
if phase_state is not None:
|
||||
if phase_state is not None and can_transition:
|
||||
await phase_state.switch_to_reviewing()
|
||||
_update_meta_json(session_manager, manager_session_id, {"phase": "editing"})
|
||||
_update_meta_json(session_manager, manager_session_id, {"phase": "reviewing"})
|
||||
|
||||
result = json.loads(stop_result)
|
||||
result["phase"] = "editing"
|
||||
result["message"] = (
|
||||
"Worker stopped. You are now in editing phase. "
|
||||
"You can re-run with run_agent_with_input(task), tweak config "
|
||||
"with inject_message, or escalate to building/planning."
|
||||
)
|
||||
if can_transition:
|
||||
result["phase"] = "reviewing"
|
||||
result["message"] = (
|
||||
"Worker stopped. You are now in reviewing phase. "
|
||||
"Review the latest results and decide whether to re-run, "
|
||||
"edit the agent, or move into planning."
|
||||
)
|
||||
else:
|
||||
result["message"] = (
|
||||
"Stop requested, but the worker is still shutting down. Phase will not change until shutdown completes."
|
||||
)
|
||||
return json.dumps(result)
|
||||
|
||||
_switch_editing_tool = Tool(
|
||||
@@ -1596,21 +1645,26 @@ def register_queen_lifecycle_tools(
|
||||
async def stop_worker_and_review() -> str:
|
||||
"""Stop the loaded graph and switch to building phase for editing the agent."""
|
||||
stop_result = await stop_worker()
|
||||
result, can_transition = _stop_result_allows_phase_transition(stop_result)
|
||||
|
||||
# Switch to building phase
|
||||
if phase_state is not None:
|
||||
if phase_state is not None and can_transition:
|
||||
await phase_state.switch_to_building()
|
||||
_update_meta_json(session_manager, manager_session_id, {"phase": "building"})
|
||||
|
||||
result = json.loads(stop_result)
|
||||
result["phase"] = "building"
|
||||
result["message"] = (
|
||||
"Graph stopped. You are now in building phase. "
|
||||
"Use your coding tools to modify the agent, then call "
|
||||
"load_built_agent(path) to stage it again."
|
||||
)
|
||||
if can_transition:
|
||||
result["phase"] = "building"
|
||||
result["message"] = (
|
||||
"Graph stopped. You are now in building phase. "
|
||||
"Use your coding tools to modify the agent, then call "
|
||||
"load_built_agent(path) to stage it again."
|
||||
)
|
||||
else:
|
||||
result["message"] = (
|
||||
"Stop requested, but the worker is still shutting down. Phase will not change until shutdown completes."
|
||||
)
|
||||
# Nudge the queen to start coding instead of blocking for user input.
|
||||
if phase_state is not None and phase_state.inject_notification:
|
||||
if can_transition and phase_state is not None and phase_state.inject_notification:
|
||||
await phase_state.inject_notification(
|
||||
"[PHASE CHANGE] Switched to BUILDING phase. Start implementing the changes now."
|
||||
)
|
||||
@@ -1633,19 +1687,25 @@ def register_queen_lifecycle_tools(
|
||||
async def stop_worker_and_plan() -> str:
|
||||
"""Stop the loaded graph and switch to planning phase for diagnosis."""
|
||||
stop_result = await stop_worker()
|
||||
result, can_transition = _stop_result_allows_phase_transition(stop_result)
|
||||
|
||||
# Switch to planning phase
|
||||
if phase_state is not None:
|
||||
if phase_state is not None and can_transition:
|
||||
await phase_state.switch_to_planning(source="tool")
|
||||
_update_meta_json(session_manager, manager_session_id, {"phase": "planning"})
|
||||
|
||||
result = json.loads(stop_result)
|
||||
result["phase"] = "planning"
|
||||
result["message"] = (
|
||||
"Graph stopped. You are now in planning phase. "
|
||||
"Diagnose the issue using read-only tools (checkpoints, logs, sessions), "
|
||||
"discuss a fix plan with the user, then call "
|
||||
"initialize_and_build_agent() to implement the fix."
|
||||
)
|
||||
if can_transition:
|
||||
result["phase"] = "planning"
|
||||
result["message"] = (
|
||||
"Graph stopped. You are now in planning phase. "
|
||||
"Diagnose the issue using read-only tools (checkpoints, logs, sessions), "
|
||||
"discuss a fix plan with the user, then call "
|
||||
"initialize_and_build_agent() to implement the fix."
|
||||
)
|
||||
else:
|
||||
result["message"] = (
|
||||
"Stop requested, but the worker is still shutting down. Phase will not change until shutdown completes."
|
||||
)
|
||||
return json.dumps(result)
|
||||
|
||||
_stop_plan_tool = Tool(
|
||||
@@ -2507,19 +2567,25 @@ def register_queen_lifecycle_tools(
|
||||
2. Edit the agent code → call stop_worker_and_review() to go to building phase
|
||||
"""
|
||||
stop_result = await stop_worker()
|
||||
result, can_transition = _stop_result_allows_phase_transition(stop_result)
|
||||
|
||||
# Switch to staging phase
|
||||
if phase_state is not None:
|
||||
if phase_state is not None and can_transition:
|
||||
await phase_state.switch_to_staging()
|
||||
_update_meta_json(session_manager, manager_session_id, {"phase": "staging"})
|
||||
|
||||
result = json.loads(stop_result)
|
||||
result["phase"] = "staging"
|
||||
result["message"] = (
|
||||
"Graph stopped. You are now in staging phase. "
|
||||
"Ask the user: would they like to re-run with new input, "
|
||||
"or edit the agent code?"
|
||||
)
|
||||
if can_transition:
|
||||
result["phase"] = "staging"
|
||||
result["message"] = (
|
||||
"Graph stopped. You are now in staging phase. "
|
||||
"Ask the user: would they like to re-run with new input, "
|
||||
"or edit the agent code?"
|
||||
)
|
||||
else:
|
||||
result["message"] = (
|
||||
"Stop requested, but the worker is still shutting down. "
|
||||
"Stay in the current phase until shutdown completes."
|
||||
)
|
||||
return json.dumps(result)
|
||||
|
||||
_stop_worker_tool = Tool(
|
||||
|
||||
Reference in New Issue
Block a user