Merge pull request #6385 from Waryjustice/fix/google-sheets-credentials-orphan
fix: make state.json progress writes atomic in GraphExecutor
This commit is contained in:
@@ -32,6 +32,7 @@ from framework.observability import set_trace_context
|
||||
from framework.runtime.core import Runtime
|
||||
from framework.schemas.checkpoint import Checkpoint
|
||||
from framework.storage.checkpoint_store import CheckpointStore
|
||||
from framework.utils.io import atomic_write
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -226,11 +227,11 @@ class GraphExecutor:
|
||||
"""
|
||||
if not self._storage_path:
|
||||
return
|
||||
state_path = self._storage_path / "state.json"
|
||||
try:
|
||||
import json as _json
|
||||
from datetime import datetime
|
||||
|
||||
state_path = self._storage_path / "state.json"
|
||||
if state_path.exists():
|
||||
state_data = _json.loads(state_path.read_text(encoding="utf-8"))
|
||||
else:
|
||||
@@ -253,9 +254,14 @@ class GraphExecutor:
|
||||
state_data["memory"] = memory_snapshot
|
||||
state_data["memory_keys"] = list(memory_snapshot.keys())
|
||||
|
||||
state_path.write_text(_json.dumps(state_data, indent=2), encoding="utf-8")
|
||||
with atomic_write(state_path, encoding="utf-8") as f:
|
||||
_json.dump(state_data, f, indent=2)
|
||||
except Exception:
|
||||
pass # Best-effort — never block execution
|
||||
logger.warning(
|
||||
"Failed to persist progress state to %s",
|
||||
state_path,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def _validate_tools(self, graph: GraphSpec) -> list[str]:
|
||||
"""
|
||||
|
||||
@@ -3,12 +3,16 @@ Tests for core GraphExecutor execution paths.
|
||||
Focused on minimal success and failure scenarios.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.graph.goal import Goal
|
||||
from framework.graph.node import NodeResult, NodeSpec
|
||||
from framework.utils.io import atomic_write
|
||||
|
||||
|
||||
# ---- Dummy runtime (no real logging) ----
|
||||
@@ -25,6 +29,14 @@ class DummyRuntime:
|
||||
pass
|
||||
|
||||
|
||||
class DummyMemory:
|
||||
def __init__(self, data):
|
||||
self._data = data
|
||||
|
||||
def read_all(self):
|
||||
return self._data
|
||||
|
||||
|
||||
# ---- Fake node that always succeeds ----
|
||||
class SuccessNode:
|
||||
def validate_input(self, ctx):
|
||||
@@ -245,3 +257,61 @@ async def test_executor_no_events_without_event_bus():
|
||||
result = await executor.execute(graph=graph, goal=goal)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
|
||||
def test_write_progress_uses_atomic_write_and_updates_state(tmp_path, monkeypatch):
|
||||
runtime = DummyRuntime()
|
||||
executor = GraphExecutor(runtime=runtime, storage_path=tmp_path)
|
||||
state_path = tmp_path / "state.json"
|
||||
state_path.write_text(json.dumps({"entry_point": "primary"}), encoding="utf-8")
|
||||
memory = DummyMemory({"foo": "bar"})
|
||||
|
||||
called = {}
|
||||
|
||||
def recording_atomic_write(path, *args, **kwargs):
|
||||
called["path"] = path
|
||||
return atomic_write(path, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr("framework.graph.executor.atomic_write", recording_atomic_write)
|
||||
|
||||
executor._write_progress(
|
||||
current_node="node-b",
|
||||
path=["node-a", "node-b"],
|
||||
memory=memory,
|
||||
node_visit_counts={"node-a": 1, "node-b": 1},
|
||||
)
|
||||
|
||||
state = json.loads(state_path.read_text(encoding="utf-8"))
|
||||
assert called["path"] == state_path
|
||||
assert state["entry_point"] == "primary"
|
||||
assert state["progress"]["current_node"] == "node-b"
|
||||
assert state["progress"]["path"] == ["node-a", "node-b"]
|
||||
assert state["progress"]["node_visit_counts"] == {"node-a": 1, "node-b": 1}
|
||||
assert state["progress"]["steps_executed"] == 2
|
||||
assert state["memory"] == {"foo": "bar"}
|
||||
assert state["memory_keys"] == ["foo"]
|
||||
assert "updated_at" in state["timestamps"]
|
||||
|
||||
|
||||
def test_write_progress_logs_warning_on_atomic_write_failure(tmp_path, monkeypatch, caplog):
|
||||
runtime = DummyRuntime()
|
||||
executor = GraphExecutor(runtime=runtime, storage_path=tmp_path)
|
||||
state_path = tmp_path / "state.json"
|
||||
state_path.write_text(json.dumps({"entry_point": "primary"}), encoding="utf-8")
|
||||
memory = DummyMemory({"foo": "bar"})
|
||||
|
||||
def failing_atomic_write(*args, **kwargs):
|
||||
raise OSError("disk full")
|
||||
|
||||
monkeypatch.setattr("framework.graph.executor.atomic_write", failing_atomic_write)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
executor._write_progress(
|
||||
current_node="node-b",
|
||||
path=["node-a", "node-b"],
|
||||
memory=memory,
|
||||
node_visit_counts={"node-a": 1, "node-b": 1},
|
||||
)
|
||||
|
||||
assert "Failed to persist progress state to" in caplog.text
|
||||
assert str(state_path) in caplog.text
|
||||
|
||||
Reference in New Issue
Block a user