fix(runtime): make rollback restore checkpoint supersede newer checkpoints (#2582)
* Restore rollback checkpoints with fresh ids * Tighten rollback checkpoint tests and imports * Update test_run_worker_rollback.py --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -18,6 +18,7 @@ import uuid
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Request
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from langgraph.checkpoint.base import empty_checkpoint
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from app.gateway.authz import require_permission
|
from app.gateway.authz import require_permission
|
||||||
@@ -262,8 +263,6 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
|||||||
# Write an empty checkpoint so state endpoints work immediately
|
# Write an empty checkpoint so state endpoints work immediately
|
||||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||||
try:
|
try:
|
||||||
from langgraph.checkpoint.base import empty_checkpoint
|
|
||||||
|
|
||||||
ckpt_metadata = {
|
ckpt_metadata = {
|
||||||
"step": -1,
|
"step": -1,
|
||||||
"source": "input",
|
"source": "input",
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ from dataclasses import dataclass, field
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||||
|
|
||||||
|
from langgraph.checkpoint.base import empty_checkpoint
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
@@ -442,6 +444,12 @@ async def _rollback_to_pre_run_checkpoint(
|
|||||||
if checkpoint_to_restore.get("id") is None:
|
if checkpoint_to_restore.get("id") is None:
|
||||||
logger.warning("Run %s rollback skipped: pre-run checkpoint has no checkpoint id", run_id)
|
logger.warning("Run %s rollback skipped: pre-run checkpoint has no checkpoint id", run_id)
|
||||||
return
|
return
|
||||||
|
restore_marker = _new_checkpoint_marker()
|
||||||
|
checkpoint_to_restore = {
|
||||||
|
**checkpoint_to_restore,
|
||||||
|
"id": restore_marker["id"],
|
||||||
|
"ts": restore_marker["ts"],
|
||||||
|
}
|
||||||
metadata = pre_run_snapshot.get("metadata", {})
|
metadata = pre_run_snapshot.get("metadata", {})
|
||||||
metadata_to_restore = metadata if isinstance(metadata, dict) else {}
|
metadata_to_restore = metadata if isinstance(metadata, dict) else {}
|
||||||
raw_checkpoint_ns = pre_run_snapshot.get("checkpoint_ns")
|
raw_checkpoint_ns = pre_run_snapshot.get("checkpoint_ns")
|
||||||
@@ -493,6 +501,11 @@ async def _rollback_to_pre_run_checkpoint(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _new_checkpoint_marker() -> dict[str, str]:
|
||||||
|
marker = empty_checkpoint()
|
||||||
|
return {"id": marker["id"], "ts": marker["ts"]}
|
||||||
|
|
||||||
|
|
||||||
def _lg_mode_to_sse_event(mode: str) -> str:
|
def _lg_mode_to_sse_event(mode: str) -> str:
|
||||||
"""Map LangGraph internal stream_mode name to SSE event name.
|
"""Map LangGraph internal stream_mode name to SSE event name.
|
||||||
|
|
||||||
|
|||||||
@@ -47,4 +47,3 @@ members = ["packages/harness"]
|
|||||||
|
|
||||||
[tool.uv.sources]
|
[tool.uv.sources]
|
||||||
deerflow-harness = { workspace = true }
|
deerflow-harness = { workspace = true }
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ from types import SimpleNamespace
|
|||||||
from unittest.mock import AsyncMock, call
|
from unittest.mock import AsyncMock, call
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from langgraph.checkpoint.base import empty_checkpoint
|
||||||
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
|
|
||||||
from deerflow.runtime.runs.manager import RunManager
|
from deerflow.runtime.runs.manager import RunManager
|
||||||
from deerflow.runtime.runs.schemas import RunStatus
|
from deerflow.runtime.runs.schemas import RunStatus
|
||||||
@@ -16,6 +18,14 @@ class FakeCheckpointer:
|
|||||||
self.aput_writes = AsyncMock()
|
self.aput_writes = AsyncMock()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_checkpoint(checkpoint_id: str, messages: list[str], version: int):
|
||||||
|
checkpoint = empty_checkpoint()
|
||||||
|
checkpoint["id"] = checkpoint_id
|
||||||
|
checkpoint["channel_values"] = {"messages": messages}
|
||||||
|
checkpoint["channel_versions"] = {"messages": version}
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
|
||||||
def test_build_runtime_context_includes_app_config_when_present():
|
def test_build_runtime_context_includes_app_config_when_present():
|
||||||
app_config = object()
|
app_config = object()
|
||||||
|
|
||||||
@@ -110,16 +120,16 @@ async def test_rollback_restores_snapshot_without_deleting_thread():
|
|||||||
)
|
)
|
||||||
|
|
||||||
checkpointer.adelete_thread.assert_not_awaited()
|
checkpointer.adelete_thread.assert_not_awaited()
|
||||||
checkpointer.aput.assert_awaited_once_with(
|
checkpointer.aput.assert_awaited_once()
|
||||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
|
restore_config, restored_checkpoint, restored_metadata, new_versions = checkpointer.aput.await_args.args
|
||||||
{
|
assert restore_config == {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
|
||||||
"id": "ckpt-1",
|
assert restored_checkpoint["id"] != "ckpt-1"
|
||||||
"channel_versions": {"messages": 3},
|
assert "channel_versions" in restored_checkpoint
|
||||||
"channel_values": {"messages": ["before"]},
|
assert "channel_values" in restored_checkpoint
|
||||||
},
|
assert restored_checkpoint["channel_versions"] == {"messages": 3}
|
||||||
{"source": "input"},
|
assert restored_checkpoint["channel_values"] == {"messages": ["before"]}
|
||||||
{"messages": 3},
|
assert restored_metadata == {"source": "input"}
|
||||||
)
|
assert new_versions == {"messages": 3}
|
||||||
assert checkpointer.aput_writes.await_args_list == [
|
assert checkpointer.aput_writes.await_args_list == [
|
||||||
call(
|
call(
|
||||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}},
|
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}},
|
||||||
@@ -134,6 +144,40 @@ async def test_rollback_restores_snapshot_without_deleting_thread():
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_rollback_restored_checkpoint_becomes_latest_with_real_checkpointer():
|
||||||
|
checkpointer = InMemorySaver()
|
||||||
|
thread_config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
|
||||||
|
before_checkpoint = _make_checkpoint("0001", ["before"], 1)
|
||||||
|
before_config = checkpointer.put(thread_config, before_checkpoint, {"step": 1}, {"messages": 1})
|
||||||
|
after_checkpoint = _make_checkpoint("0002", ["after"], 2)
|
||||||
|
after_config = checkpointer.put(before_config, after_checkpoint, {"step": 2}, {"messages": 2})
|
||||||
|
checkpointer.put_writes(after_config, [("messages", "pending-after")], task_id="task-after")
|
||||||
|
|
||||||
|
await _rollback_to_pre_run_checkpoint(
|
||||||
|
checkpointer=checkpointer,
|
||||||
|
thread_id="thread-1",
|
||||||
|
run_id="run-1",
|
||||||
|
pre_run_checkpoint_id="0001",
|
||||||
|
pre_run_snapshot={
|
||||||
|
"checkpoint_ns": "",
|
||||||
|
"checkpoint": before_checkpoint,
|
||||||
|
"metadata": {"step": 1},
|
||||||
|
"pending_writes": [("task-before", "messages", "pending-before")],
|
||||||
|
},
|
||||||
|
snapshot_capture_failed=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
latest = checkpointer.get_tuple(thread_config)
|
||||||
|
|
||||||
|
assert latest is not None
|
||||||
|
assert latest.config["configurable"]["checkpoint_id"] != "0001"
|
||||||
|
assert latest.config["configurable"]["checkpoint_id"] != "0002"
|
||||||
|
assert latest.checkpoint["channel_values"] == {"messages": ["before"]}
|
||||||
|
assert latest.pending_writes == [("task-before", "messages", "pending-before")]
|
||||||
|
assert ("task-after", "messages", "pending-after") not in latest.pending_writes
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_rollback_deletes_thread_when_no_snapshot_exists():
|
async def test_rollback_deletes_thread_when_no_snapshot_exists():
|
||||||
checkpointer = FakeCheckpointer(put_result=None)
|
checkpointer = FakeCheckpointer(put_result=None)
|
||||||
@@ -194,12 +238,13 @@ async def test_rollback_normalizes_none_checkpoint_ns_to_root_namespace():
|
|||||||
snapshot_capture_failed=False,
|
snapshot_capture_failed=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
checkpointer.aput.assert_awaited_once_with(
|
checkpointer.aput.assert_awaited_once()
|
||||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
|
restore_config, restored_checkpoint, restored_metadata, new_versions = checkpointer.aput.await_args.args
|
||||||
{"id": "ckpt-1", "channel_versions": {}},
|
assert restore_config == {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
|
||||||
{},
|
assert restored_checkpoint["id"] != "ckpt-1"
|
||||||
{},
|
assert restored_checkpoint["channel_versions"] == {}
|
||||||
)
|
assert restored_metadata == {}
|
||||||
|
assert new_versions == {}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
|
|||||||
Reference in New Issue
Block a user