fix: bug causing queen message injection when resuming a session

This commit is contained in:
Richard Tang
2026-04-08 16:48:46 -07:00
parent d19cb2843e
commit 2b8ed0eb05
2 changed files with 150 additions and 8 deletions
+149 -7
View File
@@ -878,13 +878,56 @@ def _sort_records(records: list[dict[str, Any]]) -> list[dict[str, Any]]:
)
def _load_session_data(
logs_dir: Path, session_id: str, limit_files: int
) -> list[dict[str, Any]] | None:
"""Load records for a specific session on demand."""
if not logs_dir.exists():
return None
files = sorted(
[path for path in logs_dir.iterdir() if path.is_file() and path.suffix == ".jsonl"],
key=lambda path: path.stat().st_mtime,
reverse=True,
)[:limit_files]
records: list[dict[str, Any]] = []
for path in files:
try:
with path.open(encoding="utf-8") as handle:
for line_number, raw_line in enumerate(handle, start=1):
line = raw_line.strip()
if not line:
continue
try:
payload = json.loads(line)
except json.JSONDecodeError:
payload = {
"timestamp": "",
"execution_id": "",
"assistant_text": "",
"_parse_error": f"{path.name}:{line_number}",
"_raw_line": line,
}
# Only include records for this session
if str(payload.get("execution_id") or "").strip() == session_id:
payload["_log_file"] = str(path)
records.append(payload)
except OSError:
continue
return _sort_records(records) if records else None
def _run_server(
html: str,
sessions: dict[str, list[dict[str, Any]]],
logs_dir: Path,
limit_files: int,
port: int,
no_open: bool,
) -> None:
html_bytes = html.encode("utf-8")
session_cache: dict[str, list[dict[str, Any]]] = {}
class Handler(http.server.BaseHTTPRequestHandler):
def do_GET(self) -> None:
@@ -892,11 +935,17 @@ def _run_server(
self._respond(200, "text/html; charset=utf-8", html_bytes)
elif self.path.startswith("/api/session/"):
sid = urllib.parse.unquote(self.path[len("/api/session/") :])
records = sessions.get(sid)
# Check cache first
if sid in session_cache:
records = session_cache[sid]
else:
records = _load_session_data(logs_dir, sid, limit_files)
if records is not None:
session_cache[sid] = records
if records is None:
self._respond(404, "application/json", b"[]")
else:
body = json.dumps(_sort_records(records), ensure_ascii=False).encode("utf-8")
body = json.dumps(records, ensure_ascii=False).encode("utf-8")
self._respond(200, "application/json", body)
else:
self.send_error(404)
@@ -927,13 +976,106 @@ def _run_server(
server.server_close()
def _discover_session_summaries(
logs_dir: Path, limit_files: int, include_tests: bool
) -> list[SessionSummary]:
"""Discover only session summaries without loading full record data."""
if not logs_dir.exists():
raise FileNotFoundError(f"log directory not found: {logs_dir}")
files = sorted(
[path for path in logs_dir.iterdir() if path.is_file() and path.suffix == ".jsonl"],
key=lambda path: path.stat().st_mtime,
reverse=True,
)[:limit_files]
# Collect minimal info per session: just first/last records and metadata
by_session: dict[str, list[dict[str, Any]]] = defaultdict(list)
for path in files:
try:
with path.open(encoding="utf-8") as handle:
for raw_line in handle:
line = raw_line.strip()
if not line:
continue
try:
payload = json.loads(line)
except json.JSONDecodeError:
continue
execution_id = str(payload.get("execution_id") or "").strip()
if execution_id:
# Store minimal data for summary generation
minimal = {
"timestamp": payload.get("timestamp", ""),
"iteration": payload.get("iteration", 0),
"stream_id": payload.get("stream_id", ""),
"node_id": payload.get("node_id", ""),
"token_counts": payload.get("token_counts", {}),
"_log_file": str(path),
}
by_session[execution_id].append(minimal)
except OSError:
continue
# Filter out test sessions if needed
if not include_tests:
by_session = {
eid: recs
for eid, recs in by_session.items()
if not _is_test_session(eid, recs)
}
summaries: list[SessionSummary] = []
for execution_id, session_records in by_session.items():
session_records.sort(
key=lambda record: (
str(record.get("timestamp", "")),
record.get("iteration", 0),
)
)
first = session_records[0]
last = session_records[-1]
summaries.append(
SessionSummary(
execution_id=execution_id,
log_file=str(first.get("_log_file", "")),
start_timestamp=str(first.get("timestamp", "")),
end_timestamp=str(last.get("timestamp", "")),
turn_count=len(session_records),
streams=sorted(
{str(r.get("stream_id", "")) for r in session_records if r.get("stream_id")}
),
nodes=sorted(
{str(r.get("node_id", "")) for r in session_records if r.get("node_id")}
),
models=sorted(
{
str(r.get("token_counts", {}).get("model", ""))
for r in session_records
if isinstance(r.get("token_counts"), dict)
and r.get("token_counts", {}).get("model")
}
),
)
)
summaries.sort(key=lambda summary: summary.start_timestamp, reverse=True)
return summaries
def main() -> int:
args = _parse_args()
records = _discover_records(args.logs_dir.expanduser(), args.limit_files)
summaries, sessions = _group_sessions(records, include_tests=args.include_tests)
logs_dir = args.logs_dir.expanduser()
# Only discover summaries, not full session data
summaries = _discover_session_summaries(
logs_dir, args.limit_files, args.include_tests
)
initial_session_id = args.session or (summaries[0].execution_id if summaries else "")
if initial_session_id and initial_session_id not in sessions:
if initial_session_id and not any(
s.execution_id == initial_session_id for s in summaries
):
print(f"session not found: {initial_session_id}")
return 1
@@ -945,7 +1087,7 @@ def main() -> int:
print(args.output)
return 0
_run_server(html_report, sessions, args.port, args.no_open)
_run_server(html_report, logs_dir, args.limit_files, args.port, args.no_open)
return 0