removed dead tests, updated some drifting behavior
This commit is contained in:
@@ -421,6 +421,9 @@ def cmd_run(args: argparse.Namespace) -> int:
|
||||
model=args.model,
|
||||
enable_tui=False,
|
||||
)
|
||||
except CredentialError as e:
|
||||
print(f"\n{e}", file=sys.stderr)
|
||||
return 1
|
||||
except FileNotFoundError as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
return 1
|
||||
@@ -524,10 +527,14 @@ def cmd_run(args: argparse.Namespace) -> int:
|
||||
|
||||
def cmd_info(args: argparse.Namespace) -> int:
|
||||
"""Show agent information."""
|
||||
from framework.credentials.models import CredentialError
|
||||
from framework.runner import AgentRunner
|
||||
|
||||
try:
|
||||
runner = AgentRunner.load(args.agent_path)
|
||||
except CredentialError as e:
|
||||
print(f"\n{e}", file=sys.stderr)
|
||||
return 1
|
||||
except FileNotFoundError as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
return 1
|
||||
@@ -587,10 +594,14 @@ def cmd_info(args: argparse.Namespace) -> int:
|
||||
|
||||
def cmd_validate(args: argparse.Namespace) -> int:
|
||||
"""Validate an exported agent."""
|
||||
from framework.credentials.models import CredentialError
|
||||
from framework.runner import AgentRunner
|
||||
|
||||
try:
|
||||
runner = AgentRunner.load(args.agent_path)
|
||||
except CredentialError as e:
|
||||
print(f"\n{e}", file=sys.stderr)
|
||||
return 1
|
||||
except FileNotFoundError as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
return 1
|
||||
@@ -907,6 +918,7 @@ def cmd_shell(args: argparse.Namespace) -> int:
|
||||
"""Start an interactive agent session."""
|
||||
import logging
|
||||
|
||||
from framework.credentials.models import CredentialError
|
||||
from framework.runner import AgentRunner
|
||||
|
||||
# Configure logging to show runtime visibility
|
||||
@@ -931,6 +943,9 @@ def cmd_shell(args: argparse.Namespace) -> int:
|
||||
|
||||
try:
|
||||
runner = AgentRunner.load(agent_path)
|
||||
except CredentialError as e:
|
||||
print(f"\n{e}", file=sys.stderr)
|
||||
return 1
|
||||
except FileNotFoundError as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
@@ -346,8 +346,7 @@ class AgentRunner:
|
||||
for node in self.graph.nodes:
|
||||
if node.tools:
|
||||
required_tools.update(node.tools)
|
||||
if not required_tools:
|
||||
return
|
||||
node_types: set[str] = {node.node_type for node in self.graph.nodes}
|
||||
|
||||
try:
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
@@ -374,14 +373,19 @@ class AgentRunner:
|
||||
storage = CompositeStorage(primary=storages[0], fallbacks=storages[1:])
|
||||
store = CredentialStore(storage=storage)
|
||||
|
||||
# Build tool→credential mapping and check
|
||||
# Build reverse mappings
|
||||
tool_to_cred: dict[str, str] = {}
|
||||
node_type_to_cred: dict[str, str] = {}
|
||||
for cred_name, spec in CREDENTIAL_SPECS.items():
|
||||
for tool_name in spec.tools:
|
||||
tool_to_cred[tool_name] = cred_name
|
||||
for nt in spec.node_types:
|
||||
node_type_to_cred[nt] = cred_name
|
||||
|
||||
missing: list[str] = []
|
||||
checked: set[str] = set()
|
||||
|
||||
# Check tool credentials
|
||||
for tool_name in sorted(required_tools):
|
||||
cred_name = tool_to_cred.get(tool_name)
|
||||
if cred_name is None or cred_name in checked:
|
||||
@@ -396,6 +400,21 @@ class AgentRunner:
|
||||
entry += f"\n Get it at: {spec.help_url}"
|
||||
missing.append(entry)
|
||||
|
||||
# Check node type credentials (e.g., ANTHROPIC_API_KEY for LLM nodes)
|
||||
for nt in sorted(node_types):
|
||||
cred_name = node_type_to_cred.get(nt)
|
||||
if cred_name is None or cred_name in checked:
|
||||
continue
|
||||
checked.add(cred_name)
|
||||
spec = CREDENTIAL_SPECS[cred_name]
|
||||
cred_id = spec.credential_id or cred_name
|
||||
if spec.required and not store.is_available(cred_id):
|
||||
affected_types = sorted(t for t in node_types if t in spec.node_types)
|
||||
entry = f" {spec.env_var} for {', '.join(affected_types)} nodes"
|
||||
if spec.help_url:
|
||||
entry += f"\n Get it at: {spec.help_url}"
|
||||
missing.append(entry)
|
||||
|
||||
if missing:
|
||||
from framework.credentials.models import CredentialError
|
||||
|
||||
|
||||
@@ -85,7 +85,7 @@ class ChatRepl(Vertical):
|
||||
self._pending_ask_question: str = ""
|
||||
self._resume_session = resume_session
|
||||
self._resume_checkpoint = resume_checkpoint
|
||||
self._session_index: list[str] = [] # Ordered session IDs from last /sessions or /resume listing
|
||||
self._session_index: list[str] = [] # IDs from last listing
|
||||
|
||||
# Dedicated event loop for agent execution.
|
||||
# Keeps blocking runtime code (LLM calls, MCP tools) off
|
||||
@@ -173,12 +173,8 @@ class ChatRepl(Vertical):
|
||||
if 0 <= idx < len(self._session_index):
|
||||
session_id = self._session_index[idx]
|
||||
else:
|
||||
self._write_history(
|
||||
f"[bold red]Error:[/bold red] No session at index {arg}"
|
||||
)
|
||||
self._write_history(
|
||||
" Use [bold]/resume[/bold] to see available sessions"
|
||||
)
|
||||
self._write_history(f"[bold red]Error:[/bold red] No session at index {arg}")
|
||||
self._write_history(" Use [bold]/resume[/bold] to see available sessions")
|
||||
return
|
||||
else:
|
||||
session_id = arg
|
||||
@@ -328,9 +324,7 @@ class ChatRepl(Vertical):
|
||||
status_colored = f"[dim]{status}[/dim]"
|
||||
|
||||
# Session line with index and label
|
||||
self._write_history(
|
||||
f" [bold]{index}.[/bold] {label} {status_colored}"
|
||||
)
|
||||
self._write_history(f" [bold]{index}.[/bold] {label} {status_colored}")
|
||||
self._write_history(f" [dim]{session_id}[/dim]")
|
||||
self._write_history("") # Blank line
|
||||
|
||||
@@ -338,9 +332,7 @@ class ChatRepl(Vertical):
|
||||
self._write_history(f" [dim red]Error reading: {e}[/dim red]")
|
||||
|
||||
if self._session_index:
|
||||
self._write_history(
|
||||
"[dim]Use [bold]/resume <number>[/bold] to resume a session[/dim]"
|
||||
)
|
||||
self._write_history("[dim]Use [bold]/resume <number>[/bold] to resume a session[/dim]")
|
||||
|
||||
async def _show_session_details(self, storage_path: Path, session_id: str) -> None:
|
||||
"""Show detailed information about a specific session."""
|
||||
@@ -740,9 +732,7 @@ class ChatRepl(Vertical):
|
||||
|
||||
self._write_history(f" [bold]{i}.[/bold] {label} {status_colored}")
|
||||
|
||||
self._write_history(
|
||||
"\n Type [bold]/resume <number>[/bold] to continue a session"
|
||||
)
|
||||
self._write_history("\n Type [bold]/resume <number>[/bold] to continue a session")
|
||||
self._write_history(" Or just type your input to start a new session\n")
|
||||
|
||||
except Exception:
|
||||
|
||||
@@ -1,342 +0,0 @@
|
||||
"""Tests for the BuilderQuery interface - how Builder analyzes agent runs.
|
||||
|
||||
DEPRECATED: These tests rely on the deprecated FileStorage backend.
|
||||
BuilderQuery and Runtime both use FileStorage which is deprecated.
|
||||
New code should use unified session storage instead.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from framework import BuilderQuery, Runtime
|
||||
from framework.schemas.run import RunStatus
|
||||
|
||||
# Mark all tests in this module as skipped - they rely on deprecated FileStorage
|
||||
pytestmark = pytest.mark.skip(reason="Tests rely on deprecated FileStorage backend")
|
||||
|
||||
|
||||
def create_successful_run(runtime: Runtime, goal_id: str = "test_goal") -> str:
|
||||
"""Helper to create a successful run with decisions."""
|
||||
run_id = runtime.start_run(goal_id, f"Test goal: {goal_id}")
|
||||
|
||||
runtime.set_node("search-node")
|
||||
d1 = runtime.decide(
|
||||
intent="Search for data",
|
||||
options=[
|
||||
{"id": "web", "description": "Web search", "pros": ["Fresh data"]},
|
||||
{"id": "cache", "description": "Use cache", "pros": ["Fast"]},
|
||||
],
|
||||
chosen="web",
|
||||
reasoning="Need fresh data",
|
||||
)
|
||||
runtime.record_outcome(d1, success=True, result={"items": 3}, tokens_used=50)
|
||||
|
||||
runtime.set_node("process-node")
|
||||
d2 = runtime.decide(
|
||||
intent="Process results",
|
||||
options=[{"id": "filter", "description": "Filter and transform"}],
|
||||
chosen="filter",
|
||||
reasoning="Standard processing",
|
||||
)
|
||||
runtime.record_outcome(d2, success=True, result={"processed": 3}, tokens_used=30)
|
||||
|
||||
runtime.end_run(success=True, narrative="Successfully processed data")
|
||||
return run_id
|
||||
|
||||
|
||||
def create_failed_run(runtime: Runtime, goal_id: str = "test_goal") -> str:
|
||||
"""Helper to create a failed run."""
|
||||
run_id = runtime.start_run(goal_id, f"Test goal: {goal_id}")
|
||||
|
||||
runtime.set_node("search-node")
|
||||
d1 = runtime.decide(
|
||||
intent="Search for data",
|
||||
options=[{"id": "web", "description": "Web search"}],
|
||||
chosen="web",
|
||||
reasoning="Need data",
|
||||
)
|
||||
runtime.record_outcome(d1, success=True, result={"items": 0})
|
||||
|
||||
runtime.set_node("process-node")
|
||||
d2 = runtime.decide(
|
||||
intent="Process results",
|
||||
options=[{"id": "process", "description": "Process data"}],
|
||||
chosen="process",
|
||||
reasoning="Continue pipeline",
|
||||
)
|
||||
runtime.record_outcome(d2, success=False, error="No data to process")
|
||||
|
||||
runtime.report_problem(
|
||||
severity="critical",
|
||||
description="Processing failed due to empty input",
|
||||
decision_id=d2,
|
||||
suggested_fix="Add empty input handling",
|
||||
)
|
||||
|
||||
runtime.end_run(success=False, narrative="Failed to process - no data")
|
||||
return run_id
|
||||
|
||||
|
||||
class TestBuilderQueryBasics:
|
||||
"""Test basic query operations."""
|
||||
|
||||
def test_get_run_summary(self, tmp_path: Path):
|
||||
"""Test getting a run summary."""
|
||||
runtime = Runtime(tmp_path)
|
||||
run_id = create_successful_run(runtime)
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
summary = query.get_run_summary(run_id)
|
||||
|
||||
assert summary is not None
|
||||
assert summary.run_id == run_id
|
||||
assert summary.status == RunStatus.COMPLETED
|
||||
assert summary.decision_count == 2
|
||||
assert summary.success_rate == 1.0
|
||||
|
||||
def test_get_full_run(self, tmp_path: Path):
|
||||
"""Test getting the full run details."""
|
||||
runtime = Runtime(tmp_path)
|
||||
run_id = create_successful_run(runtime)
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
run = query.get_full_run(run_id)
|
||||
|
||||
assert run is not None
|
||||
assert run.id == run_id
|
||||
assert len(run.decisions) == 2
|
||||
assert run.decisions[0].node_id == "search-node"
|
||||
assert run.decisions[1].node_id == "process-node"
|
||||
|
||||
def test_list_runs_for_goal(self, tmp_path: Path):
|
||||
"""Test listing all runs for a goal."""
|
||||
runtime = Runtime(tmp_path)
|
||||
create_successful_run(runtime, "goal_a")
|
||||
create_successful_run(runtime, "goal_a")
|
||||
create_successful_run(runtime, "goal_b")
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
summaries = query.list_runs_for_goal("goal_a")
|
||||
|
||||
assert len(summaries) == 2
|
||||
for s in summaries:
|
||||
assert s.goal_id == "goal_a"
|
||||
|
||||
def test_get_recent_failures(self, tmp_path: Path):
|
||||
"""Test getting recent failed runs."""
|
||||
runtime = Runtime(tmp_path)
|
||||
create_successful_run(runtime)
|
||||
create_failed_run(runtime)
|
||||
create_failed_run(runtime)
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
failures = query.get_recent_failures()
|
||||
|
||||
assert len(failures) == 2
|
||||
for f in failures:
|
||||
assert f.status == RunStatus.FAILED
|
||||
|
||||
|
||||
class TestFailureAnalysis:
|
||||
"""Test failure analysis capabilities."""
|
||||
|
||||
def test_analyze_failure(self, tmp_path: Path):
|
||||
"""Test analyzing why a run failed."""
|
||||
runtime = Runtime(tmp_path)
|
||||
run_id = create_failed_run(runtime)
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
analysis = query.analyze_failure(run_id)
|
||||
|
||||
assert analysis is not None
|
||||
assert analysis.run_id == run_id
|
||||
assert "No data to process" in analysis.root_cause
|
||||
assert len(analysis.decision_chain) >= 2
|
||||
assert len(analysis.problems) == 1
|
||||
assert "critical" in analysis.problems[0].lower()
|
||||
|
||||
def test_analyze_failure_returns_none_for_success(self, tmp_path: Path):
|
||||
"""analyze_failure returns None for successful runs."""
|
||||
runtime = Runtime(tmp_path)
|
||||
run_id = create_successful_run(runtime)
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
analysis = query.analyze_failure(run_id)
|
||||
|
||||
assert analysis is None
|
||||
|
||||
def test_failure_analysis_has_suggestions(self, tmp_path: Path):
|
||||
"""Failure analysis should include suggestions."""
|
||||
runtime = Runtime(tmp_path)
|
||||
run_id = create_failed_run(runtime)
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
analysis = query.analyze_failure(run_id)
|
||||
|
||||
assert len(analysis.suggestions) > 0
|
||||
# Should include the suggested fix from the problem
|
||||
assert any("empty input" in s.lower() for s in analysis.suggestions)
|
||||
|
||||
def test_get_decision_trace(self, tmp_path: Path):
|
||||
"""Test getting a readable decision trace."""
|
||||
runtime = Runtime(tmp_path)
|
||||
run_id = create_successful_run(runtime)
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
trace = query.get_decision_trace(run_id)
|
||||
|
||||
assert len(trace) == 2
|
||||
assert "search-node" in trace[0]
|
||||
assert "process-node" in trace[1]
|
||||
|
||||
|
||||
class TestPatternAnalysis:
|
||||
"""Test pattern detection across runs."""
|
||||
|
||||
def test_find_patterns_basic(self, tmp_path: Path):
|
||||
"""Test basic pattern finding."""
|
||||
runtime = Runtime(tmp_path)
|
||||
create_successful_run(runtime, "goal_x")
|
||||
create_successful_run(runtime, "goal_x")
|
||||
create_failed_run(runtime, "goal_x")
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
patterns = query.find_patterns("goal_x")
|
||||
|
||||
assert patterns is not None
|
||||
assert patterns.goal_id == "goal_x"
|
||||
assert patterns.run_count == 3
|
||||
assert 0 < patterns.success_rate < 1 # 2/3 success
|
||||
|
||||
def test_find_patterns_common_failures(self, tmp_path: Path):
|
||||
"""Test finding common failures."""
|
||||
runtime = Runtime(tmp_path)
|
||||
# Create multiple runs with the same failure
|
||||
for _ in range(3):
|
||||
create_failed_run(runtime, "failing_goal")
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
patterns = query.find_patterns("failing_goal")
|
||||
|
||||
assert len(patterns.common_failures) > 0
|
||||
# "No data to process" should be a common failure
|
||||
failure_messages = [f[0] for f in patterns.common_failures]
|
||||
assert any("No data to process" in msg for msg in failure_messages)
|
||||
|
||||
def test_find_patterns_problematic_nodes(self, tmp_path: Path):
|
||||
"""Test finding problematic nodes."""
|
||||
runtime = Runtime(tmp_path)
|
||||
# Create runs where process-node always fails
|
||||
for _ in range(3):
|
||||
create_failed_run(runtime, "node_test")
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
patterns = query.find_patterns("node_test")
|
||||
|
||||
# process-node should be flagged as problematic
|
||||
problematic_node_ids = [n[0] for n in patterns.problematic_nodes]
|
||||
assert "process-node" in problematic_node_ids
|
||||
|
||||
def test_compare_runs(self, tmp_path: Path):
|
||||
"""Test comparing two runs."""
|
||||
runtime = Runtime(tmp_path)
|
||||
run1 = create_successful_run(runtime)
|
||||
run2 = create_failed_run(runtime)
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
comparison = query.compare_runs(run1, run2)
|
||||
|
||||
assert comparison["run_1"]["status"] == "completed"
|
||||
assert comparison["run_2"]["status"] == "failed"
|
||||
assert len(comparison["differences"]) > 0
|
||||
|
||||
|
||||
class TestImprovementSuggestions:
|
||||
"""Test improvement suggestion generation."""
|
||||
|
||||
def test_suggest_improvements(self, tmp_path: Path):
|
||||
"""Test generating improvement suggestions."""
|
||||
runtime = Runtime(tmp_path)
|
||||
# Create runs with failures to trigger suggestions
|
||||
for _ in range(3):
|
||||
create_failed_run(runtime, "improve_goal")
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
suggestions = query.suggest_improvements("improve_goal")
|
||||
|
||||
assert len(suggestions) > 0
|
||||
# Should suggest improving the problematic node
|
||||
node_suggestions = [s for s in suggestions if s["type"] == "node_improvement"]
|
||||
assert len(node_suggestions) > 0
|
||||
|
||||
def test_suggest_improvements_for_low_success_rate(self, tmp_path: Path):
|
||||
"""Should suggest architecture review for low success rate."""
|
||||
runtime = Runtime(tmp_path)
|
||||
# 4 failures, 1 success = 20% success rate
|
||||
for _ in range(4):
|
||||
create_failed_run(runtime, "low_success")
|
||||
create_successful_run(runtime, "low_success")
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
suggestions = query.suggest_improvements("low_success")
|
||||
|
||||
arch_suggestions = [s for s in suggestions if s["type"] == "architecture"]
|
||||
assert len(arch_suggestions) > 0
|
||||
assert arch_suggestions[0]["priority"] == "high"
|
||||
|
||||
def test_get_node_performance(self, tmp_path: Path):
|
||||
"""Test getting performance metrics for a node."""
|
||||
runtime = Runtime(tmp_path)
|
||||
create_successful_run(runtime)
|
||||
create_successful_run(runtime)
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
perf = query.get_node_performance("search-node")
|
||||
|
||||
assert perf["node_id"] == "search-node"
|
||||
assert perf["total_decisions"] == 2
|
||||
assert perf["success_rate"] == 1.0
|
||||
assert perf["total_tokens"] == 100 # 50 tokens per run
|
||||
|
||||
|
||||
class TestBuilderWorkflow:
|
||||
"""Test complete Builder workflows."""
|
||||
|
||||
def test_builder_investigation_workflow(self, tmp_path: Path):
|
||||
"""Test a complete investigation workflow as Builder would use it."""
|
||||
runtime = Runtime(tmp_path)
|
||||
|
||||
# Set up scenario: some successes, some failures
|
||||
for _ in range(2):
|
||||
create_successful_run(runtime, "customer_goal")
|
||||
for _ in range(2):
|
||||
create_failed_run(runtime, "customer_goal")
|
||||
|
||||
query = BuilderQuery(tmp_path)
|
||||
|
||||
# Step 1: Get overview of the goal
|
||||
summaries = query.list_runs_for_goal("customer_goal")
|
||||
assert len(summaries) == 4
|
||||
|
||||
# Step 2: Find patterns
|
||||
patterns = query.find_patterns("customer_goal")
|
||||
assert patterns.success_rate == 0.5 # 2/4
|
||||
|
||||
# Step 3: Get recent failures
|
||||
failures = query.get_recent_failures()
|
||||
assert len(failures) == 2
|
||||
|
||||
# Step 4: Analyze a specific failure
|
||||
failure_id = failures[0].run_id
|
||||
analysis = query.analyze_failure(failure_id)
|
||||
assert analysis is not None
|
||||
assert len(analysis.suggestions) > 0
|
||||
|
||||
# Step 5: Generate improvement suggestions
|
||||
suggestions = query.suggest_improvements("customer_goal")
|
||||
assert len(suggestions) > 0
|
||||
|
||||
# Step 6: Check node performance
|
||||
perf = query.get_node_performance("process-node")
|
||||
assert perf["success_rate"] < 1.0 # process-node fails in failed runs
|
||||
@@ -1,185 +0,0 @@
|
||||
"""Tests for ConcurrentStorage race condition and cache invalidation fixes."""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.schemas.run import Run, RunMetrics, RunStatus
|
||||
from framework.storage.concurrent import ConcurrentStorage
|
||||
|
||||
|
||||
def create_test_run(
|
||||
run_id: str, goal_id: str = "test-goal", status: RunStatus = RunStatus.RUNNING
|
||||
) -> Run:
|
||||
"""Create a minimal test Run object."""
|
||||
return Run(
|
||||
id=run_id,
|
||||
goal_id=goal_id,
|
||||
status=status,
|
||||
narrative="Test run",
|
||||
metrics=RunMetrics(
|
||||
nodes_executed=[],
|
||||
),
|
||||
decisions=[],
|
||||
problems=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="FileStorage.save_run() is deprecated and now a no-op. "
|
||||
"ConcurrentStorage wraps FileStorage, so these tests no longer work. "
|
||||
"New sessions use unified storage at sessions/{session_id}/state.json"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_invalidation_on_save(tmp_path: Path):
|
||||
"""Test that summary cache is invalidated when a run is saved.
|
||||
|
||||
This tests the fix for the cache invalidation bug where load_summary()
|
||||
would return stale data after a run was updated.
|
||||
"""
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
run_id = "test-run-1"
|
||||
|
||||
# Create and save initial run
|
||||
run = create_test_run(run_id, status=RunStatus.RUNNING)
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Load summary to populate the cache
|
||||
summary = await storage.load_summary(run_id)
|
||||
assert summary is not None
|
||||
assert summary.status == RunStatus.RUNNING
|
||||
|
||||
# Update run with new status
|
||||
run.status = RunStatus.COMPLETED
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Load summary again - should get fresh data, not cached stale data
|
||||
summary = await storage.load_summary(run_id)
|
||||
assert summary is not None
|
||||
assert summary.status == RunStatus.COMPLETED, (
|
||||
"Summary cache should be invalidated on save - got stale data"
|
||||
)
|
||||
finally:
|
||||
await storage.stop()
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="FileStorage.save_run() is deprecated and now a no-op. "
|
||||
"ConcurrentStorage wraps FileStorage, so these tests no longer work. "
|
||||
"New sessions use unified storage at sessions/{session_id}/state.json"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_batched_write_cache_consistency(tmp_path: Path):
|
||||
"""Test that cache is only updated after successful batched write.
|
||||
|
||||
This tests the fix for the race condition where cache was updated
|
||||
before the batched write completed.
|
||||
"""
|
||||
storage = ConcurrentStorage(tmp_path, batch_interval=0.05)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
run_id = "test-run-2"
|
||||
|
||||
# Save via batching (immediate=False)
|
||||
run = create_test_run(run_id, status=RunStatus.RUNNING)
|
||||
await storage.save_run(run, immediate=False)
|
||||
|
||||
# Before batch flush, cache should NOT contain the run
|
||||
# (This is the fix - previously cache was updated immediately)
|
||||
cache_key = f"run:{run_id}"
|
||||
assert cache_key not in storage._cache, (
|
||||
"Cache should not be updated before batch is flushed"
|
||||
)
|
||||
|
||||
# Wait for batch to flush (poll instead of fixed sleep for CI reliability)
|
||||
for _ in range(500): # 500 * 0.01s = 5s max
|
||||
if cache_key in storage._cache:
|
||||
break
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# After batch flush, cache should contain the run
|
||||
assert cache_key in storage._cache, "Cache should be updated after batch flush"
|
||||
|
||||
# Verify data on disk matches cache
|
||||
loaded_run = await storage.load_run(run_id, use_cache=False)
|
||||
assert loaded_run is not None
|
||||
assert loaded_run.id == run_id
|
||||
assert loaded_run.status == RunStatus.RUNNING
|
||||
finally:
|
||||
await storage.stop()
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="FileStorage.save_run() is deprecated and now a no-op. "
|
||||
"ConcurrentStorage wraps FileStorage, so these tests no longer work. "
|
||||
"New sessions use unified storage at sessions/{session_id}/state.json"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_immediate_write_updates_cache(tmp_path: Path):
|
||||
"""Test that immediate writes still update cache correctly."""
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
run_id = "test-run-3"
|
||||
|
||||
# Save with immediate=True
|
||||
run = create_test_run(run_id, status=RunStatus.COMPLETED)
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Cache should be updated immediately for immediate writes
|
||||
cache_key = f"run:{run_id}"
|
||||
assert cache_key in storage._cache, "Cache should be updated after immediate write"
|
||||
|
||||
# Verify cached value is correct
|
||||
cached_run = storage._cache[cache_key].value
|
||||
assert cached_run.id == run_id
|
||||
assert cached_run.status == RunStatus.COMPLETED
|
||||
finally:
|
||||
await storage.stop()
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="FileStorage.save_run() is deprecated and now a no-op. "
|
||||
"ConcurrentStorage wraps FileStorage, so these tests no longer work. "
|
||||
"New sessions use unified storage at sessions/{session_id}/state.json"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_summary_cache_invalidated_on_multiple_saves(tmp_path: Path):
|
||||
"""Test that summary cache is invalidated on each save, not just the first."""
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
try:
|
||||
run_id = "test-run-4"
|
||||
|
||||
# First save
|
||||
run = create_test_run(run_id, status=RunStatus.RUNNING)
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Load summary to cache it
|
||||
summary1 = await storage.load_summary(run_id)
|
||||
assert summary1.status == RunStatus.RUNNING
|
||||
|
||||
# Second save with new status
|
||||
run.status = RunStatus.RUNNING
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Load summary - should be fresh
|
||||
summary2 = await storage.load_summary(run_id)
|
||||
assert summary2.status == RunStatus.RUNNING
|
||||
|
||||
# Third save with final status
|
||||
run.status = RunStatus.COMPLETED
|
||||
await storage.save_run(run, immediate=True)
|
||||
|
||||
# Load summary - should be fresh again
|
||||
summary3 = await storage.load_summary(run_id)
|
||||
assert summary3.status == RunStatus.COMPLETED
|
||||
finally:
|
||||
await storage.stop()
|
||||
Reference in New Issue
Block a user