fix(runtime): make rollback restore checkpoint supersede newer checkpoints (#2582)

* Restore rollback checkpoints with fresh ids

* Tighten rollback checkpoint tests and imports

* Update test_run_worker_rollback.py

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
KiteEater
2026-05-02 11:25:45 +08:00
committed by GitHub
parent 866d1ca409
commit 17447fccbe
4 changed files with 75 additions and 19 deletions
+1 -2
View File
@@ -18,6 +18,7 @@ import uuid
from typing import Any from typing import Any
from fastapi import APIRouter, HTTPException, Request from fastapi import APIRouter, HTTPException, Request
from langgraph.checkpoint.base import empty_checkpoint
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from app.gateway.authz import require_permission from app.gateway.authz import require_permission
@@ -262,8 +263,6 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
# Write an empty checkpoint so state endpoints work immediately # Write an empty checkpoint so state endpoints work immediately
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
try: try:
from langgraph.checkpoint.base import empty_checkpoint
ckpt_metadata = { ckpt_metadata = {
"step": -1, "step": -1,
"source": "input", "source": "input",
@@ -23,6 +23,8 @@ from dataclasses import dataclass, field
from functools import lru_cache from functools import lru_cache
from typing import TYPE_CHECKING, Any, Literal, cast from typing import TYPE_CHECKING, Any, Literal, cast
from langgraph.checkpoint.base import empty_checkpoint
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
@@ -442,6 +444,12 @@ async def _rollback_to_pre_run_checkpoint(
if checkpoint_to_restore.get("id") is None: if checkpoint_to_restore.get("id") is None:
logger.warning("Run %s rollback skipped: pre-run checkpoint has no checkpoint id", run_id) logger.warning("Run %s rollback skipped: pre-run checkpoint has no checkpoint id", run_id)
return return
restore_marker = _new_checkpoint_marker()
checkpoint_to_restore = {
**checkpoint_to_restore,
"id": restore_marker["id"],
"ts": restore_marker["ts"],
}
metadata = pre_run_snapshot.get("metadata", {}) metadata = pre_run_snapshot.get("metadata", {})
metadata_to_restore = metadata if isinstance(metadata, dict) else {} metadata_to_restore = metadata if isinstance(metadata, dict) else {}
raw_checkpoint_ns = pre_run_snapshot.get("checkpoint_ns") raw_checkpoint_ns = pre_run_snapshot.get("checkpoint_ns")
@@ -493,6 +501,11 @@ async def _rollback_to_pre_run_checkpoint(
) )
def _new_checkpoint_marker() -> dict[str, str]:
marker = empty_checkpoint()
return {"id": marker["id"], "ts": marker["ts"]}
def _lg_mode_to_sse_event(mode: str) -> str: def _lg_mode_to_sse_event(mode: str) -> str:
"""Map LangGraph internal stream_mode name to SSE event name. """Map LangGraph internal stream_mode name to SSE event name.
-1
View File
@@ -47,4 +47,3 @@ members = ["packages/harness"]
[tool.uv.sources] [tool.uv.sources]
deerflow-harness = { workspace = true } deerflow-harness = { workspace = true }
+61 -16
View File
@@ -3,6 +3,8 @@ from types import SimpleNamespace
from unittest.mock import AsyncMock, call from unittest.mock import AsyncMock, call
import pytest import pytest
from langgraph.checkpoint.base import empty_checkpoint
from langgraph.checkpoint.memory import InMemorySaver
from deerflow.runtime.runs.manager import RunManager from deerflow.runtime.runs.manager import RunManager
from deerflow.runtime.runs.schemas import RunStatus from deerflow.runtime.runs.schemas import RunStatus
@@ -16,6 +18,14 @@ class FakeCheckpointer:
self.aput_writes = AsyncMock() self.aput_writes = AsyncMock()
def _make_checkpoint(checkpoint_id: str, messages: list[str], version: int):
checkpoint = empty_checkpoint()
checkpoint["id"] = checkpoint_id
checkpoint["channel_values"] = {"messages": messages}
checkpoint["channel_versions"] = {"messages": version}
return checkpoint
def test_build_runtime_context_includes_app_config_when_present(): def test_build_runtime_context_includes_app_config_when_present():
app_config = object() app_config = object()
@@ -110,16 +120,16 @@ async def test_rollback_restores_snapshot_without_deleting_thread():
) )
checkpointer.adelete_thread.assert_not_awaited() checkpointer.adelete_thread.assert_not_awaited()
checkpointer.aput.assert_awaited_once_with( checkpointer.aput.assert_awaited_once()
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}, restore_config, restored_checkpoint, restored_metadata, new_versions = checkpointer.aput.await_args.args
{ assert restore_config == {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
"id": "ckpt-1", assert restored_checkpoint["id"] != "ckpt-1"
"channel_versions": {"messages": 3}, assert "channel_versions" in restored_checkpoint
"channel_values": {"messages": ["before"]}, assert "channel_values" in restored_checkpoint
}, assert restored_checkpoint["channel_versions"] == {"messages": 3}
{"source": "input"}, assert restored_checkpoint["channel_values"] == {"messages": ["before"]}
{"messages": 3}, assert restored_metadata == {"source": "input"}
) assert new_versions == {"messages": 3}
assert checkpointer.aput_writes.await_args_list == [ assert checkpointer.aput_writes.await_args_list == [
call( call(
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}, {"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}},
@@ -134,6 +144,40 @@ async def test_rollback_restores_snapshot_without_deleting_thread():
] ]
@pytest.mark.anyio
async def test_rollback_restored_checkpoint_becomes_latest_with_real_checkpointer():
checkpointer = InMemorySaver()
thread_config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
before_checkpoint = _make_checkpoint("0001", ["before"], 1)
before_config = checkpointer.put(thread_config, before_checkpoint, {"step": 1}, {"messages": 1})
after_checkpoint = _make_checkpoint("0002", ["after"], 2)
after_config = checkpointer.put(before_config, after_checkpoint, {"step": 2}, {"messages": 2})
checkpointer.put_writes(after_config, [("messages", "pending-after")], task_id="task-after")
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id="thread-1",
run_id="run-1",
pre_run_checkpoint_id="0001",
pre_run_snapshot={
"checkpoint_ns": "",
"checkpoint": before_checkpoint,
"metadata": {"step": 1},
"pending_writes": [("task-before", "messages", "pending-before")],
},
snapshot_capture_failed=False,
)
latest = checkpointer.get_tuple(thread_config)
assert latest is not None
assert latest.config["configurable"]["checkpoint_id"] != "0001"
assert latest.config["configurable"]["checkpoint_id"] != "0002"
assert latest.checkpoint["channel_values"] == {"messages": ["before"]}
assert latest.pending_writes == [("task-before", "messages", "pending-before")]
assert ("task-after", "messages", "pending-after") not in latest.pending_writes
@pytest.mark.anyio @pytest.mark.anyio
async def test_rollback_deletes_thread_when_no_snapshot_exists(): async def test_rollback_deletes_thread_when_no_snapshot_exists():
checkpointer = FakeCheckpointer(put_result=None) checkpointer = FakeCheckpointer(put_result=None)
@@ -194,12 +238,13 @@ async def test_rollback_normalizes_none_checkpoint_ns_to_root_namespace():
snapshot_capture_failed=False, snapshot_capture_failed=False,
) )
checkpointer.aput.assert_awaited_once_with( checkpointer.aput.assert_awaited_once()
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}, restore_config, restored_checkpoint, restored_metadata, new_versions = checkpointer.aput.await_args.args
{"id": "ckpt-1", "channel_versions": {}}, assert restore_config == {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
{}, assert restored_checkpoint["id"] != "ckpt-1"
{}, assert restored_checkpoint["channel_versions"] == {}
) assert restored_metadata == {}
assert new_versions == {}
@pytest.mark.anyio @pytest.mark.anyio