Fix: SessionManager._cleanup_stale_active_sessions indiscriminately cancels healthy concurrent agent sessions (#6081)
* fixes a bug in the SessionManager * chore: remove debug print from test --------- Co-authored-by: hundao <alchemy_wimp@hotmail.com>
This commit is contained in:
@@ -9,6 +9,7 @@ Each stream has:
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from collections import OrderedDict
|
||||
@@ -963,6 +964,9 @@ class ExecutionStream:
|
||||
if error:
|
||||
state.result.error = error
|
||||
|
||||
# Stamp the owning process ID for cross-process stale detection
|
||||
state.pid = os.getpid()
|
||||
|
||||
# Write state.json
|
||||
await self._session_store.write_state(execution_id, state)
|
||||
logger.debug(f"Wrote state.json for session {execution_id} (status={status})")
|
||||
|
||||
@@ -134,6 +134,9 @@ class SessionState(BaseModel):
|
||||
# Input data (for debugging/replay)
|
||||
input_data: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# Process ID of the owning process (for cross-process stale session detection)
|
||||
pid: int | None = None
|
||||
|
||||
# Isolation level (from ExecutionContext)
|
||||
isolation_level: str = "shared"
|
||||
|
||||
|
||||
@@ -278,11 +278,20 @@ class SessionManager:
|
||||
When a new runtime starts, any on-disk session still marked 'active'
|
||||
is from a process that no longer exists. 'Paused' sessions are left
|
||||
intact so they remain resumable.
|
||||
|
||||
Two-layer protection against corrupting live sessions:
|
||||
1. In-memory: skip any session ID currently tracked in self._sessions
|
||||
(guaranteed alive in this process).
|
||||
2. PID validation: if state.json contains a ``pid`` field, check whether
|
||||
that process is still running on the host. If it is, the session is
|
||||
owned by another healthy worker process, so leave it alone.
|
||||
"""
|
||||
sessions_path = Path.home() / ".hive" / "agents" / agent_path.name / "sessions"
|
||||
if not sessions_path.exists():
|
||||
return
|
||||
|
||||
live_session_ids = set(self._sessions.keys())
|
||||
|
||||
for d in sessions_path.iterdir():
|
||||
if not d.is_dir() or not d.name.startswith("session_"):
|
||||
continue
|
||||
@@ -293,6 +302,26 @@ class SessionManager:
|
||||
state = json.loads(state_path.read_text(encoding="utf-8"))
|
||||
if state.get("status") != "active":
|
||||
continue
|
||||
|
||||
# Layer 1: skip sessions that are alive in this process
|
||||
session_id = state.get("session_id", d.name)
|
||||
if session_id in live_session_ids or d.name in live_session_ids:
|
||||
logger.debug(
|
||||
"Skipping live in-memory session '%s' during stale cleanup",
|
||||
d.name,
|
||||
)
|
||||
continue
|
||||
|
||||
# Layer 2: skip sessions whose owning process is still alive
|
||||
recorded_pid = state.get("pid")
|
||||
if recorded_pid is not None and self._is_pid_alive(recorded_pid):
|
||||
logger.debug(
|
||||
"Skipping session '%s' — owning process %d is still running",
|
||||
d.name,
|
||||
recorded_pid,
|
||||
)
|
||||
continue
|
||||
|
||||
state["status"] = "cancelled"
|
||||
state.setdefault("result", {})["error"] = "Stale session: runtime restarted"
|
||||
state.setdefault("timestamps", {})["updated_at"] = datetime.now().isoformat()
|
||||
@@ -303,6 +332,34 @@ class SessionManager:
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
logger.warning("Failed to clean up stale session %s: %s", d.name, e)
|
||||
|
||||
@staticmethod
|
||||
def _is_pid_alive(pid: int) -> bool:
|
||||
"""Check whether a process with the given PID is still running."""
|
||||
import os
|
||||
import platform
|
||||
|
||||
if platform.system() == "Windows":
|
||||
import ctypes
|
||||
|
||||
# PROCESS_QUERY_LIMITED_INFORMATION = 0x1000
|
||||
kernel32 = ctypes.windll.kernel32
|
||||
handle = kernel32.OpenProcess(0x1000, False, pid)
|
||||
if not handle:
|
||||
# 5 is ERROR_ACCESS_DENIED, meaning the process exists but is protected
|
||||
return kernel32.GetLastError() == 5
|
||||
|
||||
exit_code = ctypes.c_ulong()
|
||||
kernel32.GetExitCodeProcess(handle, ctypes.byref(exit_code))
|
||||
kernel32.CloseHandle(handle)
|
||||
# 259 is STILL_ACTIVE
|
||||
return exit_code.value == 259
|
||||
else:
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
except OSError:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def load_worker(
|
||||
self,
|
||||
session_id: str,
|
||||
|
||||
@@ -37,6 +37,7 @@ class MockNodeSpec:
|
||||
client_facing: bool = False
|
||||
success_criteria: str | None = None
|
||||
system_prompt: str | None = None
|
||||
sub_agents: list = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -67,6 +68,7 @@ class MockEntryPoint:
|
||||
name: str = "Default"
|
||||
entry_node: str = "start"
|
||||
trigger_type: str = "manual"
|
||||
trigger_config: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -130,6 +132,9 @@ class MockRuntime:
|
||||
def get_stats(self):
|
||||
return {"running": True, "executions": 1}
|
||||
|
||||
def get_timer_next_fire_in(self, ep_id):
|
||||
return None
|
||||
|
||||
|
||||
class MockAgentInfo:
|
||||
name: str = "test_agent"
|
||||
@@ -1556,3 +1561,106 @@ class TestErrorMiddleware:
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.get("/api/nonexistent")
|
||||
assert resp.status == 404
|
||||
|
||||
|
||||
class TestCleanupStaleActiveSessions:
|
||||
"""Tests for _cleanup_stale_active_sessions with two-layer protection."""
|
||||
|
||||
def _make_manager(self):
|
||||
from framework.server.session_manager import SessionManager
|
||||
|
||||
return SessionManager()
|
||||
|
||||
def _write_state(self, session_dir: Path, status: str, pid: int | None = None) -> None:
|
||||
session_dir.mkdir(parents=True, exist_ok=True)
|
||||
state: dict = {"status": status, "session_id": session_dir.name}
|
||||
if pid is not None:
|
||||
state["pid"] = pid
|
||||
(session_dir / "state.json").write_text(json.dumps(state))
|
||||
|
||||
def _read_state(self, session_dir: Path) -> dict:
|
||||
return json.loads((session_dir / "state.json").read_text())
|
||||
|
||||
def test_stale_session_is_cancelled(self, tmp_path, monkeypatch):
|
||||
"""Truly stale active sessions (no live tracking, no PID) get cancelled."""
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
agent_path = Path("my_agent")
|
||||
sessions_dir = tmp_path / ".hive" / "agents" / "my_agent" / "sessions"
|
||||
session_dir = sessions_dir / "session_stale_001"
|
||||
|
||||
self._write_state(session_dir, "active")
|
||||
|
||||
mgr = self._make_manager()
|
||||
mgr._cleanup_stale_active_sessions(agent_path)
|
||||
|
||||
state = self._read_state(session_dir)
|
||||
assert state["status"] == "cancelled"
|
||||
assert "Stale session" in state["result"]["error"]
|
||||
|
||||
def test_live_in_memory_session_is_skipped(self, tmp_path, monkeypatch):
|
||||
"""Sessions tracked in self._sessions must NOT be cancelled (Layer 1)."""
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
agent_path = Path("my_agent")
|
||||
sessions_dir = tmp_path / ".hive" / "agents" / "my_agent" / "sessions"
|
||||
session_dir = sessions_dir / "session_live_002"
|
||||
|
||||
self._write_state(session_dir, "active")
|
||||
|
||||
mgr = self._make_manager()
|
||||
# Simulate a live session in the manager's in-memory map
|
||||
mgr._sessions["session_live_002"] = MagicMock()
|
||||
|
||||
mgr._cleanup_stale_active_sessions(agent_path)
|
||||
|
||||
state = self._read_state(session_dir)
|
||||
assert state["status"] == "active", "Live in-memory session should NOT be cancelled"
|
||||
|
||||
def test_session_with_live_pid_is_skipped(self, tmp_path, monkeypatch):
|
||||
"""Sessions whose owning PID is still alive must NOT be cancelled (Layer 2)."""
|
||||
import os
|
||||
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
agent_path = Path("my_agent")
|
||||
sessions_dir = tmp_path / ".hive" / "agents" / "my_agent" / "sessions"
|
||||
session_dir = sessions_dir / "session_pid_003"
|
||||
|
||||
# Use the current process PID — guaranteed to be alive
|
||||
self._write_state(session_dir, "active", pid=os.getpid())
|
||||
|
||||
mgr = self._make_manager()
|
||||
mgr._cleanup_stale_active_sessions(agent_path)
|
||||
|
||||
state = self._read_state(session_dir)
|
||||
assert state["status"] == "active", "Session with live PID should NOT be cancelled"
|
||||
|
||||
def test_session_with_dead_pid_is_cancelled(self, tmp_path, monkeypatch):
|
||||
"""Sessions whose owning PID is dead should be cancelled."""
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
agent_path = Path("my_agent")
|
||||
sessions_dir = tmp_path / ".hive" / "agents" / "my_agent" / "sessions"
|
||||
session_dir = sessions_dir / "session_dead_004"
|
||||
|
||||
# Use a PID that is almost certainly not running
|
||||
self._write_state(session_dir, "active", pid=999999999)
|
||||
|
||||
mgr = self._make_manager()
|
||||
mgr._cleanup_stale_active_sessions(agent_path)
|
||||
|
||||
state = self._read_state(session_dir)
|
||||
assert state["status"] == "cancelled"
|
||||
assert "Stale session" in state["result"]["error"]
|
||||
|
||||
def test_paused_session_is_never_touched(self, tmp_path, monkeypatch):
|
||||
"""Paused sessions should remain intact regardless of PID or tracking."""
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
agent_path = Path("my_agent")
|
||||
sessions_dir = tmp_path / ".hive" / "agents" / "my_agent" / "sessions"
|
||||
session_dir = sessions_dir / "session_paused_005"
|
||||
|
||||
self._write_state(session_dir, "paused")
|
||||
|
||||
mgr = self._make_manager()
|
||||
mgr._cleanup_stale_active_sessions(agent_path)
|
||||
|
||||
state = self._read_state(session_dir)
|
||||
assert state["status"] == "paused", "Paused sessions must remain untouched"
|
||||
|
||||
@@ -53,6 +53,8 @@ class StateWriter:
|
||||
# Write to new format if enabled
|
||||
if self.dual_write_enabled:
|
||||
try:
|
||||
# Stamp the owning process ID for cross-process stale detection
|
||||
state.pid = os.getpid()
|
||||
await self.new.write_state(session_id, state)
|
||||
logger.debug(f"Wrote state.json for session {session_id}")
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user