fix: bug causing queen message injection when resuming a session
This commit is contained in:
@@ -878,13 +878,56 @@ def _sort_records(records: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
)
|
||||
|
||||
|
||||
def _load_session_data(
|
||||
logs_dir: Path, session_id: str, limit_files: int
|
||||
) -> list[dict[str, Any]] | None:
|
||||
"""Load records for a specific session on demand."""
|
||||
if not logs_dir.exists():
|
||||
return None
|
||||
|
||||
files = sorted(
|
||||
[path for path in logs_dir.iterdir() if path.is_file() and path.suffix == ".jsonl"],
|
||||
key=lambda path: path.stat().st_mtime,
|
||||
reverse=True,
|
||||
)[:limit_files]
|
||||
|
||||
records: list[dict[str, Any]] = []
|
||||
for path in files:
|
||||
try:
|
||||
with path.open(encoding="utf-8") as handle:
|
||||
for line_number, raw_line in enumerate(handle, start=1):
|
||||
line = raw_line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
payload = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
payload = {
|
||||
"timestamp": "",
|
||||
"execution_id": "",
|
||||
"assistant_text": "",
|
||||
"_parse_error": f"{path.name}:{line_number}",
|
||||
"_raw_line": line,
|
||||
}
|
||||
# Only include records for this session
|
||||
if str(payload.get("execution_id") or "").strip() == session_id:
|
||||
payload["_log_file"] = str(path)
|
||||
records.append(payload)
|
||||
except OSError:
|
||||
continue
|
||||
|
||||
return _sort_records(records) if records else None
|
||||
|
||||
|
||||
def _run_server(
|
||||
html: str,
|
||||
sessions: dict[str, list[dict[str, Any]]],
|
||||
logs_dir: Path,
|
||||
limit_files: int,
|
||||
port: int,
|
||||
no_open: bool,
|
||||
) -> None:
|
||||
html_bytes = html.encode("utf-8")
|
||||
session_cache: dict[str, list[dict[str, Any]]] = {}
|
||||
|
||||
class Handler(http.server.BaseHTTPRequestHandler):
|
||||
def do_GET(self) -> None:
|
||||
@@ -892,11 +935,17 @@ def _run_server(
|
||||
self._respond(200, "text/html; charset=utf-8", html_bytes)
|
||||
elif self.path.startswith("/api/session/"):
|
||||
sid = urllib.parse.unquote(self.path[len("/api/session/") :])
|
||||
records = sessions.get(sid)
|
||||
# Check cache first
|
||||
if sid in session_cache:
|
||||
records = session_cache[sid]
|
||||
else:
|
||||
records = _load_session_data(logs_dir, sid, limit_files)
|
||||
if records is not None:
|
||||
session_cache[sid] = records
|
||||
if records is None:
|
||||
self._respond(404, "application/json", b"[]")
|
||||
else:
|
||||
body = json.dumps(_sort_records(records), ensure_ascii=False).encode("utf-8")
|
||||
body = json.dumps(records, ensure_ascii=False).encode("utf-8")
|
||||
self._respond(200, "application/json", body)
|
||||
else:
|
||||
self.send_error(404)
|
||||
@@ -927,13 +976,106 @@ def _run_server(
|
||||
server.server_close()
|
||||
|
||||
|
||||
def _discover_session_summaries(
|
||||
logs_dir: Path, limit_files: int, include_tests: bool
|
||||
) -> list[SessionSummary]:
|
||||
"""Discover only session summaries without loading full record data."""
|
||||
if not logs_dir.exists():
|
||||
raise FileNotFoundError(f"log directory not found: {logs_dir}")
|
||||
|
||||
files = sorted(
|
||||
[path for path in logs_dir.iterdir() if path.is_file() and path.suffix == ".jsonl"],
|
||||
key=lambda path: path.stat().st_mtime,
|
||||
reverse=True,
|
||||
)[:limit_files]
|
||||
|
||||
# Collect minimal info per session: just first/last records and metadata
|
||||
by_session: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
||||
for path in files:
|
||||
try:
|
||||
with path.open(encoding="utf-8") as handle:
|
||||
for raw_line in handle:
|
||||
line = raw_line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
payload = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
execution_id = str(payload.get("execution_id") or "").strip()
|
||||
if execution_id:
|
||||
# Store minimal data for summary generation
|
||||
minimal = {
|
||||
"timestamp": payload.get("timestamp", ""),
|
||||
"iteration": payload.get("iteration", 0),
|
||||
"stream_id": payload.get("stream_id", ""),
|
||||
"node_id": payload.get("node_id", ""),
|
||||
"token_counts": payload.get("token_counts", {}),
|
||||
"_log_file": str(path),
|
||||
}
|
||||
by_session[execution_id].append(minimal)
|
||||
except OSError:
|
||||
continue
|
||||
|
||||
# Filter out test sessions if needed
|
||||
if not include_tests:
|
||||
by_session = {
|
||||
eid: recs
|
||||
for eid, recs in by_session.items()
|
||||
if not _is_test_session(eid, recs)
|
||||
}
|
||||
|
||||
summaries: list[SessionSummary] = []
|
||||
for execution_id, session_records in by_session.items():
|
||||
session_records.sort(
|
||||
key=lambda record: (
|
||||
str(record.get("timestamp", "")),
|
||||
record.get("iteration", 0),
|
||||
)
|
||||
)
|
||||
first = session_records[0]
|
||||
last = session_records[-1]
|
||||
summaries.append(
|
||||
SessionSummary(
|
||||
execution_id=execution_id,
|
||||
log_file=str(first.get("_log_file", "")),
|
||||
start_timestamp=str(first.get("timestamp", "")),
|
||||
end_timestamp=str(last.get("timestamp", "")),
|
||||
turn_count=len(session_records),
|
||||
streams=sorted(
|
||||
{str(r.get("stream_id", "")) for r in session_records if r.get("stream_id")}
|
||||
),
|
||||
nodes=sorted(
|
||||
{str(r.get("node_id", "")) for r in session_records if r.get("node_id")}
|
||||
),
|
||||
models=sorted(
|
||||
{
|
||||
str(r.get("token_counts", {}).get("model", ""))
|
||||
for r in session_records
|
||||
if isinstance(r.get("token_counts"), dict)
|
||||
and r.get("token_counts", {}).get("model")
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
summaries.sort(key=lambda summary: summary.start_timestamp, reverse=True)
|
||||
return summaries
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = _parse_args()
|
||||
records = _discover_records(args.logs_dir.expanduser(), args.limit_files)
|
||||
summaries, sessions = _group_sessions(records, include_tests=args.include_tests)
|
||||
logs_dir = args.logs_dir.expanduser()
|
||||
|
||||
# Only discover summaries, not full session data
|
||||
summaries = _discover_session_summaries(
|
||||
logs_dir, args.limit_files, args.include_tests
|
||||
)
|
||||
|
||||
initial_session_id = args.session or (summaries[0].execution_id if summaries else "")
|
||||
if initial_session_id and initial_session_id not in sessions:
|
||||
if initial_session_id and not any(
|
||||
s.execution_id == initial_session_id for s in summaries
|
||||
):
|
||||
print(f"session not found: {initial_session_id}")
|
||||
return 1
|
||||
|
||||
@@ -945,7 +1087,7 @@ def main() -> int:
|
||||
print(args.output)
|
||||
return 0
|
||||
|
||||
_run_server(html_report, sessions, args.port, args.no_open)
|
||||
_run_server(html_report, logs_dir, args.limit_files, args.port, args.no_open)
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user