Merge pull request #1353 from Tahir-yamin/fix/concurrent-storage-file-locks-leak
fix(memory): patch ConcurrentStorage leak (WeakValueDictionary)
This commit is contained in:
@@ -10,10 +10,11 @@ Wraps FileStorage with:
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections import defaultdict, OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from framework.schemas.run import Run, RunStatus, RunSummary
|
||||
from framework.storage.backend import FileStorage
|
||||
@@ -61,6 +62,7 @@ class ConcurrentStorage:
|
||||
cache_ttl: float = 60.0,
|
||||
batch_interval: float = 0.1,
|
||||
max_batch_size: int = 100,
|
||||
max_locks: int = 1000,
|
||||
):
|
||||
"""
|
||||
Initialize concurrent storage.
|
||||
@@ -70,6 +72,7 @@ class ConcurrentStorage:
|
||||
cache_ttl: Cache time-to-live in seconds
|
||||
batch_interval: Interval between batch flushes
|
||||
max_batch_size: Maximum items before forcing flush
|
||||
max_locks: Maximum number of active file locks to track strongly
|
||||
"""
|
||||
self.base_path = Path(base_path)
|
||||
self._base_storage = FileStorage(base_path)
|
||||
@@ -84,9 +87,10 @@ class ConcurrentStorage:
|
||||
self._max_batch_size = max_batch_size
|
||||
self._batch_task: asyncio.Task | None = None
|
||||
|
||||
# Locking
|
||||
self._file_locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
self._global_lock = asyncio.Lock()
|
||||
# Locking - Use WeakValueDictionary to allow unused locks to be GC'd
|
||||
self._file_locks: WeakValueDictionary = WeakValueDictionary()
|
||||
self._lru_tracking: OrderedDict = OrderedDict()
|
||||
self._max_locks = max_locks
|
||||
|
||||
# State
|
||||
self._running = False
|
||||
@@ -107,7 +111,10 @@ class ConcurrentStorage:
|
||||
|
||||
self._running = False
|
||||
|
||||
# Cancel batch task first to prevent queue competition
|
||||
# Flush remaining items
|
||||
await self._flush_pending()
|
||||
|
||||
# Cancel batch task
|
||||
if self._batch_task:
|
||||
self._batch_task.cancel()
|
||||
try:
|
||||
@@ -116,11 +123,40 @@ class ConcurrentStorage:
|
||||
pass
|
||||
self._batch_task = None
|
||||
|
||||
# Now flush remaining items (batch task is stopped)
|
||||
await self._flush_pending()
|
||||
|
||||
logger.info("ConcurrentStorage stopped")
|
||||
|
||||
async def _get_lock(self, lock_key: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a given key with safe eviction."""
|
||||
# 1. Check if lock exists
|
||||
lock = self._file_locks.get(lock_key)
|
||||
|
||||
if lock is not None:
|
||||
# OPTIMIZATION: Only update LRU for "run" locks.
|
||||
# This prevents high-frequency "index" locks from flushing out
|
||||
# the actual run locks we want to keep cached.
|
||||
if lock_key.startswith("run:"):
|
||||
if lock_key in self._lru_tracking:
|
||||
self._lru_tracking.move_to_end(lock_key)
|
||||
return lock
|
||||
|
||||
# 2. Create new lock
|
||||
lock = asyncio.Lock()
|
||||
self._file_locks[lock_key] = lock
|
||||
|
||||
# CRITICAL: Only add "run:" locks to the strong-ref LRU tracking.
|
||||
# Index locks live exclusively in WeakValueDictionary and are GC'd immediately.
|
||||
if lock_key.startswith("run:"):
|
||||
# Manage capacity only for run locks
|
||||
if len(self._lru_tracking) >= self._max_locks:
|
||||
# Remove oldest tracked lock (strong ref)
|
||||
# WeakValueDictionary will auto-remove the lock once no longer in use
|
||||
self._lru_tracking.popitem(last=False)
|
||||
|
||||
# Add strong reference to keep run lock alive
|
||||
self._lru_tracking[lock_key] = lock
|
||||
|
||||
return lock
|
||||
|
||||
# === RUN OPERATIONS (Async, Thread-Safe) ===
|
||||
|
||||
async def save_run(self, run: Run, immediate: bool = False) -> None:
|
||||
@@ -140,12 +176,40 @@ class ConcurrentStorage:
|
||||
self._cache[f"run:{run.id}"] = CacheEntry(run, time.time())
|
||||
|
||||
async def _save_run_locked(self, run: Run) -> None:
|
||||
"""Save a run with file locking."""
|
||||
"""Save a run with file locking, including index locks."""
|
||||
lock_key = f"run:{run.id}"
|
||||
async with self._file_locks[lock_key]:
|
||||
# Run in executor to avoid blocking event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, self._base_storage.save_run, run)
|
||||
|
||||
# Helper to get lock
|
||||
async def get_lock(k):
|
||||
return await self._get_lock(k)
|
||||
|
||||
# Acquire main lock
|
||||
run_lock = await get_lock(lock_key)
|
||||
|
||||
async with run_lock:
|
||||
# 2. Acquire index locks
|
||||
index_lock_keys = [
|
||||
f"index:by_goal:{run.goal_id}",
|
||||
f"index:by_status:{run.status.value}",
|
||||
]
|
||||
for node_id in run.metrics.nodes_executed:
|
||||
index_lock_keys.append(f"index:by_node:{node_id}")
|
||||
|
||||
# Collect index locks
|
||||
index_locks = [await get_lock(k) for k in index_lock_keys]
|
||||
|
||||
# Recursive acquisition
|
||||
async def with_locks(locks, callback):
|
||||
if not locks:
|
||||
return await callback()
|
||||
async with locks[0]:
|
||||
return await with_locks(locks[1:], callback)
|
||||
|
||||
async def perform_save():
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, self._base_storage.save_run, run)
|
||||
|
||||
await with_locks(index_locks, perform_save)
|
||||
|
||||
async def load_run(self, run_id: str, use_cache: bool = True) -> Run | None:
|
||||
"""
|
||||
@@ -158,23 +222,25 @@ class ConcurrentStorage:
|
||||
Returns:
|
||||
Run object or None if not found
|
||||
"""
|
||||
cache_key = f"run:{run_id}"
|
||||
if use_cache:
|
||||
cache_key = f"run:{run_id}"
|
||||
cached = self._cache.get(cache_key)
|
||||
if cached and not cached.is_expired(self._cache_ttl):
|
||||
# CRITICAL: Touch LRU even on cache hit
|
||||
lock_key = f"run:{run_id}"
|
||||
if lock_key in self._lru_tracking:
|
||||
self._lru_tracking.move_to_end(lock_key)
|
||||
return cached.value
|
||||
|
||||
# Check cache
|
||||
if use_cache and cache_key in self._cache:
|
||||
entry = self._cache[cache_key]
|
||||
if not entry.is_expired(self._cache_ttl):
|
||||
return entry.value
|
||||
|
||||
# Load from storage
|
||||
# CRITICAL: Acquire lock to trigger LRU update
|
||||
lock_key = f"run:{run_id}"
|
||||
async with self._file_locks[lock_key]:
|
||||
async with await self._get_lock(lock_key):
|
||||
loop = asyncio.get_event_loop()
|
||||
run = await loop.run_in_executor(None, self._base_storage.load_run, run_id)
|
||||
|
||||
# Update cache
|
||||
if run:
|
||||
self._cache[cache_key] = CacheEntry(run, time.time())
|
||||
self._cache[f"run:{run_id}"] = CacheEntry(run, time.time())
|
||||
|
||||
return run
|
||||
|
||||
@@ -189,8 +255,10 @@ class ConcurrentStorage:
|
||||
return entry.value
|
||||
|
||||
# Load from storage
|
||||
loop = asyncio.get_event_loop()
|
||||
summary = await loop.run_in_executor(None, self._base_storage.load_summary, run_id)
|
||||
lock_key = f"summary:{run_id}"
|
||||
async with await self._get_lock(lock_key):
|
||||
loop = asyncio.get_event_loop()
|
||||
summary = await loop.run_in_executor(None, self._base_storage.load_summary, run_id)
|
||||
|
||||
# Update cache
|
||||
if summary:
|
||||
@@ -201,7 +269,7 @@ class ConcurrentStorage:
|
||||
async def delete_run(self, run_id: str) -> bool:
|
||||
"""Delete a run from storage."""
|
||||
lock_key = f"run:{run_id}"
|
||||
async with self._file_locks[lock_key]:
|
||||
async with await self._get_lock(lock_key):
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(None, self._base_storage.delete_run, run_id)
|
||||
|
||||
@@ -215,7 +283,7 @@ class ConcurrentStorage:
|
||||
|
||||
async def get_runs_by_goal(self, goal_id: str) -> list[str]:
|
||||
"""Get all run IDs for a goal."""
|
||||
async with self._file_locks[f"index:by_goal:{goal_id}"]:
|
||||
async with await self._get_lock(f"index:by_goal:{goal_id}"):
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self._base_storage.get_runs_by_goal, goal_id)
|
||||
|
||||
@@ -223,13 +291,13 @@ class ConcurrentStorage:
|
||||
"""Get all run IDs with a status."""
|
||||
if isinstance(status, RunStatus):
|
||||
status = status.value
|
||||
async with self._file_locks[f"index:by_status:{status}"]:
|
||||
async with await self._get_lock(f"index:by_status:{status}"):
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self._base_storage.get_runs_by_status, status)
|
||||
|
||||
async def get_runs_by_node(self, node_id: str) -> list[str]:
|
||||
"""Get all run IDs that executed a node."""
|
||||
async with self._file_locks[f"index:by_node:{node_id}"]:
|
||||
async with await self._get_lock(f"index:by_node:{node_id}"):
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self._base_storage.get_runs_by_node, node_id)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user