Compare commits
36 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d1bcae69b9 | |||
| 597fb0e5f9 | |||
| c38b3a9280 | |||
| cbbe39d28c | |||
| 82374eb18c | |||
| a36186cf54 | |||
| 9f28115889 | |||
| 7ce9333200 | |||
| 9af2f3e73c | |||
| dfa9fc47b3 | |||
| 3877aabcfd | |||
| e8f087cb37 | |||
| 3540e157f1 | |||
| 8f7eb28c0d | |||
| 500cdfc8e4 | |||
| 3580897c56 | |||
| 229c8095be | |||
| ce24424449 | |||
| 4810898cfa | |||
| 10cc651578 | |||
| 20f64bbf4f | |||
| e1cb78fecf | |||
| 6476eabdf5 | |||
| 95d5c156a1 | |||
| 18393b55d1 | |||
| 77491f2801 | |||
| 8d3cb6da72 | |||
| d1cf3f09b2 | |||
| 0d5b3a0ece | |||
| 4184d5ed2c | |||
| 60a5ad7279 | |||
| b2ec1f99b9 | |||
| 8da1903168 | |||
| 03952eca53 | |||
| 9197000690 | |||
| 36fb1c7804 |
+20
-8
@@ -158,7 +158,7 @@ from deerflow.config import get_app_config
|
||||
|
||||
Middlewares execute in strict order in `packages/harness/deerflow/agents/lead_agent/agent.py`:
|
||||
|
||||
1. **ThreadDataMiddleware** - Creates per-thread directories (`backend/.deer-flow/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local `.deer-flow/threads/{thread_id}` directory
|
||||
1. **ThreadDataMiddleware** - Creates per-thread directories under the user's isolation scope (`backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); resolves `user_id` via `get_effective_user_id()` (falls back to `"default"` in no-auth mode); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local thread directory
|
||||
2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation
|
||||
3. **SandboxMiddleware** - Acquires sandbox, stores `sandbox_id` in state
|
||||
4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption)
|
||||
@@ -216,6 +216,9 @@ FastAPI application on port 8001 with health check at `GET /health`.
|
||||
| **Threads** (`/api/threads/{id}`) | `DELETE /` - remove DeerFlow-managed local thread data after LangGraph thread deletion; unexpected failures are logged server-side and return a generic 500 detail |
|
||||
| **Artifacts** (`/api/threads/{id}/artifacts`) | `GET /{path}` - serve artifacts; active content types (`text/html`, `application/xhtml+xml`, `image/svg+xml`) are always forced as download attachments to reduce XSS risk; `?download=true` still forces download for other file types |
|
||||
| **Suggestions** (`/api/threads/{id}/suggestions`) | `POST /` - generate follow-up questions; rich list/block model content is normalized before JSON parsing |
|
||||
| **Thread Runs** (`/api/threads/{id}/runs`) | `POST /` - create background run; `POST /stream` - create + SSE stream; `POST /wait` - create + block; `GET /` - list runs; `GET /{rid}` - run details; `POST /{rid}/cancel` - cancel; `GET /{rid}/join` - join SSE; `GET /{rid}/messages` - paginated messages `{data, has_more}`; `GET /{rid}/events` - full event stream; `GET /../messages` - thread messages with feedback; `GET /../token-usage` - aggregate tokens |
|
||||
| **Feedback** (`/api/threads/{id}/runs/{rid}/feedback`) | `PUT /` - upsert feedback; `DELETE /` - delete user feedback; `POST /` - create feedback; `GET /` - list feedback; `GET /stats` - aggregate stats; `DELETE /{fid}` - delete specific |
|
||||
| **Runs** (`/api/runs`) | `POST /stream` - stateless run + SSE; `POST /wait` - stateless run + block; `GET /{rid}/messages` - paginated messages by run_id `{data, has_more}` (cursor: `after_seq`/`before_seq`); `GET /{rid}/feedback` - list feedback by run_id |
|
||||
|
||||
Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` → Gateway.
|
||||
|
||||
@@ -229,7 +232,7 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
|
||||
|
||||
**Virtual Path System**:
|
||||
- Agent sees: `/mnt/user-data/{workspace,uploads,outputs}`, `/mnt/skills`
|
||||
- Physical: `backend/.deer-flow/threads/{thread_id}/user-data/...`, `deer-flow/skills/`
|
||||
- Physical: `backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/...`, `deer-flow/skills/`
|
||||
- Translation: `replace_virtual_path()` / `replace_virtual_paths_in_command()`
|
||||
- Detection: `is_local_sandbox()` checks `sandbox_id == "local"`
|
||||
|
||||
@@ -269,7 +272,7 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
|
||||
- `invoke_acp_agent` - Invokes external ACP-compatible agents from `config.yaml`
|
||||
- ACP launchers must be real ACP adapters. The standard `codex` CLI is not ACP-compatible by itself; configure a wrapper such as `npx -y @zed-industries/codex-acp` or an installed `codex-acp` binary
|
||||
- Missing ACP executables now return an actionable error message instead of a raw `[Errno 2]`
|
||||
- Each ACP agent uses a per-thread workspace at `{base_dir}/threads/{thread_id}/acp-workspace/`. The workspace is accessible to the lead agent via the virtual path `/mnt/acp-workspace/` (read-only). In docker sandbox mode, the directory is volume-mounted into the container at `/mnt/acp-workspace` (read-only); in local sandbox mode, path translation is handled by `tools.py`
|
||||
- Each ACP agent uses a per-thread workspace at `{base_dir}/users/{user_id}/threads/{thread_id}/acp-workspace/`. The workspace is accessible to the lead agent via the virtual path `/mnt/acp-workspace/` (read-only). In docker sandbox mode, the directory is volume-mounted into the container at `/mnt/acp-workspace` (read-only); in local sandbox mode, path translation is handled by `tools.py`
|
||||
- `image_search/` - Image search via DuckDuckGo
|
||||
|
||||
### MCP System (`packages/harness/deerflow/mcp/`)
|
||||
@@ -338,18 +341,27 @@ Bridges external messaging platforms (Feishu, Slack, Telegram) to the DeerFlow a
|
||||
|
||||
**Components**:
|
||||
- `updater.py` - LLM-based memory updates with fact extraction, whitespace-normalized fact deduplication (trims leading/trailing whitespace before comparing), and atomic file I/O
|
||||
- `queue.py` - Debounced update queue (per-thread deduplication, configurable wait time)
|
||||
- `queue.py` - Debounced update queue (per-thread deduplication, configurable wait time); captures `user_id` at enqueue time so it survives the `threading.Timer` boundary
|
||||
- `prompt.py` - Prompt templates for memory updates
|
||||
- `storage.py` - File-based storage with per-user isolation; cache keyed by `(user_id, agent_name)` tuple
|
||||
|
||||
**Data Structure** (stored in `backend/.deer-flow/memory.json`):
|
||||
**Per-User Isolation**:
|
||||
- Memory is stored per-user at `{base_dir}/users/{user_id}/memory.json`
|
||||
- Per-agent per-user memory at `{base_dir}/users/{user_id}/agents/{agent_name}/memory.json`
|
||||
- `user_id` is resolved via `get_effective_user_id()` from `deerflow.runtime.user_context`
|
||||
- In no-auth mode, `user_id` defaults to `"default"` (constant `DEFAULT_USER_ID`)
|
||||
- Absolute `storage_path` in config opts out of per-user isolation
|
||||
- **Migration**: Run `PYTHONPATH=. python scripts/migrate_user_isolation.py` to move legacy `memory.json` and `threads/` into per-user layout; supports `--dry-run`
|
||||
|
||||
**Data Structure** (stored in `{base_dir}/users/{user_id}/memory.json`):
|
||||
- **User Context**: `workContext`, `personalContext`, `topOfMind` (1-3 sentence summaries)
|
||||
- **History**: `recentMonths`, `earlierContext`, `longTermBackground`
|
||||
- **Facts**: Discrete facts with `id`, `content`, `category` (preference/knowledge/context/behavior/goal), `confidence` (0-1), `createdAt`, `source`
|
||||
|
||||
**Workflow**:
|
||||
1. `MemoryMiddleware` filters messages (user inputs + final AI responses) and queues conversation
|
||||
1. `MemoryMiddleware` filters messages (user inputs + final AI responses), captures `user_id` via `get_effective_user_id()`, and queues conversation with the captured `user_id`
|
||||
2. Queue debounces (30s default), batches updates, deduplicates per-thread
|
||||
3. Background thread invokes LLM to extract context updates and facts
|
||||
3. Background thread invokes LLM to extract context updates and facts, using the stored `user_id` (not the contextvar, which is unavailable on timer threads)
|
||||
4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append
|
||||
5. Next interaction injects top 15 facts + context into `<memory>` tags in system prompt
|
||||
|
||||
@@ -357,7 +369,7 @@ Focused regression coverage for the updater lives in `backend/tests/test_memory_
|
||||
|
||||
**Configuration** (`config.yaml` → `memory`):
|
||||
- `enabled` / `injection_enabled` - Master switches
|
||||
- `storage_path` - Path to memory.json
|
||||
- `storage_path` - Path to memory.json (absolute path opts out of per-user isolation)
|
||||
- `debounce_seconds` - Wait time before processing (default: 30)
|
||||
- `model_name` - LLM for updates (null = default model)
|
||||
- `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7)
|
||||
|
||||
@@ -13,6 +13,7 @@ from app.channels.base import Channel
|
||||
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -344,8 +345,9 @@ class FeishuChannel(Channel):
|
||||
return f"Failed to obtain the [{type}]"
|
||||
|
||||
paths = get_paths()
|
||||
paths.ensure_thread_dirs(thread_id)
|
||||
uploads_dir = paths.sandbox_uploads_dir(thread_id).resolve()
|
||||
user_id = get_effective_user_id()
|
||||
paths.ensure_thread_dirs(thread_id, user_id=user_id)
|
||||
uploads_dir = paths.sandbox_uploads_dir(thread_id, user_id=user_id).resolve()
|
||||
|
||||
ext = "png" if type == "image" else "bin"
|
||||
raw_filename = getattr(response, "file_name", "") or f"feishu_{file_key[-12:]}.{ext}"
|
||||
|
||||
@@ -17,6 +17,7 @@ from langgraph_sdk.errors import ConflictError
|
||||
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
from app.channels.store import ChannelStore
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -341,14 +342,15 @@ def _resolve_attachments(thread_id: str, artifacts: list[str]) -> list[ResolvedA
|
||||
|
||||
attachments: list[ResolvedAttachment] = []
|
||||
paths = get_paths()
|
||||
outputs_dir = paths.sandbox_outputs_dir(thread_id).resolve()
|
||||
user_id = get_effective_user_id()
|
||||
outputs_dir = paths.sandbox_outputs_dir(thread_id, user_id=user_id).resolve()
|
||||
for virtual_path in artifacts:
|
||||
# Security: only allow files from the agent outputs directory
|
||||
if not virtual_path.startswith(_OUTPUTS_VIRTUAL_PREFIX):
|
||||
logger.warning("[Manager] rejected non-outputs artifact path: %s", virtual_path)
|
||||
continue
|
||||
try:
|
||||
actual = paths.resolve_virtual_path(thread_id, virtual_path)
|
||||
actual = paths.resolve_virtual_path(thread_id, virtual_path, user_id=user_id)
|
||||
# Verify the resolved path is actually under the outputs directory
|
||||
# (guards against path-traversal even after prefix check)
|
||||
try:
|
||||
|
||||
@@ -42,6 +42,11 @@ logger = logging.getLogger(__name__)
|
||||
async def _ensure_admin_user(app: FastAPI) -> None:
|
||||
"""Startup hook: generate init token on first boot; migrate orphan threads otherwise.
|
||||
|
||||
After admin creation, migrate orphan threads from the LangGraph
|
||||
store (metadata.user_id unset) to the admin account. This is the
|
||||
"no-auth → with-auth" upgrade path: users who ran DeerFlow without
|
||||
authentication have existing LangGraph thread data that needs an
|
||||
owner assigned.
|
||||
First boot (no admin exists):
|
||||
- Generates a one-time ``init_token`` stored in ``app.state.init_token``
|
||||
- Logs the token to stdout so the operator can copy-paste it into the
|
||||
@@ -52,7 +57,7 @@ async def _ensure_admin_user(app: FastAPI) -> None:
|
||||
- Runs the one-time "no-auth → with-auth" orphan thread migration for
|
||||
existing LangGraph thread metadata that has no owner_id.
|
||||
|
||||
No SQL persistence migration is needed: the four owner_id columns
|
||||
No SQL persistence migration is needed: the four user_id columns
|
||||
(threads_meta, runs, run_events, feedback) only come into existence
|
||||
alongside the auth module via create_all, so freshly created tables
|
||||
never contain NULL-owner rows.
|
||||
@@ -96,6 +101,8 @@ async def _ensure_admin_user(app: FastAPI) -> None:
|
||||
admin_id = str(row.id)
|
||||
|
||||
# LangGraph store orphan migration — non-fatal.
|
||||
# This covers the "no-auth → with-auth" upgrade path for users
|
||||
# whose existing LangGraph thread metadata has no user_id set.
|
||||
store = getattr(app.state, "store", None)
|
||||
if store is not None:
|
||||
try:
|
||||
@@ -127,7 +134,7 @@ async def _iter_store_items(store, namespace, *, page_size: int = 500):
|
||||
|
||||
|
||||
async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
|
||||
"""Migrate LangGraph store threads with no owner_id to the given admin.
|
||||
"""Migrate LangGraph store threads with no user_id to the given admin.
|
||||
|
||||
Uses cursor pagination so all orphans are migrated regardless of
|
||||
count. Returns the number of rows migrated.
|
||||
@@ -135,8 +142,8 @@ async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
|
||||
migrated = 0
|
||||
async for item in _iter_store_items(store, ("threads",)):
|
||||
metadata = item.value.get("metadata", {})
|
||||
if not metadata.get("owner_id"):
|
||||
metadata["owner_id"] = admin_user_id
|
||||
if not metadata.get("user_id"):
|
||||
metadata["user_id"] = admin_user_id
|
||||
item.value["metadata"] = metadata
|
||||
await store.aput(("threads",), item.key, item.value)
|
||||
migrated += 1
|
||||
|
||||
@@ -233,18 +233,18 @@ def require_permission(
|
||||
# (``threads_meta`` table). We verify ownership via
|
||||
# ``ThreadMetaStore.check_access``: it returns True for
|
||||
# missing rows (untracked legacy thread) and for rows whose
|
||||
# ``owner_id`` is NULL (shared / pre-auth data), so this is
|
||||
# ``user_id`` is NULL (shared / pre-auth data), so this is
|
||||
# strict-deny rather than strict-allow — only an *existing*
|
||||
# row with a *different* owner_id triggers 404.
|
||||
# row with a *different* user_id triggers 404.
|
||||
if owner_check:
|
||||
thread_id = kwargs.get("thread_id")
|
||||
if thread_id is None:
|
||||
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
||||
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
from app.gateway.deps import get_thread_store
|
||||
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
allowed = await thread_meta_repo.check_access(
|
||||
thread_store = get_thread_store(request)
|
||||
allowed = await thread_store.check_access(
|
||||
thread_id,
|
||||
str(auth.user.id),
|
||||
require_existing=require_existing,
|
||||
|
||||
+16
-10
@@ -1,8 +1,7 @@
|
||||
"""Centralized accessors for singleton objects stored on ``app.state``.
|
||||
|
||||
**Getters** (used by routers): raise 503 when a required dependency is
|
||||
missing, except ``get_store`` and ``get_thread_meta_repo`` which return
|
||||
``None``.
|
||||
missing, except ``get_store`` which returns ``None``.
|
||||
|
||||
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
|
||||
"""
|
||||
@@ -20,6 +19,7 @@ from deerflow.runtime import RunContext, RunManager
|
||||
if TYPE_CHECKING:
|
||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -31,10 +31,10 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
async with langgraph_runtime(app):
|
||||
yield
|
||||
"""
|
||||
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config
|
||||
from deerflow.runtime import make_store, make_stream_bridge
|
||||
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
async with AsyncExitStack() as stack:
|
||||
@@ -53,18 +53,18 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
if sf is not None:
|
||||
from deerflow.persistence.feedback import FeedbackRepository
|
||||
from deerflow.persistence.run import RunRepository
|
||||
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
||||
|
||||
app.state.run_store = RunRepository(sf)
|
||||
app.state.feedback_repo = FeedbackRepository(sf)
|
||||
app.state.thread_meta_repo = ThreadMetaRepository(sf)
|
||||
else:
|
||||
from deerflow.persistence.thread_meta import MemoryThreadMetaStore
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
app.state.run_store = MemoryRunStore()
|
||||
app.state.feedback_repo = None
|
||||
app.state.thread_meta_repo = MemoryThreadMetaStore(app.state.store)
|
||||
|
||||
from deerflow.persistence.thread_meta import make_thread_store
|
||||
|
||||
app.state.thread_store = make_thread_store(sf, app.state.store)
|
||||
|
||||
# Run event store (has its own factory with config-driven backend selection)
|
||||
run_events_config = getattr(config, "run_events", None)
|
||||
@@ -80,7 +80,7 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Getters -- called by routers per-request
|
||||
# Getters – called by routers per-request
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -110,7 +110,12 @@ def get_store(request: Request):
|
||||
return getattr(request.app.state, "store", None)
|
||||
|
||||
|
||||
get_thread_meta_repo = _require("thread_meta_repo", "Thread metadata store")
|
||||
def get_thread_store(request: Request) -> ThreadMetaStore:
|
||||
"""Return the thread metadata store (SQL or memory-backed)."""
|
||||
val = getattr(request.app.state, "thread_store", None)
|
||||
if val is None:
|
||||
raise HTTPException(status_code=503, detail="Thread metadata store not available")
|
||||
return val
|
||||
|
||||
|
||||
def get_run_context(request: Request) -> RunContext:
|
||||
@@ -128,10 +133,11 @@ def get_run_context(request: Request) -> RunContext:
|
||||
store=get_store(request),
|
||||
event_store=get_run_event_store(request),
|
||||
run_events_config=getattr(get_app_config(), "run_events", None),
|
||||
thread_meta_repo=get_thread_meta_repo(request),
|
||||
thread_store=get_thread_store(request),
|
||||
)
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth helpers (used by authz.py and auth middleware)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -93,14 +93,14 @@ async def authenticate(request):
|
||||
|
||||
@auth.on
|
||||
async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict):
|
||||
"""Inject owner_id metadata on writes; filter by owner_id on reads.
|
||||
"""Inject user_id metadata on writes; filter by user_id on reads.
|
||||
|
||||
Gateway stores thread ownership as ``metadata.owner_id``.
|
||||
Gateway stores thread ownership as ``metadata.user_id``.
|
||||
This handler ensures LangGraph Server enforces the same isolation.
|
||||
"""
|
||||
# On create/update: stamp owner_id into metadata
|
||||
# On create/update: stamp user_id into metadata
|
||||
metadata = value.setdefault("metadata", {})
|
||||
metadata["owner_id"] = ctx.user.identity
|
||||
metadata["user_id"] = ctx.user.identity
|
||||
|
||||
# Return filter dict — LangGraph applies it to search/read/delete
|
||||
return {"owner_id": ctx.user.identity}
|
||||
return {"user_id": ctx.user.identity}
|
||||
|
||||
@@ -5,6 +5,7 @@ from pathlib import Path
|
||||
from fastapi import HTTPException
|
||||
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
|
||||
def resolve_thread_virtual_path(thread_id: str, virtual_path: str) -> Path:
|
||||
@@ -22,7 +23,7 @@ def resolve_thread_virtual_path(thread_id: str, virtual_path: str) -> Path:
|
||||
HTTPException: If the path is invalid or outside allowed directories.
|
||||
"""
|
||||
try:
|
||||
return get_paths().resolve_virtual_path(thread_id, virtual_path)
|
||||
return get_paths().resolve_virtual_path(thread_id, virtual_path, user_id=get_effective_user_id())
|
||||
except ValueError as e:
|
||||
status = 403 if "traversal" in str(e) else 400
|
||||
raise HTTPException(status_code=status, detail=str(e))
|
||||
|
||||
@@ -30,11 +30,16 @@ class FeedbackCreateRequest(BaseModel):
|
||||
message_id: str | None = Field(default=None, description="Optional: scope feedback to a specific message")
|
||||
|
||||
|
||||
class FeedbackUpsertRequest(BaseModel):
|
||||
rating: int = Field(..., description="Feedback rating: +1 (positive) or -1 (negative)")
|
||||
comment: str | None = Field(default=None, description="Optional text feedback")
|
||||
|
||||
|
||||
class FeedbackResponse(BaseModel):
|
||||
feedback_id: str
|
||||
run_id: str
|
||||
thread_id: str
|
||||
owner_id: str | None = None
|
||||
user_id: str | None = None
|
||||
message_id: str | None = None
|
||||
rating: int
|
||||
comment: str | None = None
|
||||
@@ -53,6 +58,57 @@ class FeedbackStatsResponse(BaseModel):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.put("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
|
||||
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
||||
async def upsert_feedback(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
body: FeedbackUpsertRequest,
|
||||
request: Request,
|
||||
) -> dict[str, Any]:
|
||||
"""Create or update feedback for a run (idempotent)."""
|
||||
if body.rating not in (1, -1):
|
||||
raise HTTPException(status_code=400, detail="rating must be +1 or -1")
|
||||
|
||||
user_id = await get_current_user(request)
|
||||
|
||||
run_store = get_run_store(request)
|
||||
run = await run_store.get(run_id)
|
||||
if run is None:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
if run.get("thread_id") != thread_id:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found in thread {thread_id}")
|
||||
|
||||
feedback_repo = get_feedback_repo(request)
|
||||
return await feedback_repo.upsert(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
rating=body.rating,
|
||||
user_id=user_id,
|
||||
comment=body.comment,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{thread_id}/runs/{run_id}/feedback")
|
||||
@require_permission("threads", "delete", owner_check=True, require_existing=True)
|
||||
async def delete_run_feedback(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
) -> dict[str, bool]:
|
||||
"""Delete the current user's feedback for a run."""
|
||||
user_id = await get_current_user(request)
|
||||
feedback_repo = get_feedback_repo(request)
|
||||
deleted = await feedback_repo.delete_by_run(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="No feedback found for this run")
|
||||
return {"success": True}
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
|
||||
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
||||
async def create_feedback(
|
||||
@@ -80,7 +136,7 @@ async def create_feedback(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
rating=body.rating,
|
||||
owner_id=user_id,
|
||||
user_id=user_id,
|
||||
message_id=body.message_id,
|
||||
comment=body.comment,
|
||||
)
|
||||
|
||||
@@ -13,6 +13,7 @@ from deerflow.agents.memory.updater import (
|
||||
update_memory_fact,
|
||||
)
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["memory"])
|
||||
|
||||
@@ -147,7 +148,7 @@ async def get_memory() -> MemoryResponse:
|
||||
}
|
||||
```
|
||||
"""
|
||||
memory_data = get_memory_data()
|
||||
memory_data = get_memory_data(user_id=get_effective_user_id())
|
||||
return MemoryResponse(**memory_data)
|
||||
|
||||
|
||||
@@ -167,7 +168,7 @@ async def reload_memory() -> MemoryResponse:
|
||||
Returns:
|
||||
The reloaded memory data.
|
||||
"""
|
||||
memory_data = reload_memory_data()
|
||||
memory_data = reload_memory_data(user_id=get_effective_user_id())
|
||||
return MemoryResponse(**memory_data)
|
||||
|
||||
|
||||
@@ -181,7 +182,7 @@ async def reload_memory() -> MemoryResponse:
|
||||
async def clear_memory() -> MemoryResponse:
|
||||
"""Clear all persisted memory data."""
|
||||
try:
|
||||
memory_data = clear_memory_data()
|
||||
memory_data = clear_memory_data(user_id=get_effective_user_id())
|
||||
except OSError as exc:
|
||||
raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc
|
||||
|
||||
@@ -202,6 +203,7 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo
|
||||
content=request.content,
|
||||
category=request.category,
|
||||
confidence=request.confidence,
|
||||
user_id=get_effective_user_id(),
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise _map_memory_fact_value_error(exc) from exc
|
||||
@@ -221,7 +223,7 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo
|
||||
async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
|
||||
"""Delete a single fact from memory by fact id."""
|
||||
try:
|
||||
memory_data = delete_memory_fact(fact_id)
|
||||
memory_data = delete_memory_fact(fact_id, user_id=get_effective_user_id())
|
||||
except KeyError as exc:
|
||||
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
|
||||
except OSError as exc:
|
||||
@@ -245,6 +247,7 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -
|
||||
content=request.content,
|
||||
category=request.category,
|
||||
confidence=request.confidence,
|
||||
user_id=get_effective_user_id(),
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise _map_memory_fact_value_error(exc) from exc
|
||||
@@ -265,7 +268,7 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -
|
||||
)
|
||||
async def export_memory() -> MemoryResponse:
|
||||
"""Export the current memory data."""
|
||||
memory_data = get_memory_data()
|
||||
memory_data = get_memory_data(user_id=get_effective_user_id())
|
||||
return MemoryResponse(**memory_data)
|
||||
|
||||
|
||||
@@ -279,7 +282,7 @@ async def export_memory() -> MemoryResponse:
|
||||
async def import_memory(request: MemoryResponse) -> MemoryResponse:
|
||||
"""Import and persist memory data."""
|
||||
try:
|
||||
memory_data = import_memory_data(request.model_dump())
|
||||
memory_data = import_memory_data(request.model_dump(), user_id=get_effective_user_id())
|
||||
except OSError as exc:
|
||||
raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc
|
||||
|
||||
@@ -337,7 +340,7 @@ async def get_memory_status() -> MemoryStatusResponse:
|
||||
Combined memory configuration and current data.
|
||||
"""
|
||||
config = get_memory_config()
|
||||
memory_data = get_memory_data()
|
||||
memory_data = get_memory_data(user_id=get_effective_user_id())
|
||||
|
||||
return MemoryStatusResponse(
|
||||
config=MemoryConfigResponse(
|
||||
|
||||
@@ -11,10 +11,11 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi import APIRouter, HTTPException, Query, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge
|
||||
from app.gateway.authz import require_permission
|
||||
from app.gateway.deps import get_checkpointer, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
||||
from app.gateway.routers.thread_runs import RunCreateRequest
|
||||
from app.gateway.services import sse_consumer, start_run
|
||||
from deerflow.runtime import serialize_channel_values
|
||||
@@ -85,3 +86,57 @@ async def stateless_wait(body: RunCreateRequest, request: Request) -> dict:
|
||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
||||
|
||||
return {"status": record.status.value, "error": record.error}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Run-scoped read endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _resolve_run(run_id: str, request: Request) -> dict:
|
||||
"""Fetch run by run_id with user ownership check. Raises 404 if not found."""
|
||||
run_store = get_run_store(request)
|
||||
record = await run_store.get(run_id) # user_id=AUTO filters by contextvar
|
||||
if record is None:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
return record
|
||||
|
||||
|
||||
@router.get("/{run_id}/messages")
|
||||
@require_permission("runs", "read")
|
||||
async def run_messages(
|
||||
run_id: str,
|
||||
request: Request,
|
||||
limit: int = Query(default=50, le=200, ge=1),
|
||||
before_seq: int | None = Query(default=None),
|
||||
after_seq: int | None = Query(default=None),
|
||||
) -> dict:
|
||||
"""Return paginated messages for a run (cursor-based).
|
||||
|
||||
Pagination:
|
||||
- after_seq: messages with seq > after_seq (forward)
|
||||
- before_seq: messages with seq < before_seq (backward)
|
||||
- neither: latest messages
|
||||
|
||||
Response: { data: [...], has_more: bool }
|
||||
"""
|
||||
run = await _resolve_run(run_id, request)
|
||||
event_store = get_run_event_store(request)
|
||||
rows = await event_store.list_messages_by_run(
|
||||
run["thread_id"], run_id,
|
||||
limit=limit + 1,
|
||||
before_seq=before_seq,
|
||||
after_seq=after_seq,
|
||||
)
|
||||
has_more = len(rows) > limit
|
||||
data = rows[:limit] if has_more else rows
|
||||
return {"data": data, "has_more": has_more}
|
||||
|
||||
|
||||
@router.get("/{run_id}/feedback")
|
||||
@require_permission("runs", "read")
|
||||
async def run_feedback(run_id: str, request: Request) -> list[dict]:
|
||||
"""Return all feedback for a run."""
|
||||
run = await _resolve_run(run_id, request)
|
||||
feedback_repo = get_feedback_repo(request)
|
||||
return await feedback_repo.list_by_run(run["thread_id"], run_id)
|
||||
|
||||
@@ -20,7 +20,7 @@ from fastapi.responses import Response, StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
||||
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
||||
from app.gateway.services import sse_consumer, start_run
|
||||
from deerflow.runtime import RunRecord, serialize_channel_values
|
||||
|
||||
@@ -291,17 +291,62 @@ async def list_thread_messages(
|
||||
before_seq: int | None = Query(default=None),
|
||||
after_seq: int | None = Query(default=None),
|
||||
) -> list[dict]:
|
||||
"""Return displayable messages for a thread (across all runs)."""
|
||||
"""Return displayable messages for a thread (across all runs), with feedback attached."""
|
||||
event_store = get_run_event_store(request)
|
||||
return await event_store.list_messages(thread_id, limit=limit, before_seq=before_seq, after_seq=after_seq)
|
||||
messages = await event_store.list_messages(thread_id, limit=limit, before_seq=before_seq, after_seq=after_seq)
|
||||
|
||||
# Attach feedback to the last AI message of each run
|
||||
feedback_repo = get_feedback_repo(request)
|
||||
user_id = await get_current_user(request)
|
||||
feedback_map = await feedback_repo.list_by_thread_grouped(thread_id, user_id=user_id)
|
||||
|
||||
# Find the last ai_message per run_id
|
||||
last_ai_per_run: dict[str, int] = {} # run_id -> index in messages list
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.get("event_type") == "ai_message":
|
||||
last_ai_per_run[msg["run_id"]] = i
|
||||
|
||||
# Attach feedback field
|
||||
last_ai_indices = set(last_ai_per_run.values())
|
||||
for i, msg in enumerate(messages):
|
||||
if i in last_ai_indices:
|
||||
run_id = msg["run_id"]
|
||||
fb = feedback_map.get(run_id)
|
||||
msg["feedback"] = {
|
||||
"feedback_id": fb["feedback_id"],
|
||||
"rating": fb["rating"],
|
||||
"comment": fb.get("comment"),
|
||||
} if fb else None
|
||||
else:
|
||||
msg["feedback"] = None
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/messages")
|
||||
@require_permission("runs", "read", owner_check=True)
|
||||
async def list_run_messages(thread_id: str, run_id: str, request: Request) -> list[dict]:
|
||||
"""Return displayable messages for a specific run."""
|
||||
async def list_run_messages(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
limit: int = Query(default=50, le=200, ge=1),
|
||||
before_seq: int | None = Query(default=None),
|
||||
after_seq: int | None = Query(default=None),
|
||||
) -> dict:
|
||||
"""Return paginated messages for a specific run.
|
||||
|
||||
Response: { data: [...], has_more: bool }
|
||||
"""
|
||||
event_store = get_run_event_store(request)
|
||||
return await event_store.list_messages_by_run(thread_id, run_id)
|
||||
rows = await event_store.list_messages_by_run(
|
||||
thread_id, run_id,
|
||||
limit=limit + 1,
|
||||
before_seq=before_seq,
|
||||
after_seq=after_seq,
|
||||
)
|
||||
has_more = len(rows) > limit
|
||||
data = rows[:limit] if has_more else rows
|
||||
return {"data": data, "has_more": has_more}
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/events")
|
||||
|
||||
@@ -13,6 +13,7 @@ matching the LangGraph Platform wire format expected by the
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
@@ -21,10 +22,11 @@ from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
from app.gateway.deps import get_checkpointer
|
||||
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store
|
||||
from app.gateway.utils import sanitize_log_param
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.runtime import serialize_channel_values
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/threads", tags=["threads"])
|
||||
@@ -34,7 +36,7 @@ router = APIRouter(prefix="/api/threads", tags=["threads"])
|
||||
# them. Pydantic ``@field_validator("metadata")`` strips them on every
|
||||
# inbound model below so a malicious client cannot reflect a forged
|
||||
# owner identity through the API surface. Defense-in-depth — the
|
||||
# row-level invariant is still ``threads_meta.owner_id`` populated from
|
||||
# row-level invariant is still ``threads_meta.user_id`` populated from
|
||||
# the auth contextvar; this list closes the metadata-blob echo gap.
|
||||
_SERVER_RESERVED_METADATA_KEYS: frozenset[str] = frozenset({"owner_id", "user_id"})
|
||||
|
||||
@@ -142,11 +144,11 @@ class ThreadHistoryRequest(BaseModel):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _delete_thread_data(thread_id: str, paths: Paths | None = None) -> ThreadDeleteResponse:
|
||||
def _delete_thread_data(thread_id: str, paths: Paths | None = None, *, user_id: str | None = None) -> ThreadDeleteResponse:
|
||||
"""Delete local persisted filesystem data for a thread."""
|
||||
path_manager = paths or get_paths()
|
||||
try:
|
||||
path_manager.delete_thread_dir(thread_id)
|
||||
path_manager.delete_thread_dir(thread_id, user_id=user_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
except FileNotFoundError:
|
||||
@@ -194,10 +196,10 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
|
||||
and removes the thread_meta row from the configured ThreadMetaStore
|
||||
(sqlite or memory).
|
||||
"""
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
from app.gateway.deps import get_thread_store
|
||||
|
||||
# Clean local filesystem
|
||||
response = _delete_thread_data(thread_id)
|
||||
response = _delete_thread_data(thread_id, user_id=get_effective_user_id())
|
||||
|
||||
# Remove checkpoints (best-effort)
|
||||
checkpointer = getattr(request.app.state, "checkpointer", None)
|
||||
@@ -211,8 +213,8 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
|
||||
# Remove thread_meta row (best-effort) — required for sqlite backend
|
||||
# so the deleted thread no longer appears in /threads/search.
|
||||
try:
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
await thread_meta_repo.delete(thread_id)
|
||||
thread_store = get_thread_store(request)
|
||||
await thread_store.delete(thread_id)
|
||||
except Exception:
|
||||
logger.debug("Could not delete thread_meta for %s (not critical)", sanitize_log_param(thread_id))
|
||||
|
||||
@@ -227,17 +229,17 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
and an empty checkpoint (so state endpoints work immediately).
|
||||
Idempotent: returns the existing record when ``thread_id`` already exists.
|
||||
"""
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
from app.gateway.deps import get_thread_store
|
||||
|
||||
checkpointer = get_checkpointer(request)
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
thread_store = get_thread_store(request)
|
||||
thread_id = body.thread_id or str(uuid.uuid4())
|
||||
now = time.time()
|
||||
# ``body.metadata`` is already stripped of server-reserved keys by
|
||||
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
|
||||
|
||||
# Idempotency: return existing record when already present
|
||||
existing_record = await thread_meta_repo.get(thread_id)
|
||||
existing_record = await thread_store.get(thread_id)
|
||||
if existing_record is not None:
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
@@ -249,7 +251,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
|
||||
# Write thread_meta so the thread appears in /threads/search immediately
|
||||
try:
|
||||
await thread_meta_repo.create(
|
||||
await thread_store.create(
|
||||
thread_id,
|
||||
assistant_id=getattr(body, "assistant_id", None),
|
||||
metadata=body.metadata,
|
||||
@@ -293,9 +295,9 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
|
||||
Delegates to the configured ThreadMetaStore implementation
|
||||
(SQL-backed for sqlite/postgres, Store-backed for memory mode).
|
||||
"""
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
from app.gateway.deps import get_thread_store
|
||||
|
||||
repo = get_thread_meta_repo(request)
|
||||
repo = get_thread_store(request)
|
||||
rows = await repo.search(
|
||||
metadata=body.metadata or None,
|
||||
status=body.status,
|
||||
@@ -320,22 +322,22 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
|
||||
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
||||
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
|
||||
"""Merge metadata into a thread record."""
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
from app.gateway.deps import get_thread_store
|
||||
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
record = await thread_meta_repo.get(thread_id)
|
||||
thread_store = get_thread_store(request)
|
||||
record = await thread_store.get(thread_id)
|
||||
if record is None:
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
# ``body.metadata`` already stripped by ``ThreadPatchRequest._strip_reserved``.
|
||||
try:
|
||||
await thread_meta_repo.update_metadata(thread_id, body.metadata)
|
||||
await thread_store.update_metadata(thread_id, body.metadata)
|
||||
except Exception:
|
||||
logger.exception("Failed to patch thread %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to update thread")
|
||||
|
||||
# Re-read to get the merged metadata + refreshed updated_at
|
||||
record = await thread_meta_repo.get(thread_id) or record
|
||||
record = await thread_store.get(thread_id) or record
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=record.get("status", "idle"),
|
||||
@@ -354,12 +356,12 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
||||
execution status from the checkpointer. Falls back to the checkpointer
|
||||
alone for threads that pre-date ThreadMetaStore adoption (backward compat).
|
||||
"""
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
from app.gateway.deps import get_thread_store
|
||||
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
thread_store = get_thread_store(request)
|
||||
checkpointer = get_checkpointer(request)
|
||||
|
||||
record: dict | None = await thread_meta_repo.get(thread_id)
|
||||
record: dict | None = await thread_store.get(thread_id)
|
||||
|
||||
# Derive accurate status from the checkpointer
|
||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
@@ -402,6 +404,165 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Event-store-backed message loader
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_LEGACY_CMD_INNER_CONTENT_RE = re.compile(
|
||||
r"ToolMessage\(content=(?P<q>['\"])(?P<inner>.*?)(?P=q)",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_legacy_command_repr(content_field: Any) -> Any:
|
||||
"""Recover the inner ToolMessage text from a legacy ``str(Command(...))`` repr.
|
||||
|
||||
Runs captured before the ``on_tool_end`` fix in ``journal.py`` stored
|
||||
``str(Command(update={'messages':[ToolMessage(content='X', ...)]}))`` as the
|
||||
tool_result content. New runs store ``'X'`` directly. For legacy rows, try
|
||||
to extract ``'X'`` defensively; return the original string if extraction
|
||||
fails (still no worse than the checkpoint fallback for summarized threads).
|
||||
"""
|
||||
if not isinstance(content_field, str) or not content_field.startswith("Command(update="):
|
||||
return content_field
|
||||
match = _LEGACY_CMD_INNER_CONTENT_RE.search(content_field)
|
||||
return match.group("inner") if match else content_field
|
||||
|
||||
|
||||
async def _get_event_store_messages(request: Request, thread_id: str) -> list[dict] | None:
|
||||
"""Load the full message stream for ``thread_id`` from the event store.
|
||||
|
||||
The event store is append-only and unaffected by summarization — the
|
||||
checkpoint's ``channel_values["messages"]`` is rewritten in-place when the
|
||||
SummarizationMiddleware runs, which drops all pre-summarize messages. The
|
||||
event store retains the full transcript, so callers in Gateway mode should
|
||||
prefer it for rendering the conversation history.
|
||||
|
||||
In addition to the core message content, this helper attaches two extra
|
||||
fields to every returned dict:
|
||||
|
||||
- ``run_id``: the ``run_id`` of the event that produced this message.
|
||||
Always present.
|
||||
- ``feedback``: thumbs-up/down data. Present only on the **final
|
||||
``ai_message`` of each run** (matching the per-run feedback semantics
|
||||
of ``POST /api/threads/{id}/runs/{run_id}/feedback``). The frontend uses
|
||||
the presence of this field to decide whether to render the feedback
|
||||
button, which sidesteps the positional-index mapping bug that an
|
||||
out-of-band ``/messages`` fetch exhibited.
|
||||
|
||||
Behaviour contract:
|
||||
|
||||
- **Full pagination.** ``RunEventStore.list_messages`` returns the newest
|
||||
``limit`` records when no cursor is given, so a fixed limit silently
|
||||
drops older messages on long threads. We size the read from
|
||||
``count_messages()`` and then page forward with ``after_seq`` cursors.
|
||||
- **Copy-on-read.** Each content dict is copied before ``id`` is patched
|
||||
so the live store object is never mutated; ``MemoryRunEventStore``
|
||||
returns live references.
|
||||
- **Stable ids.** Messages with ``id=None`` (human + tool_result) receive
|
||||
a deterministic ``uuid5(NAMESPACE_URL, f"{thread_id}:{seq}")`` so React
|
||||
keys are stable across requests without altering stored data. AI messages
|
||||
retain their LLM-assigned ``lc_run--*`` ids.
|
||||
- **Legacy Command repr.** Rows captured before the ``journal.py``
|
||||
``on_tool_end`` fix stored ``str(Command(update={...}))`` as the tool
|
||||
result content. ``_sanitize_legacy_command_repr`` extracts the inner
|
||||
ToolMessage text.
|
||||
- **User context.** ``DbRunEventStore`` is user-scoped by default via
|
||||
``resolve_user_id(AUTO)`` in ``runtime/user_context.py``. This helper
|
||||
must run inside a request where ``@require_permission`` has populated
|
||||
the user contextvar. Both callers below are decorated appropriately.
|
||||
Do not call this helper from CLI or migration scripts without passing
|
||||
``user_id=None`` explicitly to the underlying store methods.
|
||||
|
||||
Returns ``None`` when the event store is not configured or has no message
|
||||
events for this thread, so callers fall back to checkpoint messages.
|
||||
"""
|
||||
try:
|
||||
event_store = get_run_event_store(request)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
try:
|
||||
total = await event_store.count_messages(thread_id)
|
||||
except Exception:
|
||||
logger.exception("count_messages failed for thread %s", sanitize_log_param(thread_id))
|
||||
return None
|
||||
if not total:
|
||||
return None
|
||||
|
||||
# Batch by page_size to keep memory bounded for very long threads.
|
||||
page_size = 500
|
||||
collected: list[dict] = []
|
||||
after_seq: int | None = None
|
||||
while True:
|
||||
try:
|
||||
page = await event_store.list_messages(thread_id, limit=page_size, after_seq=after_seq)
|
||||
except Exception:
|
||||
logger.exception("list_messages failed for thread %s", sanitize_log_param(thread_id))
|
||||
return None
|
||||
if not page:
|
||||
break
|
||||
collected.extend(page)
|
||||
if len(page) < page_size:
|
||||
break
|
||||
next_cursor = page[-1].get("seq")
|
||||
if next_cursor is None or (after_seq is not None and next_cursor <= after_seq):
|
||||
break
|
||||
after_seq = next_cursor
|
||||
|
||||
# Build the message list; track the final ``ai_message`` index per run so
|
||||
# feedback can be attached at the right position (matches thread_runs.py).
|
||||
messages: list[dict] = []
|
||||
last_ai_per_run: dict[str, int] = {}
|
||||
for evt in collected:
|
||||
raw = evt.get("content")
|
||||
if not isinstance(raw, dict) or "type" not in raw:
|
||||
continue
|
||||
content = dict(raw)
|
||||
if content.get("id") is None:
|
||||
content["id"] = str(uuid.uuid5(uuid.NAMESPACE_URL, f"{thread_id}:{evt['seq']}"))
|
||||
if content.get("type") == "tool":
|
||||
content["content"] = _sanitize_legacy_command_repr(content.get("content"))
|
||||
run_id = evt.get("run_id")
|
||||
if run_id:
|
||||
content["run_id"] = run_id
|
||||
if evt.get("event_type") == "ai_message" and run_id:
|
||||
last_ai_per_run[run_id] = len(messages)
|
||||
messages.append(content)
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Attach feedback to the final ai_message of each run. If the feedback
|
||||
# subsystem is unavailable, leave the ``feedback`` field absent entirely
|
||||
# so the frontend hides the button rather than showing it over a broken
|
||||
# write path.
|
||||
feedback_available = False
|
||||
feedback_map: dict[str, dict] = {}
|
||||
try:
|
||||
feedback_repo = get_feedback_repo(request)
|
||||
user_id = await get_current_user(request)
|
||||
feedback_map = await feedback_repo.list_by_thread_grouped(thread_id, user_id=user_id)
|
||||
feedback_available = True
|
||||
except Exception:
|
||||
logger.exception("feedback lookup failed for thread %s", sanitize_log_param(thread_id))
|
||||
|
||||
if feedback_available:
|
||||
for run_id, idx in last_ai_per_run.items():
|
||||
fb = feedback_map.get(run_id)
|
||||
messages[idx]["feedback"] = (
|
||||
{
|
||||
"feedback_id": fb["feedback_id"],
|
||||
"rating": fb["rating"],
|
||||
"comment": fb.get("comment"),
|
||||
}
|
||||
if fb
|
||||
else None
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse:
|
||||
@@ -440,8 +601,15 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
|
||||
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
|
||||
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
|
||||
|
||||
values = serialize_channel_values(channel_values)
|
||||
|
||||
# Prefer event-store messages: append-only, immune to summarization.
|
||||
es_messages = await _get_event_store_messages(request, thread_id)
|
||||
if es_messages is not None:
|
||||
values["messages"] = es_messages
|
||||
|
||||
return ThreadStateResponse(
|
||||
values=serialize_channel_values(channel_values),
|
||||
values=values,
|
||||
next=next_tasks,
|
||||
metadata=metadata,
|
||||
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
|
||||
@@ -462,10 +630,10 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
||||
ThreadMetaStore abstraction so that ``/threads/search`` reflects the
|
||||
change immediately in both sqlite and memory backends.
|
||||
"""
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
from app.gateway.deps import get_thread_store
|
||||
|
||||
checkpointer = get_checkpointer(request)
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
thread_store = get_thread_store(request)
|
||||
|
||||
# checkpoint_ns must be present in the config for aput — default to ""
|
||||
# (the root graph namespace). checkpoint_id is optional; omitting it
|
||||
@@ -529,7 +697,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
||||
new_title = body.values["title"]
|
||||
if new_title: # Skip empty strings and None
|
||||
try:
|
||||
await thread_meta_repo.update_display_name(thread_id, new_title)
|
||||
await thread_store.update_display_name(thread_id, new_title)
|
||||
except Exception:
|
||||
logger.debug("Failed to sync title to thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
||||
|
||||
@@ -559,6 +727,11 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
|
||||
if body.before:
|
||||
config["configurable"]["checkpoint_id"] = body.before
|
||||
|
||||
# Load the full event-store message stream once; attach to the latest
|
||||
# checkpoint entry only (matching the prior semantics). The event store
|
||||
# is append-only and immune to summarization.
|
||||
es_messages = await _get_event_store_messages(request, thread_id)
|
||||
|
||||
entries: list[HistoryEntry] = []
|
||||
is_latest_checkpoint = True
|
||||
try:
|
||||
@@ -582,11 +755,17 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
|
||||
if thread_data := channel_values.get("thread_data"):
|
||||
values["thread_data"] = thread_data
|
||||
|
||||
# Attach messages from checkpointer only for the latest checkpoint
|
||||
# Attach messages only to the latest checkpoint. Prefer the
|
||||
# event-store stream (complete and unaffected by summarization);
|
||||
# fall back to checkpoint channel_values when the event store is
|
||||
# unavailable or empty.
|
||||
if is_latest_checkpoint:
|
||||
messages = channel_values.get("messages")
|
||||
if messages:
|
||||
values["messages"] = serialize_channel_values({"messages": messages}).get("messages", [])
|
||||
if es_messages is not None:
|
||||
values["messages"] = es_messages
|
||||
else:
|
||||
messages = channel_values.get("messages")
|
||||
if messages:
|
||||
values["messages"] = serialize_channel_values({"messages": messages}).get("messages", [])
|
||||
is_latest_checkpoint = False
|
||||
|
||||
# Derive next tasks
|
||||
|
||||
@@ -9,6 +9,7 @@ from pydantic import BaseModel
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||
from deerflow.uploads.manager import (
|
||||
PathTraversalError,
|
||||
@@ -69,7 +70,7 @@ async def upload_files(
|
||||
uploads_dir = ensure_uploads_dir(thread_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id)
|
||||
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
||||
uploaded_files = []
|
||||
|
||||
sandbox_provider = get_sandbox_provider()
|
||||
@@ -147,7 +148,7 @@ async def list_uploaded_files(thread_id: str, request: Request) -> dict:
|
||||
enrich_file_listing(result, thread_id)
|
||||
|
||||
# Gateway additionally includes the sandbox-relative path.
|
||||
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id)
|
||||
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
||||
for f in result["files"]:
|
||||
f["path"] = str(sandbox_uploads / f["filename"])
|
||||
|
||||
|
||||
@@ -229,15 +229,15 @@ async def start_run(
|
||||
# even for threads that were never explicitly created via POST /threads
|
||||
# (e.g. stateless runs).
|
||||
try:
|
||||
existing = await run_ctx.thread_meta_repo.get(thread_id)
|
||||
existing = await run_ctx.thread_store.get(thread_id)
|
||||
if existing is None:
|
||||
await run_ctx.thread_meta_repo.create(
|
||||
await run_ctx.thread_store.create(
|
||||
thread_id,
|
||||
assistant_id=body.assistant_id,
|
||||
metadata=body.metadata,
|
||||
)
|
||||
else:
|
||||
await run_ctx.thread_meta_repo.update_status(thread_id, "running")
|
||||
await run_ctx.thread_store.update_status(thread_id, "running")
|
||||
except Exception:
|
||||
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
||||
|
||||
@@ -285,7 +285,7 @@ async def start_run(
|
||||
record.task = task
|
||||
|
||||
# Title sync is handled by worker.py's finally block which reads the
|
||||
# title from the checkpoint and calls thread_meta_repo.update_display_name
|
||||
# title from the checkpoint and calls thread_store.update_display_name
|
||||
# after the run completes.
|
||||
|
||||
return record
|
||||
|
||||
@@ -124,7 +124,7 @@ title:
|
||||
# checkpointer.py
|
||||
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||
|
||||
checkpointer = SqliteSaver.from_conn_string("checkpoints.db")
|
||||
checkpointer = SqliteSaver.from_conn_string("deerflow.db")
|
||||
```
|
||||
|
||||
```json
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from .checkpointer import get_checkpointer, make_checkpointer, reset_checkpointer
|
||||
from .factory import create_deerflow_agent
|
||||
from .features import Next, Prev, RuntimeFeatures
|
||||
from .lead_agent import make_lead_agent
|
||||
@@ -18,7 +17,4 @@ __all__ = [
|
||||
"make_lead_agent",
|
||||
"SandboxState",
|
||||
"ThreadState",
|
||||
"get_checkpointer",
|
||||
"reset_checkpointer",
|
||||
"make_checkpointer",
|
||||
]
|
||||
|
||||
@@ -519,12 +519,13 @@ def _get_memory_context(agent_name: str | None = None) -> str:
|
||||
try:
|
||||
from deerflow.agents.memory import format_memory_for_injection, get_memory_data
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
config = get_memory_config()
|
||||
if not config.enabled or not config.injection_enabled:
|
||||
return ""
|
||||
|
||||
memory_data = get_memory_data(agent_name)
|
||||
memory_data = get_memory_data(agent_name, user_id=get_effective_user_id())
|
||||
memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens)
|
||||
|
||||
if not memory_content.strip():
|
||||
|
||||
@@ -20,6 +20,7 @@ class ConversationContext:
|
||||
messages: list[Any]
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
agent_name: str | None = None
|
||||
user_id: str | None = None
|
||||
correction_detected: bool = False
|
||||
reinforcement_detected: bool = False
|
||||
|
||||
@@ -44,6 +45,7 @@ class MemoryUpdateQueue:
|
||||
thread_id: str,
|
||||
messages: list[Any],
|
||||
agent_name: str | None = None,
|
||||
user_id: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
) -> None:
|
||||
@@ -53,6 +55,9 @@ class MemoryUpdateQueue:
|
||||
thread_id: The thread ID.
|
||||
messages: The conversation messages.
|
||||
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
||||
user_id: The user ID captured at enqueue time. Stored in ConversationContext so it
|
||||
survives the threading.Timer boundary (ContextVar does not propagate across
|
||||
raw threads).
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||
"""
|
||||
@@ -71,6 +76,7 @@ class MemoryUpdateQueue:
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
agent_name=agent_name,
|
||||
user_id=user_id,
|
||||
correction_detected=merged_correction_detected,
|
||||
reinforcement_detected=merged_reinforcement_detected,
|
||||
)
|
||||
@@ -136,6 +142,7 @@ class MemoryUpdateQueue:
|
||||
agent_name=context.agent_name,
|
||||
correction_detected=context.correction_detected,
|
||||
reinforcement_detected=context.reinforcement_detected,
|
||||
user_id=context.user_id,
|
||||
)
|
||||
if success:
|
||||
logger.info("Memory updated successfully for thread %s", context.thread_id)
|
||||
|
||||
@@ -43,17 +43,17 @@ class MemoryStorage(abc.ABC):
|
||||
"""Abstract base class for memory storage providers."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def load(self, agent_name: str | None = None) -> dict[str, Any]:
|
||||
def load(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Load memory data for the given agent."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def reload(self, agent_name: str | None = None) -> dict[str, Any]:
|
||||
def reload(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Force reload memory data for the given agent."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
|
||||
def save(self, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
|
||||
"""Save memory data for the given agent."""
|
||||
pass
|
||||
|
||||
@@ -63,9 +63,9 @@ class FileMemoryStorage(MemoryStorage):
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the file memory storage."""
|
||||
# Per-agent memory cache: keyed by agent_name (None = global)
|
||||
# Per-user/agent memory cache: keyed by (user_id, agent_name) tuple (None = global)
|
||||
# Value: (memory_data, file_mtime)
|
||||
self._memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {}
|
||||
self._memory_cache: dict[tuple[str | None, str | None], tuple[dict[str, Any], float | None]] = {}
|
||||
|
||||
def _validate_agent_name(self, agent_name: str) -> None:
|
||||
"""Validate that the agent name is safe to use in filesystem paths.
|
||||
@@ -78,21 +78,29 @@ class FileMemoryStorage(MemoryStorage):
|
||||
if not AGENT_NAME_PATTERN.match(agent_name):
|
||||
raise ValueError(f"Invalid agent name {agent_name!r}: names must match {AGENT_NAME_PATTERN.pattern}")
|
||||
|
||||
def _get_memory_file_path(self, agent_name: str | None = None) -> Path:
|
||||
def _get_memory_file_path(self, agent_name: str | None = None, *, user_id: str | None = None) -> Path:
|
||||
"""Get the path to the memory file."""
|
||||
if user_id is not None:
|
||||
if agent_name is not None:
|
||||
self._validate_agent_name(agent_name)
|
||||
return get_paths().user_agent_memory_file(user_id, agent_name)
|
||||
config = get_memory_config()
|
||||
if config.storage_path and Path(config.storage_path).is_absolute():
|
||||
return Path(config.storage_path)
|
||||
return get_paths().user_memory_file(user_id)
|
||||
# Legacy: no user_id
|
||||
if agent_name is not None:
|
||||
self._validate_agent_name(agent_name)
|
||||
return get_paths().agent_memory_file(agent_name)
|
||||
|
||||
config = get_memory_config()
|
||||
if config.storage_path:
|
||||
p = Path(config.storage_path)
|
||||
return p if p.is_absolute() else get_paths().base_dir / p
|
||||
return get_paths().memory_file
|
||||
|
||||
def _load_memory_from_file(self, agent_name: str | None = None) -> dict[str, Any]:
|
||||
def _load_memory_from_file(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Load memory data from file."""
|
||||
file_path = self._get_memory_file_path(agent_name)
|
||||
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
||||
|
||||
if not file_path.exists():
|
||||
return create_empty_memory()
|
||||
@@ -105,40 +113,42 @@ class FileMemoryStorage(MemoryStorage):
|
||||
logger.warning("Failed to load memory file: %s", e)
|
||||
return create_empty_memory()
|
||||
|
||||
def load(self, agent_name: str | None = None) -> dict[str, Any]:
|
||||
def load(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Load memory data (cached with file modification time check)."""
|
||||
file_path = self._get_memory_file_path(agent_name)
|
||||
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
||||
|
||||
try:
|
||||
current_mtime = file_path.stat().st_mtime if file_path.exists() else None
|
||||
except OSError:
|
||||
current_mtime = None
|
||||
|
||||
cached = self._memory_cache.get(agent_name)
|
||||
cache_key = (user_id, agent_name)
|
||||
cached = self._memory_cache.get(cache_key)
|
||||
|
||||
if cached is None or cached[1] != current_mtime:
|
||||
memory_data = self._load_memory_from_file(agent_name)
|
||||
self._memory_cache[agent_name] = (memory_data, current_mtime)
|
||||
memory_data = self._load_memory_from_file(agent_name, user_id=user_id)
|
||||
self._memory_cache[cache_key] = (memory_data, current_mtime)
|
||||
return memory_data
|
||||
|
||||
return cached[0]
|
||||
|
||||
def reload(self, agent_name: str | None = None) -> dict[str, Any]:
|
||||
def reload(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Reload memory data from file, forcing cache invalidation."""
|
||||
file_path = self._get_memory_file_path(agent_name)
|
||||
memory_data = self._load_memory_from_file(agent_name)
|
||||
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
||||
memory_data = self._load_memory_from_file(agent_name, user_id=user_id)
|
||||
|
||||
try:
|
||||
mtime = file_path.stat().st_mtime if file_path.exists() else None
|
||||
except OSError:
|
||||
mtime = None
|
||||
|
||||
self._memory_cache[agent_name] = (memory_data, mtime)
|
||||
cache_key = (user_id, agent_name)
|
||||
self._memory_cache[cache_key] = (memory_data, mtime)
|
||||
return memory_data
|
||||
|
||||
def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
|
||||
def save(self, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
|
||||
"""Save memory data to file and update cache."""
|
||||
file_path = self._get_memory_file_path(agent_name)
|
||||
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
||||
|
||||
try:
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
@@ -155,7 +165,8 @@ class FileMemoryStorage(MemoryStorage):
|
||||
except OSError:
|
||||
mtime = None
|
||||
|
||||
self._memory_cache[agent_name] = (memory_data, mtime)
|
||||
cache_key = (user_id, agent_name)
|
||||
self._memory_cache[cache_key] = (memory_data, mtime)
|
||||
logger.info("Memory saved to %s", file_path)
|
||||
return True
|
||||
except OSError as e:
|
||||
|
||||
@@ -27,27 +27,28 @@ def _create_empty_memory() -> dict[str, Any]:
|
||||
return create_empty_memory()
|
||||
|
||||
|
||||
def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
|
||||
def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
|
||||
"""Backward-compatible wrapper around the configured memory storage save path."""
|
||||
return get_memory_storage().save(memory_data, agent_name)
|
||||
return get_memory_storage().save(memory_data, agent_name, user_id=user_id)
|
||||
|
||||
|
||||
def get_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
||||
def get_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Get the current memory data via storage provider."""
|
||||
return get_memory_storage().load(agent_name)
|
||||
return get_memory_storage().load(agent_name, user_id=user_id)
|
||||
|
||||
|
||||
def reload_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
||||
def reload_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Reload memory data via storage provider."""
|
||||
return get_memory_storage().reload(agent_name)
|
||||
return get_memory_storage().reload(agent_name, user_id=user_id)
|
||||
|
||||
|
||||
def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None) -> dict[str, Any]:
|
||||
def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Persist imported memory data via storage provider.
|
||||
|
||||
Args:
|
||||
memory_data: Full memory payload to persist.
|
||||
agent_name: If provided, imports into per-agent memory.
|
||||
user_id: If provided, scopes memory to a specific user.
|
||||
|
||||
Returns:
|
||||
The saved memory data after storage normalization.
|
||||
@@ -56,15 +57,15 @@ def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = Non
|
||||
OSError: If persisting the imported memory fails.
|
||||
"""
|
||||
storage = get_memory_storage()
|
||||
if not storage.save(memory_data, agent_name):
|
||||
if not storage.save(memory_data, agent_name, user_id=user_id):
|
||||
raise OSError("Failed to save imported memory data")
|
||||
return storage.load(agent_name)
|
||||
return storage.load(agent_name, user_id=user_id)
|
||||
|
||||
|
||||
def clear_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
||||
def clear_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Clear all stored memory data and persist an empty structure."""
|
||||
cleared_memory = create_empty_memory()
|
||||
if not _save_memory_to_file(cleared_memory, agent_name):
|
||||
if not _save_memory_to_file(cleared_memory, agent_name, user_id=user_id):
|
||||
raise OSError("Failed to save cleared memory data")
|
||||
return cleared_memory
|
||||
|
||||
@@ -81,6 +82,8 @@ def create_memory_fact(
|
||||
category: str = "context",
|
||||
confidence: float = 0.5,
|
||||
agent_name: str | None = None,
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new fact and persist the updated memory data."""
|
||||
normalized_content = content.strip()
|
||||
@@ -90,7 +93,7 @@ def create_memory_fact(
|
||||
normalized_category = category.strip() or "context"
|
||||
validated_confidence = _validate_confidence(confidence)
|
||||
now = utc_now_iso_z()
|
||||
memory_data = get_memory_data(agent_name)
|
||||
memory_data = get_memory_data(agent_name, user_id=user_id)
|
||||
updated_memory = dict(memory_data)
|
||||
facts = list(memory_data.get("facts", []))
|
||||
facts.append(
|
||||
@@ -105,15 +108,15 @@ def create_memory_fact(
|
||||
)
|
||||
updated_memory["facts"] = facts
|
||||
|
||||
if not _save_memory_to_file(updated_memory, agent_name):
|
||||
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
|
||||
raise OSError("Failed to save memory data after creating fact")
|
||||
|
||||
return updated_memory
|
||||
|
||||
|
||||
def delete_memory_fact(fact_id: str, agent_name: str | None = None) -> dict[str, Any]:
|
||||
def delete_memory_fact(fact_id: str, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Delete a fact by its id and persist the updated memory data."""
|
||||
memory_data = get_memory_data(agent_name)
|
||||
memory_data = get_memory_data(agent_name, user_id=user_id)
|
||||
facts = memory_data.get("facts", [])
|
||||
updated_facts = [fact for fact in facts if fact.get("id") != fact_id]
|
||||
if len(updated_facts) == len(facts):
|
||||
@@ -122,7 +125,7 @@ def delete_memory_fact(fact_id: str, agent_name: str | None = None) -> dict[str,
|
||||
updated_memory = dict(memory_data)
|
||||
updated_memory["facts"] = updated_facts
|
||||
|
||||
if not _save_memory_to_file(updated_memory, agent_name):
|
||||
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
|
||||
raise OSError(f"Failed to save memory data after deleting fact '{fact_id}'")
|
||||
|
||||
return updated_memory
|
||||
@@ -134,9 +137,11 @@ def update_memory_fact(
|
||||
category: str | None = None,
|
||||
confidence: float | None = None,
|
||||
agent_name: str | None = None,
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing fact and persist the updated memory data."""
|
||||
memory_data = get_memory_data(agent_name)
|
||||
memory_data = get_memory_data(agent_name, user_id=user_id)
|
||||
updated_memory = dict(memory_data)
|
||||
updated_facts: list[dict[str, Any]] = []
|
||||
found = False
|
||||
@@ -163,7 +168,7 @@ def update_memory_fact(
|
||||
|
||||
updated_memory["facts"] = updated_facts
|
||||
|
||||
if not _save_memory_to_file(updated_memory, agent_name):
|
||||
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
|
||||
raise OSError(f"Failed to save memory data after updating fact '{fact_id}'")
|
||||
|
||||
return updated_memory
|
||||
@@ -276,6 +281,7 @@ class MemoryUpdater:
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
user_id: str | None = None,
|
||||
) -> bool:
|
||||
"""Update memory based on conversation messages.
|
||||
|
||||
@@ -285,6 +291,7 @@ class MemoryUpdater:
|
||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||
user_id: If provided, scopes memory to a specific user.
|
||||
|
||||
Returns:
|
||||
True if update was successful, False otherwise.
|
||||
@@ -298,7 +305,7 @@ class MemoryUpdater:
|
||||
|
||||
try:
|
||||
# Get current memory
|
||||
current_memory = get_memory_data(agent_name)
|
||||
current_memory = get_memory_data(agent_name, user_id=user_id)
|
||||
|
||||
# Format conversation for prompt
|
||||
conversation_text = format_conversation_for_update(messages)
|
||||
@@ -353,7 +360,7 @@ class MemoryUpdater:
|
||||
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
|
||||
|
||||
# Save
|
||||
return get_memory_storage().save(updated_memory, agent_name)
|
||||
return get_memory_storage().save(updated_memory, agent_name, user_id=user_id)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning("Failed to parse LLM response for memory update: %s", e)
|
||||
@@ -455,6 +462,7 @@ def update_memory_from_conversation(
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
user_id: str | None = None,
|
||||
) -> bool:
|
||||
"""Convenience function to update memory from a conversation.
|
||||
|
||||
@@ -464,9 +472,10 @@ def update_memory_from_conversation(
|
||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||
user_id: If provided, scopes memory to a specific user.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
"""
|
||||
updater = MemoryUpdater()
|
||||
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected)
|
||||
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected, user_id=user_id)
|
||||
|
||||
@@ -11,6 +11,7 @@ from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.memory.queue import get_memory_queue
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -236,11 +237,16 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
# Queue the filtered conversation for memory update
|
||||
correction_detected = detect_correction(filtered_messages)
|
||||
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
|
||||
# Capture user_id at enqueue time while the request context is still alive.
|
||||
# threading.Timer fires on a different thread where ContextVar values are not
|
||||
# propagated, so we must store user_id explicitly in ConversationContext.
|
||||
user_id = get_effective_user_id()
|
||||
queue = get_memory_queue()
|
||||
queue.add(
|
||||
thread_id=thread_id,
|
||||
messages=filtered_messages,
|
||||
agent_name=self._agent_name,
|
||||
user_id=user_id,
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.thread_state import ThreadDataState
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -46,32 +47,34 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
|
||||
self._paths = Paths(base_dir) if base_dir else get_paths()
|
||||
self._lazy_init = lazy_init
|
||||
|
||||
def _get_thread_paths(self, thread_id: str) -> dict[str, str]:
|
||||
def _get_thread_paths(self, thread_id: str, user_id: str | None = None) -> dict[str, str]:
|
||||
"""Get the paths for a thread's data directories.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
user_id: Optional user ID for per-user path isolation.
|
||||
|
||||
Returns:
|
||||
Dictionary with workspace_path, uploads_path, and outputs_path.
|
||||
"""
|
||||
return {
|
||||
"workspace_path": str(self._paths.sandbox_work_dir(thread_id)),
|
||||
"uploads_path": str(self._paths.sandbox_uploads_dir(thread_id)),
|
||||
"outputs_path": str(self._paths.sandbox_outputs_dir(thread_id)),
|
||||
"workspace_path": str(self._paths.sandbox_work_dir(thread_id, user_id=user_id)),
|
||||
"uploads_path": str(self._paths.sandbox_uploads_dir(thread_id, user_id=user_id)),
|
||||
"outputs_path": str(self._paths.sandbox_outputs_dir(thread_id, user_id=user_id)),
|
||||
}
|
||||
|
||||
def _create_thread_directories(self, thread_id: str) -> dict[str, str]:
|
||||
def _create_thread_directories(self, thread_id: str, user_id: str | None = None) -> dict[str, str]:
|
||||
"""Create the thread data directories.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
user_id: Optional user ID for per-user path isolation.
|
||||
|
||||
Returns:
|
||||
Dictionary with the created directory paths.
|
||||
"""
|
||||
self._paths.ensure_thread_dirs(thread_id)
|
||||
return self._get_thread_paths(thread_id)
|
||||
self._paths.ensure_thread_dirs(thread_id, user_id=user_id)
|
||||
return self._get_thread_paths(thread_id, user_id=user_id)
|
||||
|
||||
@override
|
||||
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
@@ -84,12 +87,14 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
|
||||
if thread_id is None:
|
||||
raise ValueError("Thread ID is required in runtime context or config.configurable")
|
||||
|
||||
user_id = get_effective_user_id()
|
||||
|
||||
if self._lazy_init:
|
||||
# Lazy initialization: only compute paths, don't create directories
|
||||
paths = self._get_thread_paths(thread_id)
|
||||
paths = self._get_thread_paths(thread_id, user_id=user_id)
|
||||
else:
|
||||
# Eager initialization: create directories immediately
|
||||
paths = self._create_thread_directories(thread_id)
|
||||
paths = self._create_thread_directories(thread_id, user_id=user_id)
|
||||
logger.debug("Created thread data directories for thread %s", thread_id)
|
||||
|
||||
return {
|
||||
|
||||
@@ -10,6 +10,7 @@ from langchain_core.messages import HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.utils.file_conversion import extract_outline
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -221,7 +222,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||||
thread_id = get_config().get("configurable", {}).get("thread_id")
|
||||
except RuntimeError:
|
||||
pass # get_config() raises outside a runnable context (e.g. unit tests)
|
||||
uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
|
||||
uploads_dir = self._paths.sandbox_uploads_dir(thread_id, user_id=get_effective_user_id()) if thread_id else None
|
||||
|
||||
# Get newly uploaded files from the current message's additional_kwargs.files
|
||||
new_files = self._files_from_kwargs(last_message, uploads_dir) or []
|
||||
|
||||
@@ -40,6 +40,7 @@ from deerflow.config.app_config import get_app_config, reload_app_config
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.models import create_chat_model
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.skills.installer import install_skill_from_archive
|
||||
from deerflow.uploads.manager import (
|
||||
claim_unique_filename,
|
||||
@@ -240,7 +241,7 @@ class DeerFlowClient:
|
||||
}
|
||||
checkpointer = self._checkpointer
|
||||
if checkpointer is None:
|
||||
from deerflow.agents.checkpointer import get_checkpointer
|
||||
from deerflow.runtime.checkpointer import get_checkpointer
|
||||
|
||||
checkpointer = get_checkpointer()
|
||||
if checkpointer is not None:
|
||||
@@ -374,7 +375,7 @@ class DeerFlowClient:
|
||||
"""
|
||||
checkpointer = self._checkpointer
|
||||
if checkpointer is None:
|
||||
from deerflow.agents.checkpointer.provider import get_checkpointer
|
||||
from deerflow.runtime.checkpointer.provider import get_checkpointer
|
||||
|
||||
checkpointer = get_checkpointer()
|
||||
|
||||
@@ -429,7 +430,7 @@ class DeerFlowClient:
|
||||
"""
|
||||
checkpointer = self._checkpointer
|
||||
if checkpointer is None:
|
||||
from deerflow.agents.checkpointer.provider import get_checkpointer
|
||||
from deerflow.runtime.checkpointer.provider import get_checkpointer
|
||||
|
||||
checkpointer = get_checkpointer()
|
||||
|
||||
@@ -769,19 +770,19 @@ class DeerFlowClient:
|
||||
"""
|
||||
from deerflow.agents.memory.updater import get_memory_data
|
||||
|
||||
return get_memory_data()
|
||||
return get_memory_data(user_id=get_effective_user_id())
|
||||
|
||||
def export_memory(self) -> dict:
|
||||
"""Export current memory data for backup or transfer."""
|
||||
from deerflow.agents.memory.updater import get_memory_data
|
||||
|
||||
return get_memory_data()
|
||||
return get_memory_data(user_id=get_effective_user_id())
|
||||
|
||||
def import_memory(self, memory_data: dict) -> dict:
|
||||
"""Import and persist full memory data."""
|
||||
from deerflow.agents.memory.updater import import_memory_data
|
||||
|
||||
return import_memory_data(memory_data)
|
||||
return import_memory_data(memory_data, user_id=get_effective_user_id())
|
||||
|
||||
def get_model(self, name: str) -> dict | None:
|
||||
"""Get a specific model's configuration by name.
|
||||
@@ -956,13 +957,13 @@ class DeerFlowClient:
|
||||
"""
|
||||
from deerflow.agents.memory.updater import reload_memory_data
|
||||
|
||||
return reload_memory_data()
|
||||
return reload_memory_data(user_id=get_effective_user_id())
|
||||
|
||||
def clear_memory(self) -> dict:
|
||||
"""Clear all persisted memory data."""
|
||||
from deerflow.agents.memory.updater import clear_memory_data
|
||||
|
||||
return clear_memory_data()
|
||||
return clear_memory_data(user_id=get_effective_user_id())
|
||||
|
||||
def create_memory_fact(self, content: str, category: str = "context", confidence: float = 0.5) -> dict:
|
||||
"""Create a single fact manually."""
|
||||
@@ -1179,7 +1180,7 @@ class DeerFlowClient:
|
||||
ValueError: If the path is invalid.
|
||||
"""
|
||||
try:
|
||||
actual = get_paths().resolve_virtual_path(thread_id, path)
|
||||
actual = get_paths().resolve_virtual_path(thread_id, path, user_id=get_effective_user_id())
|
||||
except ValueError as exc:
|
||||
if "traversal" in str(exc):
|
||||
from deerflow.uploads.manager import PathTraversalError
|
||||
|
||||
@@ -27,6 +27,7 @@ except ImportError: # pragma: no cover - Windows fallback
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
from deerflow.sandbox.sandbox_provider import SandboxProvider
|
||||
|
||||
@@ -260,15 +261,16 @@ class AioSandboxProvider(SandboxProvider):
|
||||
mounted Docker socket (DooD), the host Docker daemon can resolve the paths.
|
||||
"""
|
||||
paths = get_paths()
|
||||
paths.ensure_thread_dirs(thread_id)
|
||||
user_id = get_effective_user_id()
|
||||
paths.ensure_thread_dirs(thread_id, user_id=user_id)
|
||||
|
||||
return [
|
||||
(paths.host_sandbox_work_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/workspace", False),
|
||||
(paths.host_sandbox_uploads_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/uploads", False),
|
||||
(paths.host_sandbox_outputs_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/outputs", False),
|
||||
(paths.host_sandbox_work_dir(thread_id, user_id=user_id), f"{VIRTUAL_PATH_PREFIX}/workspace", False),
|
||||
(paths.host_sandbox_uploads_dir(thread_id, user_id=user_id), f"{VIRTUAL_PATH_PREFIX}/uploads", False),
|
||||
(paths.host_sandbox_outputs_dir(thread_id, user_id=user_id), f"{VIRTUAL_PATH_PREFIX}/outputs", False),
|
||||
# ACP workspace: read-only inside the sandbox (lead agent reads results;
|
||||
# the ACP subprocess writes from the host side, not from within the container).
|
||||
(paths.host_acp_workspace_dir(thread_id), "/mnt/acp-workspace", True),
|
||||
(paths.host_acp_workspace_dir(thread_id, user_id=user_id), "/mnt/acp-workspace", True),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -480,8 +482,9 @@ class AioSandboxProvider(SandboxProvider):
|
||||
across multiple processes, preventing container-name conflicts.
|
||||
"""
|
||||
paths = get_paths()
|
||||
paths.ensure_thread_dirs(thread_id)
|
||||
lock_path = paths.thread_dir(thread_id) / f"{sandbox_id}.lock"
|
||||
user_id = get_effective_user_id()
|
||||
paths.ensure_thread_dirs(thread_id, user_id=user_id)
|
||||
lock_path = paths.thread_dir(thread_id, user_id=user_id) / f"{sandbox_id}.lock"
|
||||
|
||||
with open(lock_path, "a", encoding="utf-8") as lock_file:
|
||||
locked = False
|
||||
|
||||
@@ -4,8 +4,12 @@ Controls BOTH the LangGraph checkpointer and the DeerFlow application
|
||||
persistence layer (runs, threads metadata, users, etc.). The user
|
||||
configures one backend; the system handles physical separation details.
|
||||
|
||||
SQLite mode: checkpointer and app use different .db files in the same
|
||||
directory to avoid write-lock contention. This is automatic.
|
||||
SQLite mode: checkpointer and app share a single .db file
|
||||
({sqlite_dir}/deerflow.db) with WAL journal mode enabled on every
|
||||
connection. WAL allows concurrent readers and a single writer without
|
||||
blocking, making a unified file safe for both workloads. Writers
|
||||
that contend for the lock wait via the default 5-second sqlite3
|
||||
busy timeout rather than failing immediately.
|
||||
|
||||
Postgres mode: both use the same database URL but maintain independent
|
||||
connection pools with different lifecycles.
|
||||
@@ -40,7 +44,7 @@ class DatabaseConfig(BaseModel):
|
||||
)
|
||||
sqlite_dir: str = Field(
|
||||
default=".deer-flow/data",
|
||||
description=("Directory for SQLite database files. Checkpointer uses {sqlite_dir}/checkpoints.db, application data uses {sqlite_dir}/app.db."),
|
||||
description=("Directory for the SQLite database file. Both checkpointer and application data share {sqlite_dir}/deerflow.db."),
|
||||
)
|
||||
postgres_url: str = Field(
|
||||
default="",
|
||||
@@ -69,21 +73,27 @@ class DatabaseConfig(BaseModel):
|
||||
|
||||
return str(Path(self.sqlite_dir).resolve())
|
||||
|
||||
@property
|
||||
def sqlite_path(self) -> str:
|
||||
"""Unified SQLite file path shared by checkpointer and app."""
|
||||
return os.path.join(self._resolved_sqlite_dir, "deerflow.db")
|
||||
|
||||
# Backward-compatible aliases
|
||||
@property
|
||||
def checkpointer_sqlite_path(self) -> str:
|
||||
"""SQLite file path for the LangGraph checkpointer."""
|
||||
return os.path.join(self._resolved_sqlite_dir, "checkpoints.db")
|
||||
"""SQLite file path for the LangGraph checkpointer (alias for sqlite_path)."""
|
||||
return self.sqlite_path
|
||||
|
||||
@property
|
||||
def app_sqlite_path(self) -> str:
|
||||
"""SQLite file path for application ORM data."""
|
||||
return os.path.join(self._resolved_sqlite_dir, "app.db")
|
||||
"""SQLite file path for application ORM data (alias for sqlite_path)."""
|
||||
return self.sqlite_path
|
||||
|
||||
@property
|
||||
def app_sqlalchemy_url(self) -> str:
|
||||
"""SQLAlchemy async URL for the application ORM engine."""
|
||||
if self.backend == "sqlite":
|
||||
return f"sqlite+aiosqlite:///{self.app_sqlite_path}"
|
||||
return f"sqlite+aiosqlite:///{self.sqlite_path}"
|
||||
if self.backend == "postgres":
|
||||
url = self.postgres_url
|
||||
if url.startswith("postgresql://"):
|
||||
|
||||
@@ -14,8 +14,9 @@ class MemoryConfig(BaseModel):
|
||||
default="",
|
||||
description=(
|
||||
"Path to store memory data. "
|
||||
"If empty, defaults to `{base_dir}/memory.json` (see Paths.memory_file). "
|
||||
"Absolute paths are used as-is. "
|
||||
"If empty, defaults to per-user memory at `{base_dir}/users/{user_id}/memory.json`. "
|
||||
"Absolute paths are used as-is and opt out of per-user isolation "
|
||||
"(all users share the same file). "
|
||||
"Relative paths are resolved against `Paths.base_dir` "
|
||||
"(not the backend working directory). "
|
||||
"Note: if you previously set this to `.deer-flow/memory.json`, "
|
||||
|
||||
@@ -7,6 +7,7 @@ from pathlib import Path, PureWindowsPath
|
||||
VIRTUAL_PATH_PREFIX = "/mnt/user-data"
|
||||
|
||||
_SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
||||
_SAFE_USER_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
||||
|
||||
|
||||
def _default_local_base_dir() -> Path:
|
||||
@@ -22,6 +23,13 @@ def _validate_thread_id(thread_id: str) -> str:
|
||||
return thread_id
|
||||
|
||||
|
||||
def _validate_user_id(user_id: str) -> str:
|
||||
"""Validate a user ID before using it in filesystem paths."""
|
||||
if not _SAFE_USER_ID_RE.match(user_id):
|
||||
raise ValueError(f"Invalid user_id {user_id!r}: only alphanumeric characters, hyphens, and underscores are allowed.")
|
||||
return user_id
|
||||
|
||||
|
||||
def _join_host_path(base: str, *parts: str) -> str:
|
||||
"""Join host filesystem path segments while preserving native style.
|
||||
|
||||
@@ -134,44 +142,63 @@ class Paths:
|
||||
"""Per-agent memory file: `{base_dir}/agents/{name}/memory.json`."""
|
||||
return self.agent_dir(name) / "memory.json"
|
||||
|
||||
def thread_dir(self, thread_id: str) -> Path:
|
||||
def user_dir(self, user_id: str) -> Path:
|
||||
"""Directory for a specific user: `{base_dir}/users/{user_id}/`."""
|
||||
return self.base_dir / "users" / _validate_user_id(user_id)
|
||||
|
||||
def user_memory_file(self, user_id: str) -> Path:
|
||||
"""Per-user memory file: `{base_dir}/users/{user_id}/memory.json`."""
|
||||
return self.user_dir(user_id) / "memory.json"
|
||||
|
||||
def user_agent_memory_file(self, user_id: str, agent_name: str) -> Path:
|
||||
"""Per-user per-agent memory: `{base_dir}/users/{user_id}/agents/{name}/memory.json`."""
|
||||
return self.user_dir(user_id) / "agents" / agent_name.lower() / "memory.json"
|
||||
|
||||
def thread_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
|
||||
"""
|
||||
Host path for a thread's data: `{base_dir}/threads/{thread_id}/`
|
||||
Host path for a thread's data.
|
||||
|
||||
When *user_id* is provided:
|
||||
`{base_dir}/users/{user_id}/threads/{thread_id}/`
|
||||
Otherwise (legacy layout):
|
||||
`{base_dir}/threads/{thread_id}/`
|
||||
|
||||
This directory contains a `user-data/` subdirectory that is mounted
|
||||
as `/mnt/user-data/` inside the sandbox.
|
||||
|
||||
Raises:
|
||||
ValueError: If `thread_id` contains unsafe characters (path separators
|
||||
or `..`) that could cause directory traversal.
|
||||
ValueError: If `thread_id` or `user_id` contains unsafe characters (path
|
||||
separators or `..`) that could cause directory traversal.
|
||||
"""
|
||||
if user_id is not None:
|
||||
return self.user_dir(user_id) / "threads" / _validate_thread_id(thread_id)
|
||||
return self.base_dir / "threads" / _validate_thread_id(thread_id)
|
||||
|
||||
def sandbox_work_dir(self, thread_id: str) -> Path:
|
||||
def sandbox_work_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
|
||||
"""
|
||||
Host path for the agent's workspace directory.
|
||||
Host: `{base_dir}/threads/{thread_id}/user-data/workspace/`
|
||||
Sandbox: `/mnt/user-data/workspace/`
|
||||
"""
|
||||
return self.thread_dir(thread_id) / "user-data" / "workspace"
|
||||
return self.thread_dir(thread_id, user_id=user_id) / "user-data" / "workspace"
|
||||
|
||||
def sandbox_uploads_dir(self, thread_id: str) -> Path:
|
||||
def sandbox_uploads_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
|
||||
"""
|
||||
Host path for user-uploaded files.
|
||||
Host: `{base_dir}/threads/{thread_id}/user-data/uploads/`
|
||||
Sandbox: `/mnt/user-data/uploads/`
|
||||
"""
|
||||
return self.thread_dir(thread_id) / "user-data" / "uploads"
|
||||
return self.thread_dir(thread_id, user_id=user_id) / "user-data" / "uploads"
|
||||
|
||||
def sandbox_outputs_dir(self, thread_id: str) -> Path:
|
||||
def sandbox_outputs_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
|
||||
"""
|
||||
Host path for agent-generated artifacts.
|
||||
Host: `{base_dir}/threads/{thread_id}/user-data/outputs/`
|
||||
Sandbox: `/mnt/user-data/outputs/`
|
||||
"""
|
||||
return self.thread_dir(thread_id) / "user-data" / "outputs"
|
||||
return self.thread_dir(thread_id, user_id=user_id) / "user-data" / "outputs"
|
||||
|
||||
def acp_workspace_dir(self, thread_id: str) -> Path:
|
||||
def acp_workspace_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
|
||||
"""
|
||||
Host path for the ACP workspace of a specific thread.
|
||||
Host: `{base_dir}/threads/{thread_id}/acp-workspace/`
|
||||
@@ -180,41 +207,43 @@ class Paths:
|
||||
Each thread gets its own isolated ACP workspace so that concurrent
|
||||
sessions cannot read each other's ACP agent outputs.
|
||||
"""
|
||||
return self.thread_dir(thread_id) / "acp-workspace"
|
||||
return self.thread_dir(thread_id, user_id=user_id) / "acp-workspace"
|
||||
|
||||
def sandbox_user_data_dir(self, thread_id: str) -> Path:
|
||||
def sandbox_user_data_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
|
||||
"""
|
||||
Host path for the user-data root.
|
||||
Host: `{base_dir}/threads/{thread_id}/user-data/`
|
||||
Sandbox: `/mnt/user-data/`
|
||||
"""
|
||||
return self.thread_dir(thread_id) / "user-data"
|
||||
return self.thread_dir(thread_id, user_id=user_id) / "user-data"
|
||||
|
||||
def host_thread_dir(self, thread_id: str) -> str:
|
||||
def host_thread_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
|
||||
"""Host path for a thread directory, preserving Windows path syntax."""
|
||||
if user_id is not None:
|
||||
return _join_host_path(self._host_base_dir_str(), "users", _validate_user_id(user_id), "threads", _validate_thread_id(thread_id))
|
||||
return _join_host_path(self._host_base_dir_str(), "threads", _validate_thread_id(thread_id))
|
||||
|
||||
def host_sandbox_user_data_dir(self, thread_id: str) -> str:
|
||||
def host_sandbox_user_data_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
|
||||
"""Host path for a thread's user-data root."""
|
||||
return _join_host_path(self.host_thread_dir(thread_id), "user-data")
|
||||
return _join_host_path(self.host_thread_dir(thread_id, user_id=user_id), "user-data")
|
||||
|
||||
def host_sandbox_work_dir(self, thread_id: str) -> str:
|
||||
def host_sandbox_work_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
|
||||
"""Host path for the workspace mount source."""
|
||||
return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "workspace")
|
||||
return _join_host_path(self.host_sandbox_user_data_dir(thread_id, user_id=user_id), "workspace")
|
||||
|
||||
def host_sandbox_uploads_dir(self, thread_id: str) -> str:
|
||||
def host_sandbox_uploads_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
|
||||
"""Host path for the uploads mount source."""
|
||||
return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "uploads")
|
||||
return _join_host_path(self.host_sandbox_user_data_dir(thread_id, user_id=user_id), "uploads")
|
||||
|
||||
def host_sandbox_outputs_dir(self, thread_id: str) -> str:
|
||||
def host_sandbox_outputs_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
|
||||
"""Host path for the outputs mount source."""
|
||||
return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "outputs")
|
||||
return _join_host_path(self.host_sandbox_user_data_dir(thread_id, user_id=user_id), "outputs")
|
||||
|
||||
def host_acp_workspace_dir(self, thread_id: str) -> str:
|
||||
def host_acp_workspace_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
|
||||
"""Host path for the ACP workspace mount source."""
|
||||
return _join_host_path(self.host_thread_dir(thread_id), "acp-workspace")
|
||||
return _join_host_path(self.host_thread_dir(thread_id, user_id=user_id), "acp-workspace")
|
||||
|
||||
def ensure_thread_dirs(self, thread_id: str) -> None:
|
||||
def ensure_thread_dirs(self, thread_id: str, *, user_id: str | None = None) -> None:
|
||||
"""Create all standard sandbox directories for a thread.
|
||||
|
||||
Directories are created with mode 0o777 so that sandbox containers
|
||||
@@ -228,24 +257,24 @@ class Paths:
|
||||
ACP agent invocation.
|
||||
"""
|
||||
for d in [
|
||||
self.sandbox_work_dir(thread_id),
|
||||
self.sandbox_uploads_dir(thread_id),
|
||||
self.sandbox_outputs_dir(thread_id),
|
||||
self.acp_workspace_dir(thread_id),
|
||||
self.sandbox_work_dir(thread_id, user_id=user_id),
|
||||
self.sandbox_uploads_dir(thread_id, user_id=user_id),
|
||||
self.sandbox_outputs_dir(thread_id, user_id=user_id),
|
||||
self.acp_workspace_dir(thread_id, user_id=user_id),
|
||||
]:
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
d.chmod(0o777)
|
||||
|
||||
def delete_thread_dir(self, thread_id: str) -> None:
|
||||
def delete_thread_dir(self, thread_id: str, *, user_id: str | None = None) -> None:
|
||||
"""Delete all persisted data for a thread.
|
||||
|
||||
The operation is idempotent: missing thread directories are ignored.
|
||||
"""
|
||||
thread_dir = self.thread_dir(thread_id)
|
||||
thread_dir = self.thread_dir(thread_id, user_id=user_id)
|
||||
if thread_dir.exists():
|
||||
shutil.rmtree(thread_dir)
|
||||
|
||||
def resolve_virtual_path(self, thread_id: str, virtual_path: str) -> Path:
|
||||
def resolve_virtual_path(self, thread_id: str, virtual_path: str, *, user_id: str | None = None) -> Path:
|
||||
"""Resolve a sandbox virtual path to the actual host filesystem path.
|
||||
|
||||
Args:
|
||||
@@ -253,6 +282,7 @@ class Paths:
|
||||
virtual_path: Virtual path as seen inside the sandbox, e.g.
|
||||
``/mnt/user-data/outputs/report.pdf``.
|
||||
Leading slashes are stripped before matching.
|
||||
user_id: Optional user ID for user-scoped path resolution.
|
||||
|
||||
Returns:
|
||||
The resolved absolute host filesystem path.
|
||||
@@ -270,7 +300,7 @@ class Paths:
|
||||
raise ValueError(f"Path must start with /{prefix}")
|
||||
|
||||
relative = stripped[len(prefix) :].lstrip("/")
|
||||
base = self.sandbox_user_data_dir(thread_id).resolve()
|
||||
base = self.sandbox_user_data_dir(thread_id, user_id=user_id).resolve()
|
||||
actual = (base / relative).resolve()
|
||||
|
||||
try:
|
||||
|
||||
@@ -98,6 +98,11 @@ async def init_engine(
|
||||
# SQLite deployment (TC-UPG-06 in AUTH_TEST_PLAN.md). The companion
|
||||
# ``synchronous=NORMAL`` is the safe-and-fast pairing — fsync only
|
||||
# at WAL checkpoint boundaries instead of every commit.
|
||||
# Note: we do not set PRAGMA busy_timeout here — Python's sqlite3
|
||||
# driver already defaults to a 5-second busy timeout (see the
|
||||
# ``timeout`` kwarg of ``sqlite3.connect``), and aiosqlite /
|
||||
# SQLAlchemy's aiosqlite dialect inherit that default. Setting
|
||||
# it again would be a no-op.
|
||||
@event.listens_for(_engine.sync_engine, "connect")
|
||||
def _enable_sqlite_wal(dbapi_conn, _record): # noqa: ARG001 — SQLAlchemy contract
|
||||
cursor = dbapi_conn.cursor()
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import DateTime, String, Text
|
||||
from sqlalchemy import DateTime, String, Text, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
@@ -13,10 +13,14 @@ from deerflow.persistence.base import Base
|
||||
class FeedbackRow(Base):
|
||||
__tablename__ = "feedback"
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("thread_id", "run_id", "user_id", name="uq_feedback_thread_run_user"),
|
||||
)
|
||||
|
||||
feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
run_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
message_id: Mapped[str | None] = mapped_column(String(64))
|
||||
# message_id is an optional RunEventStore event identifier —
|
||||
# allows feedback to target a specific message or the entire run
|
||||
|
||||
@@ -12,7 +12,7 @@ from sqlalchemy import case, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.feedback.model import FeedbackRow
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
||||
|
||||
|
||||
class FeedbackRepository:
|
||||
@@ -33,19 +33,19 @@ class FeedbackRepository:
|
||||
run_id: str,
|
||||
thread_id: str,
|
||||
rating: int,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
message_id: str | None = None,
|
||||
comment: str | None = None,
|
||||
) -> dict:
|
||||
"""Create a feedback record. rating must be +1 or -1."""
|
||||
if rating not in (1, -1):
|
||||
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.create")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.create")
|
||||
row = FeedbackRow(
|
||||
feedback_id=str(uuid.uuid4()),
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
owner_id=resolved_owner_id,
|
||||
user_id=resolved_user_id,
|
||||
message_id=message_id,
|
||||
rating=rating,
|
||||
comment=comment,
|
||||
@@ -61,14 +61,14 @@ class FeedbackRepository:
|
||||
self,
|
||||
feedback_id: str,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> dict | None:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.get")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.get")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(FeedbackRow, feedback_id)
|
||||
if row is None:
|
||||
return None
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||
return None
|
||||
return self._row_to_dict(row)
|
||||
|
||||
@@ -78,12 +78,12 @@ class FeedbackRepository:
|
||||
run_id: str,
|
||||
*,
|
||||
limit: int = 100,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> list[dict]:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_run")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_run")
|
||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id)
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(FeedbackRow.user_id == resolved_user_id)
|
||||
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
@@ -94,12 +94,12 @@ class FeedbackRepository:
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 100,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> list[dict]:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_thread")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_thread")
|
||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id)
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(FeedbackRow.user_id == resolved_user_id)
|
||||
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
@@ -109,19 +109,97 @@ class FeedbackRepository:
|
||||
self,
|
||||
feedback_id: str,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> bool:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.delete")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.delete")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(FeedbackRow, feedback_id)
|
||||
if row is None:
|
||||
return False
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||
return False
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
thread_id: str,
|
||||
rating: int,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
comment: str | None = None,
|
||||
) -> dict:
|
||||
"""Create or update feedback for (thread_id, run_id, user_id). rating must be +1 or -1."""
|
||||
if rating not in (1, -1):
|
||||
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.upsert")
|
||||
async with self._sf() as session:
|
||||
stmt = select(FeedbackRow).where(
|
||||
FeedbackRow.thread_id == thread_id,
|
||||
FeedbackRow.run_id == run_id,
|
||||
FeedbackRow.user_id == resolved_user_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
row = result.scalar_one_or_none()
|
||||
if row is not None:
|
||||
row.rating = rating
|
||||
row.comment = comment
|
||||
row.created_at = datetime.now(UTC)
|
||||
else:
|
||||
row = FeedbackRow(
|
||||
feedback_id=str(uuid.uuid4()),
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
user_id=resolved_user_id,
|
||||
rating=rating,
|
||||
comment=comment,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def delete_by_run(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> bool:
|
||||
"""Delete the current user's feedback for a run. Returns True if a record was deleted."""
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.delete_by_run")
|
||||
async with self._sf() as session:
|
||||
stmt = select(FeedbackRow).where(
|
||||
FeedbackRow.thread_id == thread_id,
|
||||
FeedbackRow.run_id == run_id,
|
||||
FeedbackRow.user_id == resolved_user_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
return False
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
async def list_by_thread_grouped(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> dict[str, dict]:
|
||||
"""Return feedback grouped by run_id for a thread: {run_id: feedback_dict}."""
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_thread_grouped")
|
||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(FeedbackRow.user_id == resolved_user_id)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return {row.run_id: self._row_to_dict(row) for row in result.scalars()}
|
||||
|
||||
async def aggregate_by_run(self, thread_id: str, run_id: str) -> dict:
|
||||
"""Aggregate feedback stats for a run using database-side counting."""
|
||||
stmt = select(
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
script_location = %(here)s
|
||||
# Default URL for offline mode / autogenerate.
|
||||
# Runtime uses engine from DeerFlow config.
|
||||
sqlalchemy.url = sqlite+aiosqlite:///./data/app.db
|
||||
sqlalchemy.url = sqlite+aiosqlite:///./data/deerflow.db
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
@@ -19,7 +19,7 @@ class RunEventRow(Base):
|
||||
# Owner of the conversation this event belongs to. Nullable for data
|
||||
# created before auth was introduced; populated by auth middleware on
|
||||
# new writes and by the boot-time orphan migration on existing rows.
|
||||
owner_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
|
||||
event_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
category: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
# "message" | "trace" | "lifecycle"
|
||||
|
||||
@@ -16,7 +16,7 @@ class RunRow(Base):
|
||||
run_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
assistant_id: Mapped[str | None] = mapped_column(String(128))
|
||||
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
status: Mapped[str] = mapped_column(String(20), default="pending")
|
||||
# "pending" | "running" | "success" | "error" | "timeout" | "interrupted"
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.run.model import RunRow
|
||||
from deerflow.runtime.runs.store.base import RunStore
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
||||
|
||||
|
||||
class RunRepository(RunStore):
|
||||
@@ -69,7 +69,7 @@ class RunRepository(RunStore):
|
||||
*,
|
||||
thread_id,
|
||||
assistant_id=None,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
status="pending",
|
||||
multitask_strategy="reject",
|
||||
metadata=None,
|
||||
@@ -78,13 +78,13 @@ class RunRepository(RunStore):
|
||||
created_at=None,
|
||||
follow_up_to_run_id=None,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.put")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.put")
|
||||
now = datetime.now(UTC)
|
||||
row = RunRow(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
owner_id=resolved_owner_id,
|
||||
user_id=resolved_user_id,
|
||||
status=status,
|
||||
multitask_strategy=multitask_strategy,
|
||||
metadata_json=self._safe_json(metadata) or {},
|
||||
@@ -102,14 +102,14 @@ class RunRepository(RunStore):
|
||||
self,
|
||||
run_id,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.get")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.get")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(RunRow, run_id)
|
||||
if row is None:
|
||||
return None
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||
return None
|
||||
return self._row_to_dict(row)
|
||||
|
||||
@@ -117,13 +117,13 @@ class RunRepository(RunStore):
|
||||
self,
|
||||
thread_id,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
limit=100,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.list_by_thread")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.list_by_thread")
|
||||
stmt = select(RunRow).where(RunRow.thread_id == thread_id)
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunRow.owner_id == resolved_owner_id)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(RunRow.user_id == resolved_user_id)
|
||||
stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
@@ -141,14 +141,14 @@ class RunRepository(RunStore):
|
||||
self,
|
||||
run_id,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.delete")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.delete")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(RunRow, run_id)
|
||||
if row is None:
|
||||
return
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||
return
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
|
||||
@@ -1,13 +1,38 @@
|
||||
"""Thread metadata persistence — ORM, abstract store, and concrete implementations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
|
||||
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
||||
from deerflow.persistence.thread_meta.sql import ThreadMetaRepository
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langgraph.store.base import BaseStore
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
__all__ = [
|
||||
"MemoryThreadMetaStore",
|
||||
"ThreadMetaRepository",
|
||||
"ThreadMetaRow",
|
||||
"ThreadMetaStore",
|
||||
"make_thread_store",
|
||||
]
|
||||
|
||||
|
||||
def make_thread_store(
|
||||
session_factory: async_sessionmaker[AsyncSession] | None,
|
||||
store: BaseStore | None = None,
|
||||
) -> ThreadMetaStore:
|
||||
"""Create the appropriate ThreadMetaStore based on available backends.
|
||||
|
||||
Returns a SQL-backed repository when a session factory is available,
|
||||
otherwise falls back to the in-memory LangGraph Store implementation.
|
||||
"""
|
||||
if session_factory is not None:
|
||||
return ThreadMetaRepository(session_factory)
|
||||
if store is None:
|
||||
raise ValueError("make_thread_store requires either a session_factory (SQL) or a store (memory)")
|
||||
return MemoryThreadMetaStore(store)
|
||||
|
||||
@@ -3,12 +3,21 @@
|
||||
Implementations:
|
||||
- ThreadMetaRepository: SQL-backed (sqlite / postgres via SQLAlchemy)
|
||||
- MemoryThreadMetaStore: wraps LangGraph BaseStore (memory mode)
|
||||
|
||||
All mutating and querying methods accept a ``user_id`` parameter with
|
||||
three-state semantics (see :mod:`deerflow.runtime.user_context`):
|
||||
|
||||
- ``AUTO`` (default): resolve from the request-scoped contextvar.
|
||||
- Explicit ``str``: use the provided value verbatim.
|
||||
- Explicit ``None``: bypass owner filtering (migration/CLI only).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel
|
||||
|
||||
|
||||
class ThreadMetaStore(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
@@ -17,14 +26,14 @@ class ThreadMetaStore(abc.ABC):
|
||||
thread_id: str,
|
||||
*,
|
||||
assistant_id: str | None = None,
|
||||
owner_id: str | None = None,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
display_name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> dict:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get(self, thread_id: str) -> dict | None:
|
||||
async def get(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> dict | None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -35,26 +44,33 @@ class ThreadMetaStore(abc.ABC):
|
||||
status: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> list[dict]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def update_display_name(self, thread_id: str, display_name: str) -> None:
|
||||
async def update_display_name(self, thread_id: str, display_name: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def update_status(self, thread_id: str, status: str) -> None:
|
||||
async def update_status(self, thread_id: str, status: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def update_metadata(self, thread_id: str, metadata: dict) -> None:
|
||||
async def update_metadata(self, thread_id: str, metadata: dict, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
"""Merge ``metadata`` into the thread's metadata field.
|
||||
|
||||
Existing keys are overwritten by the new values; keys absent from
|
||||
``metadata`` are preserved. No-op if the thread does not exist.
|
||||
``metadata`` are preserved. No-op if the thread does not exist
|
||||
or the owner check fails.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete(self, thread_id: str) -> None:
|
||||
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
|
||||
"""Check if ``user_id`` has access to ``thread_id``."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
pass
|
||||
|
||||
@@ -13,6 +13,7 @@ from typing import Any
|
||||
from langgraph.store.base import BaseStore
|
||||
|
||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
||||
|
||||
THREADS_NS: tuple[str, ...] = ("threads",)
|
||||
|
||||
@@ -21,20 +22,37 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
def __init__(self, store: BaseStore) -> None:
|
||||
self._store = store
|
||||
|
||||
async def _get_owned_record(
|
||||
self,
|
||||
thread_id: str,
|
||||
user_id: str | None | _AutoSentinel,
|
||||
method_name: str,
|
||||
) -> dict | None:
|
||||
"""Fetch a record and verify ownership. Returns a mutable copy, or None."""
|
||||
resolved = resolve_user_id(user_id, method_name=method_name)
|
||||
item = await self._store.aget(THREADS_NS, thread_id)
|
||||
if item is None:
|
||||
return None
|
||||
record = dict(item.value)
|
||||
if resolved is not None and record.get("user_id") != resolved:
|
||||
return None
|
||||
return record
|
||||
|
||||
async def create(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
assistant_id: str | None = None,
|
||||
owner_id: str | None = None,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
display_name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> dict:
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.create")
|
||||
now = time.time()
|
||||
record: dict[str, Any] = {
|
||||
"thread_id": thread_id,
|
||||
"assistant_id": assistant_id,
|
||||
"owner_id": owner_id,
|
||||
"user_id": resolved_user_id,
|
||||
"display_name": display_name,
|
||||
"status": "idle",
|
||||
"metadata": metadata or {},
|
||||
@@ -45,9 +63,8 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
return record
|
||||
|
||||
async def get(self, thread_id: str) -> dict | None:
|
||||
item = await self._store.aget(THREADS_NS, thread_id)
|
||||
return item.value if item is not None else None
|
||||
async def get(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> dict | None:
|
||||
return await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.get")
|
||||
|
||||
async def search(
|
||||
self,
|
||||
@@ -56,12 +73,16 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
status: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> list[dict]:
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.search")
|
||||
filter_dict: dict[str, Any] = {}
|
||||
if metadata:
|
||||
filter_dict.update(metadata)
|
||||
if status:
|
||||
filter_dict["status"] = status
|
||||
if resolved_user_id is not None:
|
||||
filter_dict["user_id"] = resolved_user_id
|
||||
|
||||
items = await self._store.asearch(
|
||||
THREADS_NS,
|
||||
@@ -71,37 +92,45 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
)
|
||||
return [self._item_to_dict(item) for item in items]
|
||||
|
||||
async def update_display_name(self, thread_id: str, display_name: str) -> None:
|
||||
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
|
||||
item = await self._store.aget(THREADS_NS, thread_id)
|
||||
if item is None:
|
||||
return not require_existing
|
||||
record_user_id = item.value.get("user_id")
|
||||
if record_user_id is None:
|
||||
return True
|
||||
return record_user_id == user_id
|
||||
|
||||
async def update_display_name(self, thread_id: str, display_name: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_display_name")
|
||||
if record is None:
|
||||
return
|
||||
record = dict(item.value)
|
||||
record["display_name"] = display_name
|
||||
record["updated_at"] = time.time()
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
|
||||
async def update_status(self, thread_id: str, status: str) -> None:
|
||||
item = await self._store.aget(THREADS_NS, thread_id)
|
||||
if item is None:
|
||||
async def update_status(self, thread_id: str, status: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_status")
|
||||
if record is None:
|
||||
return
|
||||
record = dict(item.value)
|
||||
record["status"] = status
|
||||
record["updated_at"] = time.time()
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
|
||||
async def update_metadata(self, thread_id: str, metadata: dict) -> None:
|
||||
"""Merge ``metadata`` into the in-memory record. No-op if absent."""
|
||||
item = await self._store.aget(THREADS_NS, thread_id)
|
||||
if item is None:
|
||||
async def update_metadata(self, thread_id: str, metadata: dict, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_metadata")
|
||||
if record is None:
|
||||
return
|
||||
record = dict(item.value)
|
||||
merged = dict(record.get("metadata") or {})
|
||||
merged.update(metadata)
|
||||
record["metadata"] = merged
|
||||
record["updated_at"] = time.time()
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
|
||||
async def delete(self, thread_id: str) -> None:
|
||||
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.delete")
|
||||
if record is None:
|
||||
return
|
||||
await self._store.adelete(THREADS_NS, thread_id)
|
||||
|
||||
@staticmethod
|
||||
@@ -111,7 +140,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
return {
|
||||
"thread_id": item.key,
|
||||
"assistant_id": val.get("assistant_id"),
|
||||
"owner_id": val.get("owner_id"),
|
||||
"user_id": val.get("user_id"),
|
||||
"display_name": val.get("display_name"),
|
||||
"status": val.get("status", "idle"),
|
||||
"metadata": val.get("metadata", {}),
|
||||
|
||||
@@ -15,7 +15,7 @@ class ThreadMetaRow(Base):
|
||||
|
||||
thread_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
assistant_id: Mapped[str | None] = mapped_column(String(128), index=True)
|
||||
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
display_name: Mapped[str | None] = mapped_column(String(256))
|
||||
status: Mapped[str] = mapped_column(String(20), default="idle")
|
||||
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
|
||||
@@ -10,7 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
||||
|
||||
|
||||
class ThreadMetaRepository(ThreadMetaStore):
|
||||
@@ -32,18 +32,18 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
thread_id: str,
|
||||
*,
|
||||
assistant_id: str | None = None,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
display_name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> dict:
|
||||
# Auto-resolve owner_id from contextvar when AUTO; explicit None
|
||||
# Auto-resolve user_id from contextvar when AUTO; explicit None
|
||||
# creates an orphan row (used by migration scripts).
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.create")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.create")
|
||||
now = datetime.now(UTC)
|
||||
row = ThreadMetaRow(
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
owner_id=resolved_owner_id,
|
||||
user_id=resolved_user_id,
|
||||
display_name=display_name,
|
||||
metadata_json=metadata or {},
|
||||
created_at=now,
|
||||
@@ -59,40 +59,34 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> dict | None:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.get")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.get")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is None:
|
||||
return None
|
||||
# Enforce owner filter unless explicitly bypassed (owner_id=None).
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
# Enforce owner filter unless explicitly bypassed (user_id=None).
|
||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||
return None
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def list_by_owner(self, owner_id: str, *, limit: int = 100, offset: int = 0) -> list[dict]:
|
||||
stmt = select(ThreadMetaRow).where(ThreadMetaRow.owner_id == owner_id).order_by(ThreadMetaRow.updated_at.desc()).limit(limit).offset(offset)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def check_access(self, thread_id: str, owner_id: str, *, require_existing: bool = False) -> bool:
|
||||
"""Check if ``owner_id`` has access to ``thread_id``.
|
||||
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
|
||||
"""Check if ``user_id`` has access to ``thread_id``.
|
||||
|
||||
Two modes — one row, two distinct semantics depending on what
|
||||
the caller is about to do:
|
||||
|
||||
- ``require_existing=False`` (default, permissive):
|
||||
Returns True for: row missing (untracked legacy thread),
|
||||
``row.owner_id`` is None (shared / pre-auth data),
|
||||
or ``row.owner_id == owner_id``. Use for **read-style**
|
||||
``row.user_id`` is None (shared / pre-auth data),
|
||||
or ``row.user_id == user_id``. Use for **read-style**
|
||||
decorators where treating an untracked thread as accessible
|
||||
preserves backward-compat.
|
||||
|
||||
- ``require_existing=True`` (strict):
|
||||
Returns True **only** when the row exists AND
|
||||
(``row.owner_id == owner_id`` OR ``row.owner_id is None``).
|
||||
(``row.user_id == user_id`` OR ``row.user_id is None``).
|
||||
Use for **destructive / mutating** decorators (DELETE, PATCH,
|
||||
state-update) so a thread that has *already been deleted*
|
||||
cannot be re-targeted by any caller — closing the
|
||||
@@ -103,9 +97,9 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is None:
|
||||
return not require_existing
|
||||
if row.owner_id is None:
|
||||
if row.user_id is None:
|
||||
return True
|
||||
return row.owner_id == owner_id
|
||||
return row.user_id == user_id
|
||||
|
||||
async def search(
|
||||
self,
|
||||
@@ -114,17 +108,17 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
status: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> list[dict]:
|
||||
"""Search threads with optional metadata and status filters.
|
||||
|
||||
Owner filter is enforced by default: caller must be in a user
|
||||
context. Pass ``owner_id=None`` to bypass (migration/CLI).
|
||||
context. Pass ``user_id=None`` to bypass (migration/CLI).
|
||||
"""
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.search")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.search")
|
||||
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc())
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(ThreadMetaRow.owner_id == resolved_owner_id)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(ThreadMetaRow.user_id == resolved_user_id)
|
||||
if status:
|
||||
stmt = stmt.where(ThreadMetaRow.status == status)
|
||||
|
||||
@@ -144,24 +138,24 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_owner_id: str | None) -> bool:
|
||||
async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_user_id: str | None) -> bool:
|
||||
"""Return True if the row exists and is owned (or filter bypassed)."""
|
||||
if resolved_owner_id is None:
|
||||
if resolved_user_id is None:
|
||||
return True # explicit bypass
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
return row is not None and row.owner_id == resolved_owner_id
|
||||
return row is not None and row.user_id == resolved_user_id
|
||||
|
||||
async def update_display_name(
|
||||
self,
|
||||
thread_id: str,
|
||||
display_name: str,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> None:
|
||||
"""Update the display_name (title) for a thread."""
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_display_name")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_display_name")
|
||||
async with self._sf() as session:
|
||||
if not await self._check_ownership(session, thread_id, resolved_owner_id):
|
||||
if not await self._check_ownership(session, thread_id, resolved_user_id):
|
||||
return
|
||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(display_name=display_name, updated_at=datetime.now(UTC)))
|
||||
await session.commit()
|
||||
@@ -171,11 +165,11 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
thread_id: str,
|
||||
status: str,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> None:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_status")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_status")
|
||||
async with self._sf() as session:
|
||||
if not await self._check_ownership(session, thread_id, resolved_owner_id):
|
||||
if not await self._check_ownership(session, thread_id, resolved_user_id):
|
||||
return
|
||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC)))
|
||||
await session.commit()
|
||||
@@ -185,20 +179,20 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
thread_id: str,
|
||||
metadata: dict,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> None:
|
||||
"""Merge ``metadata`` into ``metadata_json``.
|
||||
|
||||
Read-modify-write inside a single session/transaction so concurrent
|
||||
callers see consistent state. No-op if the row does not exist or
|
||||
the owner_id check fails.
|
||||
the user_id check fails.
|
||||
"""
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_metadata")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_metadata")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is None:
|
||||
return
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||
return
|
||||
merged = dict(row.metadata_json or {})
|
||||
merged.update(metadata)
|
||||
@@ -210,14 +204,14 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> None:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.delete")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.delete")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is None:
|
||||
return
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||
return
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
|
||||
@@ -5,12 +5,18 @@ Re-exports the public API of :mod:`~deerflow.runtime.runs` and
|
||||
directly from ``deerflow.runtime``.
|
||||
"""
|
||||
|
||||
from .checkpointer import checkpointer_context, get_checkpointer, make_checkpointer, reset_checkpointer
|
||||
from .runs import ConflictError, DisconnectMode, RunContext, RunManager, RunRecord, RunStatus, UnsupportedStrategyError, run_agent
|
||||
from .serialization import serialize, serialize_channel_values, serialize_lc_object, serialize_messages_tuple
|
||||
from .store import get_store, make_store, reset_store, store_context
|
||||
from .stream_bridge import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, StreamBridge, StreamEvent, make_stream_bridge
|
||||
|
||||
__all__ = [
|
||||
# checkpointer
|
||||
"checkpointer_context",
|
||||
"get_checkpointer",
|
||||
"make_checkpointer",
|
||||
"reset_checkpointer",
|
||||
# runs
|
||||
"ConflictError",
|
||||
"DisconnectMode",
|
||||
|
||||
+4
-4
@@ -7,12 +7,12 @@ Supported backends: memory, sqlite, postgres.
|
||||
|
||||
Usage (e.g. FastAPI lifespan)::
|
||||
|
||||
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
||||
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
||||
|
||||
async with make_checkpointer() as checkpointer:
|
||||
app.state.checkpointer = checkpointer # InMemorySaver if not configured
|
||||
|
||||
For sync usage see :mod:`deerflow.agents.checkpointer.provider`.
|
||||
For sync usage see :mod:`deerflow.runtime.checkpointer.provider`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -24,12 +24,12 @@ from collections.abc import AsyncIterator
|
||||
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from deerflow.agents.checkpointer.provider import (
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.runtime.checkpointer.provider import (
|
||||
POSTGRES_CONN_REQUIRED,
|
||||
POSTGRES_INSTALL,
|
||||
SQLITE_INSTALL,
|
||||
)
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
+1
-1
@@ -7,7 +7,7 @@ Supported backends: memory, sqlite, postgres.
|
||||
|
||||
Usage::
|
||||
|
||||
from deerflow.agents.checkpointer.provider import get_checkpointer, checkpointer_context
|
||||
from deerflow.runtime.checkpointer.provider import get_checkpointer, checkpointer_context
|
||||
|
||||
# Singleton — reused across calls, closed on process exit
|
||||
cp = get_checkpointer()
|
||||
@@ -83,8 +83,18 @@ class RunEventStore(abc.ABC):
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
) -> list[dict]:
|
||||
"""Return displayable messages (category=message) for a specific run, ordered by seq ascending."""
|
||||
"""Return displayable messages (category=message) for a specific run, ordered by seq ascending.
|
||||
|
||||
Supports bidirectional cursor pagination:
|
||||
- after_seq: return the first ``limit`` records with seq > after_seq (ascending)
|
||||
- before_seq: return the last ``limit`` records with seq < before_seq (ascending)
|
||||
- neither: return the latest ``limit`` records (ascending)
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def count_messages(self, thread_id: str) -> int:
|
||||
|
||||
@@ -15,7 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.models.run_event import RunEventRow
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_owner_id
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -55,16 +55,22 @@ class DbRunEventStore(RunEventStore):
|
||||
return content, metadata or {}
|
||||
|
||||
@staticmethod
|
||||
def _owner_from_context() -> str | None:
|
||||
"""Soft read of owner_id from contextvar for write paths.
|
||||
def _user_id_from_context() -> str | None:
|
||||
"""Soft read of user_id from contextvar for write paths.
|
||||
|
||||
Returns ``None`` (no filter / no stamp) if contextvar is unset,
|
||||
which is the expected case for background worker writes. HTTP
|
||||
request writes will have the contextvar set by auth middleware
|
||||
and get their user_id stamped automatically.
|
||||
|
||||
Coerces ``user.id`` to ``str`` at the boundary: ``User.id`` is
|
||||
typed as ``UUID`` by the auth layer, but ``run_events.user_id``
|
||||
is ``VARCHAR(64)`` and aiosqlite cannot bind a raw UUID object
|
||||
to a VARCHAR column ("type 'UUID' is not supported") — the
|
||||
INSERT would silently roll back and the worker would hang.
|
||||
"""
|
||||
user = get_current_user()
|
||||
return user.id if user is not None else None
|
||||
return str(user.id) if user is not None else None
|
||||
|
||||
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401
|
||||
"""Write a single event — low-frequency path only.
|
||||
@@ -81,7 +87,7 @@ class DbRunEventStore(RunEventStore):
|
||||
metadata = {**(metadata or {}), "content_is_dict": True}
|
||||
else:
|
||||
db_content = content
|
||||
owner_id = self._owner_from_context()
|
||||
user_id = self._user_id_from_context()
|
||||
async with self._sf() as session:
|
||||
async with session.begin():
|
||||
# Use FOR UPDATE to serialize seq assignment within a thread.
|
||||
@@ -92,7 +98,7 @@ class DbRunEventStore(RunEventStore):
|
||||
row = RunEventRow(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
owner_id=owner_id,
|
||||
user_id=user_id,
|
||||
event_type=event_type,
|
||||
category=category,
|
||||
content=db_content,
|
||||
@@ -106,7 +112,7 @@ class DbRunEventStore(RunEventStore):
|
||||
async def put_batch(self, events):
|
||||
if not events:
|
||||
return []
|
||||
owner_id = self._owner_from_context()
|
||||
user_id = self._user_id_from_context()
|
||||
async with self._sf() as session:
|
||||
async with session.begin():
|
||||
# Get max seq for the thread (assume all events in batch belong to same thread).
|
||||
@@ -130,7 +136,7 @@ class DbRunEventStore(RunEventStore):
|
||||
row = RunEventRow(
|
||||
thread_id=e["thread_id"],
|
||||
run_id=e["run_id"],
|
||||
owner_id=e.get("owner_id", owner_id),
|
||||
user_id=e.get("user_id", user_id),
|
||||
event_type=e["event_type"],
|
||||
category=category,
|
||||
content=db_content,
|
||||
@@ -149,12 +155,12 @@ class DbRunEventStore(RunEventStore):
|
||||
limit=50,
|
||||
before_seq=None,
|
||||
after_seq=None,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages")
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
||||
if before_seq is not None:
|
||||
stmt = stmt.where(RunEventRow.seq < before_seq)
|
||||
if after_seq is not None:
|
||||
@@ -181,12 +187,12 @@ class DbRunEventStore(RunEventStore):
|
||||
*,
|
||||
event_types=None,
|
||||
limit=500,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_events")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_events")
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
||||
if event_types:
|
||||
stmt = stmt.where(RunEventRow.event_type.in_(event_types))
|
||||
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
||||
@@ -199,27 +205,46 @@ class DbRunEventStore(RunEventStore):
|
||||
thread_id,
|
||||
run_id,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
limit=50,
|
||||
before_seq=None,
|
||||
after_seq=None,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages_by_run")
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message")
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||
stmt = stmt.order_by(RunEventRow.seq.asc())
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages_by_run")
|
||||
stmt = select(RunEventRow).where(
|
||||
RunEventRow.thread_id == thread_id,
|
||||
RunEventRow.run_id == run_id,
|
||||
RunEventRow.category == "message",
|
||||
)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
||||
if before_seq is not None:
|
||||
stmt = stmt.where(RunEventRow.seq < before_seq)
|
||||
if after_seq is not None:
|
||||
stmt = stmt.where(RunEventRow.seq > after_seq)
|
||||
|
||||
if after_seq is not None:
|
||||
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
else:
|
||||
stmt = stmt.order_by(RunEventRow.seq.desc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
rows = list(result.scalars())
|
||||
return [self._row_to_dict(r) for r in reversed(rows)]
|
||||
|
||||
async def count_messages(
|
||||
self,
|
||||
thread_id,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.count_messages")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.count_messages")
|
||||
stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
||||
async with self._sf() as session:
|
||||
return await session.scalar(stmt) or 0
|
||||
|
||||
@@ -227,13 +252,13 @@ class DbRunEventStore(RunEventStore):
|
||||
self,
|
||||
thread_id,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_thread")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_thread")
|
||||
async with self._sf() as session:
|
||||
count_conditions = [RunEventRow.thread_id == thread_id]
|
||||
if resolved_owner_id is not None:
|
||||
count_conditions.append(RunEventRow.owner_id == resolved_owner_id)
|
||||
if resolved_user_id is not None:
|
||||
count_conditions.append(RunEventRow.user_id == resolved_user_id)
|
||||
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
||||
count = await session.scalar(count_stmt) or 0
|
||||
if count > 0:
|
||||
@@ -246,13 +271,13 @@ class DbRunEventStore(RunEventStore):
|
||||
thread_id,
|
||||
run_id,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_run")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_run")
|
||||
async with self._sf() as session:
|
||||
count_conditions = [RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id]
|
||||
if resolved_owner_id is not None:
|
||||
count_conditions.append(RunEventRow.owner_id == resolved_owner_id)
|
||||
if resolved_user_id is not None:
|
||||
count_conditions.append(RunEventRow.user_id == resolved_user_id)
|
||||
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
||||
count = await session.scalar(count_stmt) or 0
|
||||
if count > 0:
|
||||
|
||||
@@ -152,9 +152,17 @@ class JsonlRunEventStore(RunEventStore):
|
||||
events = [e for e in events if e.get("event_type") in event_types]
|
||||
return events[:limit]
|
||||
|
||||
async def list_messages_by_run(self, thread_id, run_id):
|
||||
async def list_messages_by_run(self, thread_id, run_id, *, limit=50, before_seq=None, after_seq=None):
|
||||
events = self._read_run_events(thread_id, run_id)
|
||||
return [e for e in events if e.get("category") == "message"]
|
||||
filtered = [e for e in events if e.get("category") == "message"]
|
||||
if before_seq is not None:
|
||||
filtered = [e for e in filtered if e.get("seq", 0) < before_seq]
|
||||
if after_seq is not None:
|
||||
filtered = [e for e in filtered if e.get("seq", 0) > after_seq]
|
||||
if after_seq is not None:
|
||||
return filtered[:limit]
|
||||
else:
|
||||
return filtered[-limit:] if len(filtered) > limit else filtered
|
||||
|
||||
async def count_messages(self, thread_id):
|
||||
all_events = self._read_thread_events(thread_id)
|
||||
|
||||
@@ -97,9 +97,17 @@ class MemoryRunEventStore(RunEventStore):
|
||||
filtered = [e for e in filtered if e["event_type"] in event_types]
|
||||
return filtered[:limit]
|
||||
|
||||
async def list_messages_by_run(self, thread_id, run_id):
|
||||
async def list_messages_by_run(self, thread_id, run_id, *, limit=50, before_seq=None, after_seq=None):
|
||||
all_events = self._events.get(thread_id, [])
|
||||
return [e for e in all_events if e["run_id"] == run_id and e["category"] == "message"]
|
||||
filtered = [e for e in all_events if e["run_id"] == run_id and e["category"] == "message"]
|
||||
if before_seq is not None:
|
||||
filtered = [e for e in filtered if e["seq"] < before_seq]
|
||||
if after_seq is not None:
|
||||
filtered = [e for e in filtered if e["seq"] > after_seq]
|
||||
if after_seq is not None:
|
||||
return filtered[:limit]
|
||||
else:
|
||||
return filtered[-limit:] if len(filtered) > limit else filtered
|
||||
|
||||
async def count_messages(self, thread_id):
|
||||
all_events = self._events.get(thread_id, [])
|
||||
|
||||
@@ -50,6 +50,7 @@ class RunJournal(BaseCallbackHandler):
|
||||
|
||||
# Write buffer
|
||||
self._buffer: list[dict] = []
|
||||
self._pending_flush_tasks: set[asyncio.Task[None]] = set()
|
||||
|
||||
# Token accumulators
|
||||
self._total_input_tokens = 0
|
||||
@@ -245,6 +246,19 @@ class RunJournal(BaseCallbackHandler):
|
||||
|
||||
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
# Tools that update graph state return a ``Command`` (e.g.
|
||||
# ``present_files``). LangGraph later unwraps the inner ToolMessage
|
||||
# into checkpoint state, so to stay checkpoint-aligned we must
|
||||
# extract it here rather than storing ``str(Command(...))``.
|
||||
if isinstance(output, Command):
|
||||
update = getattr(output, "update", None) or {}
|
||||
inner_msgs = update.get("messages") if isinstance(update, dict) else None
|
||||
if isinstance(inner_msgs, list):
|
||||
inner_tool_msg = next((m for m in inner_msgs if isinstance(m, ToolMessage)), None)
|
||||
if inner_tool_msg is not None:
|
||||
output = inner_tool_msg
|
||||
|
||||
# Extract fields from ToolMessage object when LangChain provides one.
|
||||
# LangChain's _format_output wraps tool results into a ToolMessage
|
||||
@@ -381,6 +395,10 @@ class RunJournal(BaseCallbackHandler):
|
||||
"""
|
||||
if not self._buffer:
|
||||
return
|
||||
# Skip if a flush is already in flight — avoids concurrent writes
|
||||
# to the same SQLite file from multiple fire-and-forget tasks.
|
||||
if self._pending_flush_tasks:
|
||||
return
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
@@ -389,6 +407,7 @@ class RunJournal(BaseCallbackHandler):
|
||||
batch = self._buffer.copy()
|
||||
self._buffer.clear()
|
||||
task = loop.create_task(self._flush_async(batch))
|
||||
self._pending_flush_tasks.add(task)
|
||||
task.add_done_callback(self._on_flush_done)
|
||||
|
||||
async def _flush_async(self, batch: list[dict]) -> None:
|
||||
@@ -404,8 +423,8 @@ class RunJournal(BaseCallbackHandler):
|
||||
# Return failed events to buffer for retry on next flush
|
||||
self._buffer = batch + self._buffer
|
||||
|
||||
@staticmethod
|
||||
def _on_flush_done(task: asyncio.Task) -> None:
|
||||
def _on_flush_done(self, task: asyncio.Task) -> None:
|
||||
self._pending_flush_tasks.discard(task)
|
||||
if task.cancelled():
|
||||
return
|
||||
exc = task.exception()
|
||||
@@ -450,10 +469,17 @@ class RunJournal(BaseCallbackHandler):
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Force flush remaining buffer. Called in worker's finally block."""
|
||||
if self._buffer:
|
||||
batch = self._buffer.copy()
|
||||
self._buffer.clear()
|
||||
await self._store.put_batch(batch)
|
||||
if self._pending_flush_tasks:
|
||||
await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True)
|
||||
|
||||
while self._buffer:
|
||||
batch = self._buffer[: self._flush_threshold]
|
||||
del self._buffer[: self._flush_threshold]
|
||||
try:
|
||||
await self._store.put_batch(batch)
|
||||
except Exception:
|
||||
self._buffer = batch + self._buffer
|
||||
raise
|
||||
|
||||
def get_completion_data(self) -> dict:
|
||||
"""Return accumulated token and message data for run completion."""
|
||||
|
||||
@@ -4,8 +4,8 @@ RunManager depends on this interface. Implementations:
|
||||
- MemoryRunStore: in-memory dict (development, tests)
|
||||
- Future: RunRepository backed by SQLAlchemy ORM
|
||||
|
||||
All methods accept an optional owner_id for user isolation.
|
||||
When owner_id is None, no user filtering is applied (single-user mode).
|
||||
All methods accept an optional user_id for user isolation.
|
||||
When user_id is None, no user filtering is applied (single-user mode).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -22,7 +22,7 @@ class RunStore(abc.ABC):
|
||||
*,
|
||||
thread_id: str,
|
||||
assistant_id: str | None = None,
|
||||
owner_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
status: str = "pending",
|
||||
multitask_strategy: str = "reject",
|
||||
metadata: dict[str, Any] | None = None,
|
||||
@@ -42,7 +42,7 @@ class RunStore(abc.ABC):
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
owner_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[dict[str, Any]]:
|
||||
pass
|
||||
|
||||
@@ -21,7 +21,7 @@ class MemoryRunStore(RunStore):
|
||||
*,
|
||||
thread_id,
|
||||
assistant_id=None,
|
||||
owner_id=None,
|
||||
user_id=None,
|
||||
status="pending",
|
||||
multitask_strategy="reject",
|
||||
metadata=None,
|
||||
@@ -35,7 +35,7 @@ class MemoryRunStore(RunStore):
|
||||
"run_id": run_id,
|
||||
"thread_id": thread_id,
|
||||
"assistant_id": assistant_id,
|
||||
"owner_id": owner_id,
|
||||
"user_id": user_id,
|
||||
"status": status,
|
||||
"multitask_strategy": multitask_strategy,
|
||||
"metadata": metadata or {},
|
||||
@@ -49,8 +49,8 @@ class MemoryRunStore(RunStore):
|
||||
async def get(self, run_id):
|
||||
return self._runs.get(run_id)
|
||||
|
||||
async def list_by_thread(self, thread_id, *, owner_id=None, limit=100):
|
||||
results = [r for r in self._runs.values() if r["thread_id"] == thread_id and (owner_id is None or r.get("owner_id") == owner_id)]
|
||||
async def list_by_thread(self, thread_id, *, user_id=None, limit=100):
|
||||
results = [r for r in self._runs.values() if r["thread_id"] == thread_id and (user_id is None or r.get("user_id") == user_id)]
|
||||
results.sort(key=lambda r: r["created_at"], reverse=True)
|
||||
return results[:limit]
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ class RunContext:
|
||||
store: Any | None = field(default=None)
|
||||
event_store: Any | None = field(default=None)
|
||||
run_events_config: Any | None = field(default=None)
|
||||
thread_meta_repo: Any | None = field(default=None)
|
||||
thread_store: Any | None = field(default=None)
|
||||
follow_up_to_run_id: str | None = field(default=None)
|
||||
|
||||
|
||||
@@ -75,7 +75,7 @@ async def run_agent(
|
||||
store = ctx.store
|
||||
event_store = ctx.event_store
|
||||
run_events_config = ctx.run_events_config
|
||||
thread_meta_repo = ctx.thread_meta_repo
|
||||
thread_store = ctx.thread_store
|
||||
follow_up_to_run_id = ctx.follow_up_to_run_id
|
||||
|
||||
run_id = record.run_id
|
||||
@@ -85,63 +85,7 @@ async def run_agent(
|
||||
pre_run_snapshot: dict[str, Any] | None = None
|
||||
snapshot_capture_failed = False
|
||||
|
||||
# Initialize RunJournal for event capture
|
||||
journal = None
|
||||
if event_store is not None:
|
||||
from deerflow.runtime.journal import RunJournal
|
||||
|
||||
journal = RunJournal(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
event_store=event_store,
|
||||
track_token_usage=getattr(run_events_config, "track_token_usage", True),
|
||||
)
|
||||
|
||||
# Write human_message event (model_dump format, aligned with checkpoint)
|
||||
human_msg = _extract_human_message(graph_input)
|
||||
if human_msg is not None:
|
||||
msg_metadata = {}
|
||||
if follow_up_to_run_id:
|
||||
msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id
|
||||
await event_store.put(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content=human_msg.model_dump(),
|
||||
metadata=msg_metadata or None,
|
||||
)
|
||||
content = human_msg.content
|
||||
journal.set_first_human_message(content if isinstance(content, str) else str(content))
|
||||
|
||||
# Initialize RunJournal for event capture
|
||||
journal = None
|
||||
if event_store is not None:
|
||||
from deerflow.runtime.journal import RunJournal
|
||||
|
||||
journal = RunJournal(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
event_store=event_store,
|
||||
track_token_usage=getattr(run_events_config, "track_token_usage", True),
|
||||
)
|
||||
|
||||
# Write human_message event (model_dump format, aligned with checkpoint)
|
||||
human_msg = _extract_human_message(graph_input)
|
||||
if human_msg is not None:
|
||||
msg_metadata = {}
|
||||
if follow_up_to_run_id:
|
||||
msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id
|
||||
await event_store.put(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content=human_msg.model_dump(),
|
||||
metadata=msg_metadata or None,
|
||||
)
|
||||
content = human_msg.content
|
||||
journal.set_first_human_message(content if isinstance(content, str) else str(content))
|
||||
|
||||
# Track whether "events" was requested but skipped
|
||||
if "events" in requested_modes:
|
||||
@@ -151,6 +95,38 @@ async def run_agent(
|
||||
)
|
||||
|
||||
try:
|
||||
# Initialize RunJournal + write human_message event.
|
||||
# These are inside the try block so any exception (e.g. a DB
|
||||
# error writing the event) flows through the except/finally
|
||||
# path that publishes an "end" event to the SSE bridge —
|
||||
# otherwise a failure here would leave the stream hanging
|
||||
# with no terminator.
|
||||
if event_store is not None:
|
||||
from deerflow.runtime.journal import RunJournal
|
||||
|
||||
journal = RunJournal(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
event_store=event_store,
|
||||
track_token_usage=getattr(run_events_config, "track_token_usage", True),
|
||||
)
|
||||
|
||||
human_msg = _extract_human_message(graph_input)
|
||||
if human_msg is not None:
|
||||
msg_metadata = {}
|
||||
if follow_up_to_run_id:
|
||||
msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id
|
||||
await event_store.put(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content=human_msg.model_dump(),
|
||||
metadata=msg_metadata or None,
|
||||
)
|
||||
content = human_msg.content
|
||||
journal.set_first_human_message(content if isinstance(content, str) else str(content))
|
||||
|
||||
# 1. Mark running
|
||||
await run_manager.set_status(run_id, RunStatus.running)
|
||||
|
||||
@@ -334,12 +310,15 @@ async def run_agent(
|
||||
except Exception:
|
||||
logger.warning("Failed to flush journal for run %s", run_id, exc_info=True)
|
||||
|
||||
# Persist token usage + convenience fields to RunStore
|
||||
completion = journal.get_completion_data()
|
||||
await run_manager.update_run_completion(run_id, status=record.status.value, **completion)
|
||||
try:
|
||||
# Persist token usage + convenience fields to RunStore
|
||||
completion = journal.get_completion_data()
|
||||
await run_manager.update_run_completion(run_id, status=record.status.value, **completion)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist run completion for %s (non-fatal)", run_id, exc_info=True)
|
||||
|
||||
# Sync title from checkpoint to threads_meta.display_name
|
||||
if checkpointer is not None:
|
||||
if checkpointer is not None and thread_store is not None:
|
||||
try:
|
||||
ckpt_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
ckpt_tuple = await checkpointer.aget_tuple(ckpt_config)
|
||||
@@ -347,16 +326,17 @@ async def run_agent(
|
||||
ckpt = getattr(ckpt_tuple, "checkpoint", {}) or {}
|
||||
title = ckpt.get("channel_values", {}).get("title")
|
||||
if title:
|
||||
await thread_meta_repo.update_display_name(thread_id, title)
|
||||
await thread_store.update_display_name(thread_id, title)
|
||||
except Exception:
|
||||
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id)
|
||||
|
||||
# Update threads_meta status based on run outcome
|
||||
try:
|
||||
final_status = "idle" if record.status == RunStatus.success else record.status.value
|
||||
await thread_meta_repo.update_status(thread_id, final_status)
|
||||
except Exception:
|
||||
logger.debug("Failed to update thread_meta status for %s (non-fatal)", thread_id)
|
||||
if thread_store is not None:
|
||||
try:
|
||||
final_status = "idle" if record.status == RunStatus.success else record.status.value
|
||||
await thread_store.update_status(thread_id, final_status)
|
||||
except Exception:
|
||||
logger.debug("Failed to update thread_meta status for %s (non-fatal)", thread_id)
|
||||
|
||||
await bridge.publish_end(run_id)
|
||||
asyncio.create_task(bridge.cleanup(run_id, delay=60))
|
||||
|
||||
@@ -91,7 +91,7 @@ async def make_store() -> AsyncIterator[BaseStore]:
|
||||
configured checkpointer.
|
||||
|
||||
Reads from the same ``checkpointer`` section of *config.yaml* used by
|
||||
:func:`deerflow.agents.checkpointer.async_provider.make_checkpointer` so
|
||||
:func:`deerflow.runtime.checkpointer.async_provider.make_checkpointer` so
|
||||
that both singletons always use the same persistence technology::
|
||||
|
||||
async with make_store() as store:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Async stream bridge factory.
|
||||
|
||||
Provides an **async context manager** aligned with
|
||||
:func:`deerflow.agents.checkpointer.async_provider.make_checkpointer`.
|
||||
:func:`deerflow.runtime.checkpointer.async_provider.make_checkpointer`.
|
||||
|
||||
Usage (e.g. FastAPI lifespan)::
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""Request-scoped user context for owner-based authorization.
|
||||
"""Request-scoped user context for user-based authorization.
|
||||
|
||||
This module holds a :class:`~contextvars.ContextVar` that the gateway's
|
||||
auth middleware sets after a successful authentication. Repository
|
||||
methods read the contextvar via a sentinel default parameter, letting
|
||||
routers stay free of ``owner_id`` boilerplate.
|
||||
routers stay free of ``user_id`` boilerplate.
|
||||
|
||||
Three-state semantics for the repository ``owner_id`` parameter (the
|
||||
Three-state semantics for the repository ``user_id`` parameter (the
|
||||
consumer side of this module lives in ``deerflow.persistence.*``):
|
||||
|
||||
- ``_AUTO`` (module-private sentinel, default): read from contextvar;
|
||||
@@ -91,16 +91,35 @@ def require_current_user() -> CurrentUser:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sentinel-based owner_id resolution
|
||||
# Effective user_id helpers (filesystem isolation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DEFAULT_USER_ID: Final[str] = "default"
|
||||
|
||||
|
||||
def get_effective_user_id() -> str:
|
||||
"""Return the current user's id as a string, or DEFAULT_USER_ID if unset.
|
||||
|
||||
Unlike :func:`require_current_user` this never raises — it is designed
|
||||
for filesystem-path resolution where a valid user bucket is always needed.
|
||||
"""
|
||||
user = _current_user.get()
|
||||
if user is None:
|
||||
return DEFAULT_USER_ID
|
||||
return str(user.id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sentinel-based user_id resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Repository methods accept an ``owner_id`` keyword-only argument that
|
||||
# Repository methods accept a ``user_id`` keyword-only argument that
|
||||
# defaults to ``AUTO``. The three possible values drive distinct
|
||||
# behaviours; see the docstring on :func:`resolve_owner_id`.
|
||||
# behaviours; see the docstring on :func:`resolve_user_id`.
|
||||
|
||||
|
||||
class _AutoSentinel:
|
||||
"""Singleton marker meaning 'resolve owner_id from contextvar'."""
|
||||
"""Singleton marker meaning 'resolve user_id from contextvar'."""
|
||||
|
||||
_instance: _AutoSentinel | None = None
|
||||
|
||||
@@ -116,12 +135,12 @@ class _AutoSentinel:
|
||||
AUTO: Final[_AutoSentinel] = _AutoSentinel()
|
||||
|
||||
|
||||
def resolve_owner_id(
|
||||
def resolve_user_id(
|
||||
value: str | None | _AutoSentinel,
|
||||
*,
|
||||
method_name: str = "repository method",
|
||||
) -> str | None:
|
||||
"""Resolve the owner_id parameter passed to a repository method.
|
||||
"""Resolve the user_id parameter passed to a repository method.
|
||||
|
||||
Three-state semantics:
|
||||
|
||||
@@ -131,16 +150,16 @@ def resolve_owner_id(
|
||||
- Explicit ``str``: use the provided id verbatim, overriding any
|
||||
contextvar value. Useful for tests and admin-override flows.
|
||||
- Explicit ``None``: no filter — the repository should skip the
|
||||
owner_id WHERE clause entirely. Reserved for migration scripts
|
||||
user_id WHERE clause entirely. Reserved for migration scripts
|
||||
and CLI tools that intentionally bypass isolation.
|
||||
"""
|
||||
if isinstance(value, _AutoSentinel):
|
||||
user = _current_user.get()
|
||||
if user is None:
|
||||
raise RuntimeError(f"{method_name} called with owner_id=AUTO but no user context is set; pass an explicit owner_id, set the contextvar via auth middleware, or opt out with owner_id=None for migration/CLI paths.")
|
||||
raise RuntimeError(f"{method_name} called with user_id=AUTO but no user context is set; pass an explicit user_id, set the contextvar via auth middleware, or opt out with user_id=None for migration/CLI paths.")
|
||||
# Coerce to ``str`` at the boundary: ``User.id`` is typed as
|
||||
# ``UUID`` for the API surface, but the persistence layer
|
||||
# stores ``owner_id`` as ``String(64)`` and aiosqlite cannot
|
||||
# stores ``user_id`` as ``String(64)`` and aiosqlite cannot
|
||||
# bind a raw UUID object to a VARCHAR column ("type 'UUID' is
|
||||
# not supported"). Honour the documented return type here
|
||||
# rather than ripple a type change through every caller.
|
||||
|
||||
@@ -200,8 +200,9 @@ def _get_acp_workspace_host_path(thread_id: str | None = None) -> str | None:
|
||||
if thread_id is not None:
|
||||
try:
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
host_path = get_paths().acp_workspace_dir(thread_id)
|
||||
host_path = get_paths().acp_workspace_dir(thread_id, user_id=get_effective_user_id())
|
||||
if host_path.exists():
|
||||
return str(host_path)
|
||||
except Exception:
|
||||
|
||||
@@ -33,11 +33,12 @@ def _get_work_dir(thread_id: str | None) -> str:
|
||||
An absolute physical filesystem path to use as the working directory.
|
||||
"""
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
paths = get_paths()
|
||||
if thread_id:
|
||||
try:
|
||||
work_dir = paths.acp_workspace_dir(thread_id)
|
||||
work_dir = paths.acp_workspace_dir(thread_id, user_id=get_effective_user_id())
|
||||
except ValueError:
|
||||
logger.warning("Invalid thread_id %r for ACP workspace, falling back to global", thread_id)
|
||||
work_dir = paths.base_dir / "acp-workspace"
|
||||
|
||||
@@ -8,6 +8,7 @@ from langgraph.typing import ContextT
|
||||
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs"
|
||||
|
||||
@@ -47,7 +48,7 @@ def _normalize_presented_filepath(
|
||||
virtual_prefix = VIRTUAL_PATH_PREFIX.lstrip("/")
|
||||
|
||||
if stripped == virtual_prefix or stripped.startswith(virtual_prefix + "/"):
|
||||
actual_path = get_paths().resolve_virtual_path(thread_id, filepath)
|
||||
actual_path = get_paths().resolve_virtual_path(thread_id, filepath, user_id=get_effective_user_id())
|
||||
else:
|
||||
actual_path = Path(filepath).expanduser().resolve()
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from pathlib import Path
|
||||
from urllib.parse import quote
|
||||
|
||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
|
||||
class PathTraversalError(ValueError):
|
||||
@@ -33,7 +34,7 @@ def validate_thread_id(thread_id: str) -> None:
|
||||
def get_uploads_dir(thread_id: str) -> Path:
|
||||
"""Return the uploads directory path for a thread (no side effects)."""
|
||||
validate_thread_id(thread_id)
|
||||
return get_paths().sandbox_uploads_dir(thread_id)
|
||||
return get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
||||
|
||||
|
||||
def ensure_uploads_dir(thread_id: str) -> Path:
|
||||
|
||||
@@ -0,0 +1,160 @@
|
||||
"""One-time migration: move legacy thread dirs and memory into per-user layout.
|
||||
|
||||
Usage:
|
||||
PYTHONPATH=. python scripts/migrate_user_isolation.py [--dry-run]
|
||||
|
||||
The script is idempotent — re-running it after a successful migration is a no-op.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def migrate_thread_dirs(
|
||||
paths: Paths,
|
||||
thread_owner_map: dict[str, str],
|
||||
*,
|
||||
dry_run: bool = False,
|
||||
) -> list[dict]:
|
||||
"""Move legacy thread directories into per-user layout.
|
||||
|
||||
Args:
|
||||
paths: Paths instance.
|
||||
thread_owner_map: Mapping of thread_id -> user_id from threads_meta table.
|
||||
dry_run: If True, only log what would happen.
|
||||
|
||||
Returns:
|
||||
List of migration report entries.
|
||||
"""
|
||||
report: list[dict] = []
|
||||
legacy_threads = paths.base_dir / "threads"
|
||||
if not legacy_threads.exists():
|
||||
logger.info("No legacy threads directory found — nothing to migrate.")
|
||||
return report
|
||||
|
||||
for thread_dir in sorted(legacy_threads.iterdir()):
|
||||
if not thread_dir.is_dir():
|
||||
continue
|
||||
thread_id = thread_dir.name
|
||||
user_id = thread_owner_map.get(thread_id, "default")
|
||||
dest = paths.base_dir / "users" / user_id / "threads" / thread_id
|
||||
|
||||
entry = {"thread_id": thread_id, "user_id": user_id, "action": ""}
|
||||
|
||||
if dest.exists():
|
||||
conflicts_dir = paths.base_dir / "migration-conflicts" / thread_id
|
||||
entry["action"] = f"conflict -> {conflicts_dir}"
|
||||
if not dry_run:
|
||||
conflicts_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(thread_dir), str(conflicts_dir))
|
||||
logger.warning("Conflict for thread %s: moved to %s", thread_id, conflicts_dir)
|
||||
else:
|
||||
entry["action"] = f"moved -> {dest}"
|
||||
if not dry_run:
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(thread_dir), str(dest))
|
||||
logger.info("Migrated thread %s -> user %s", thread_id, user_id)
|
||||
|
||||
report.append(entry)
|
||||
|
||||
# Clean up empty legacy threads dir
|
||||
if not dry_run and legacy_threads.exists() and not any(legacy_threads.iterdir()):
|
||||
legacy_threads.rmdir()
|
||||
|
||||
return report
|
||||
|
||||
|
||||
def migrate_memory(
|
||||
paths: Paths,
|
||||
user_id: str = "default",
|
||||
*,
|
||||
dry_run: bool = False,
|
||||
) -> None:
|
||||
"""Move legacy global memory.json into per-user layout.
|
||||
|
||||
Args:
|
||||
paths: Paths instance.
|
||||
user_id: Target user to receive the legacy memory.
|
||||
dry_run: If True, only log.
|
||||
"""
|
||||
legacy_mem = paths.base_dir / "memory.json"
|
||||
if not legacy_mem.exists():
|
||||
logger.info("No legacy memory.json found — nothing to migrate.")
|
||||
return
|
||||
|
||||
dest = paths.user_memory_file(user_id)
|
||||
if dest.exists():
|
||||
legacy_backup = paths.base_dir / "memory.legacy.json"
|
||||
logger.warning("Destination %s exists; renaming legacy to %s", dest, legacy_backup)
|
||||
if not dry_run:
|
||||
legacy_mem.rename(legacy_backup)
|
||||
return
|
||||
|
||||
logger.info("Migrating memory.json -> %s", dest)
|
||||
if not dry_run:
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(legacy_mem), str(dest))
|
||||
|
||||
|
||||
def _build_owner_map_from_db(paths: Paths) -> dict[str, str]:
|
||||
"""Query threads_meta table for thread_id -> user_id mapping.
|
||||
|
||||
Uses raw sqlite3 to avoid async dependencies.
|
||||
"""
|
||||
import sqlite3
|
||||
|
||||
db_path = paths.base_dir / "deer-flow.db"
|
||||
if not db_path.exists():
|
||||
logger.info("No database found at %s — using empty owner map.", db_path)
|
||||
return {}
|
||||
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
try:
|
||||
cursor = conn.execute("SELECT thread_id, user_id FROM threads_meta WHERE user_id IS NOT NULL")
|
||||
return {row[0]: row[1] for row in cursor.fetchall()}
|
||||
except sqlite3.OperationalError as e:
|
||||
logger.warning("Failed to query threads_meta: %s", e)
|
||||
return {}
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Migrate DeerFlow data to per-user layout")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Log actions without making changes")
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
|
||||
paths = get_paths()
|
||||
logger.info("Base directory: %s", paths.base_dir)
|
||||
logger.info("Dry run: %s", args.dry_run)
|
||||
|
||||
owner_map = _build_owner_map_from_db(paths)
|
||||
logger.info("Found %d thread ownership records in DB", len(owner_map))
|
||||
|
||||
report = migrate_thread_dirs(paths, owner_map, dry_run=args.dry_run)
|
||||
migrate_memory(paths, user_id="default", dry_run=args.dry_run)
|
||||
|
||||
if report:
|
||||
logger.info("Migration report:")
|
||||
for entry in report:
|
||||
logger.info(" thread=%s user=%s action=%s", entry["thread_id"], entry["user_id"], entry["action"])
|
||||
else:
|
||||
logger.info("No threads to migrate.")
|
||||
|
||||
unowned = [e for e in report if e["user_id"] == "default"]
|
||||
if unowned:
|
||||
logger.warning("%d thread(s) had no owner and were assigned to 'default':", len(unowned))
|
||||
for e in unowned:
|
||||
logger.warning(" %s", e["thread_id"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -3,16 +3,16 @@
|
||||
The production gateway runs ``AuthMiddleware`` (validates the JWT cookie)
|
||||
ahead of every router, plus ``@require_permission(owner_check=True)``
|
||||
decorators that read ``request.state.auth`` and call
|
||||
``thread_meta_repo.check_access``. Router-level unit tests construct
|
||||
``thread_store.check_access``. Router-level unit tests construct
|
||||
**bare** FastAPI apps that include only one router — they have neither
|
||||
the auth middleware nor a real thread_meta_repo, so the decorators raise
|
||||
the auth middleware nor a real thread_store, so the decorators raise
|
||||
401 (TestClient path) or ValueError (direct-call path).
|
||||
|
||||
This module provides two surfaces:
|
||||
|
||||
1. :func:`make_authed_test_app` — wraps ``FastAPI()`` with a tiny
|
||||
``BaseHTTPMiddleware`` that stamps a fake user / AuthContext on every
|
||||
request, plus a permissive ``thread_meta_repo`` mock on
|
||||
request, plus a permissive ``thread_store`` mock on
|
||||
``app.state``. Use from TestClient-based router tests.
|
||||
|
||||
2. :func:`call_unwrapped` — invokes the underlying function bypassing
|
||||
@@ -86,20 +86,20 @@ def make_authed_test_app(
|
||||
user_factory: Callable[[], User] | None = None,
|
||||
owner_check_passes: bool = True,
|
||||
) -> FastAPI:
|
||||
"""Build a FastAPI test app with stub auth + permissive thread_meta_repo.
|
||||
"""Build a FastAPI test app with stub auth + permissive thread_store.
|
||||
|
||||
Args:
|
||||
user_factory: Override the default test user. Must return a fully
|
||||
populated :class:`User`. Useful for cross-user isolation tests
|
||||
that need a stable id across requests.
|
||||
owner_check_passes: When True (default), ``thread_meta_repo.check_access``
|
||||
owner_check_passes: When True (default), ``thread_store.check_access``
|
||||
returns True for every call so ``@require_permission(owner_check=True)``
|
||||
never blocks the route under test. Pass False to verify that
|
||||
permission failures surface correctly.
|
||||
|
||||
Returns:
|
||||
A ``FastAPI`` app with the stub middleware installed and
|
||||
``app.state.thread_meta_repo`` set to a permissive mock. The
|
||||
``app.state.thread_store`` set to a permissive mock. The
|
||||
caller is still responsible for ``app.include_router(...)``.
|
||||
"""
|
||||
factory = user_factory or _make_stub_user
|
||||
@@ -108,7 +108,7 @@ def make_authed_test_app(
|
||||
|
||||
repo = MagicMock()
|
||||
repo.check_access = AsyncMock(return_value=owner_check_passes)
|
||||
app.state.thread_meta_repo = repo
|
||||
app.state.thread_store = repo
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ def provisioner_module():
|
||||
# Auto-set user context for every test unless marked no_auto_user
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Repository methods read ``owner_id`` from a contextvar by default
|
||||
# Repository methods read ``user_id`` from a contextvar by default
|
||||
# (see ``deerflow.runtime.user_context``). Without this fixture, every
|
||||
# pre-existing persistence test would raise RuntimeError because the
|
||||
# contextvar is unset. The fixture sets a default test user on every
|
||||
|
||||
@@ -57,6 +57,7 @@ def test_get_thread_mounts_includes_acp_workspace(tmp_path, monkeypatch):
|
||||
"""_get_thread_mounts must include /mnt/acp-workspace (read-only) for docker sandbox."""
|
||||
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||
monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path))
|
||||
monkeypatch.setattr(aio_mod, "get_effective_user_id", lambda: None)
|
||||
|
||||
mounts = aio_mod.AioSandboxProvider._get_thread_mounts("thread-3")
|
||||
|
||||
@@ -95,6 +96,7 @@ def test_get_thread_mounts_preserves_windows_host_path_style(tmp_path, monkeypat
|
||||
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||
monkeypatch.setenv("DEER_FLOW_HOST_BASE_DIR", r"C:\Users\demo\deer-flow\backend\.deer-flow")
|
||||
monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path))
|
||||
monkeypatch.setattr(aio_mod, "get_effective_user_id", lambda: None)
|
||||
|
||||
mounts = aio_mod.AioSandboxProvider._get_thread_mounts("thread-10")
|
||||
|
||||
|
||||
@@ -231,7 +231,7 @@ class TestResolveAttachments:
|
||||
mock_paths = MagicMock()
|
||||
mock_paths.sandbox_outputs_dir.return_value = outputs_dir
|
||||
|
||||
def resolve_side_effect(tid, vpath):
|
||||
def resolve_side_effect(tid, vpath, *, user_id=None):
|
||||
if "data.csv" in vpath:
|
||||
return good_file
|
||||
return tmp_path / "missing.txt"
|
||||
|
||||
@@ -6,13 +6,13 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
import deerflow.config.app_config as app_config_module
|
||||
from deerflow.agents.checkpointer import get_checkpointer, reset_checkpointer
|
||||
from deerflow.config.checkpointer_config import (
|
||||
CheckpointerConfig,
|
||||
get_checkpointer_config,
|
||||
load_checkpointer_config_from_dict,
|
||||
set_checkpointer_config,
|
||||
)
|
||||
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -78,7 +78,7 @@ class TestGetCheckpointer:
|
||||
"""get_checkpointer should return InMemorySaver when not configured."""
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
with patch("deerflow.agents.checkpointer.provider.get_app_config", side_effect=FileNotFoundError):
|
||||
with patch("deerflow.runtime.checkpointer.provider.get_app_config", side_effect=FileNotFoundError):
|
||||
cp = get_checkpointer()
|
||||
assert cp is not None
|
||||
assert isinstance(cp, InMemorySaver)
|
||||
@@ -178,7 +178,7 @@ class TestAsyncCheckpointer:
|
||||
@pytest.mark.anyio
|
||||
async def test_sqlite_creates_parent_dir_via_to_thread(self):
|
||||
"""Async SQLite setup should move mkdir off the event loop."""
|
||||
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
||||
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.checkpointer = CheckpointerConfig(type="sqlite", connection_string="relative/test.db")
|
||||
@@ -195,11 +195,11 @@ class TestAsyncCheckpointer:
|
||||
mock_module.AsyncSqliteSaver = mock_saver_cls
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.checkpointer.async_provider.get_app_config", return_value=mock_config),
|
||||
patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config),
|
||||
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite.aio": mock_module}),
|
||||
patch("deerflow.agents.checkpointer.async_provider.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread,
|
||||
patch("deerflow.runtime.checkpointer.async_provider.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread,
|
||||
patch(
|
||||
"deerflow.agents.checkpointer.async_provider.resolve_sqlite_conn_str",
|
||||
"deerflow.runtime.checkpointer.async_provider.resolve_sqlite_conn_str",
|
||||
return_value="/tmp/resolved/test.db",
|
||||
),
|
||||
):
|
||||
|
||||
@@ -12,14 +12,14 @@ class TestCheckpointerNoneFix:
|
||||
@pytest.mark.anyio
|
||||
async def test_async_make_checkpointer_returns_in_memory_saver_when_not_configured(self):
|
||||
"""make_checkpointer should return InMemorySaver when config.checkpointer is None."""
|
||||
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
||||
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
||||
|
||||
# Mock get_app_config to return a config with checkpointer=None and database=None
|
||||
mock_config = MagicMock()
|
||||
mock_config.checkpointer = None
|
||||
mock_config.database = None
|
||||
|
||||
with patch("deerflow.agents.checkpointer.async_provider.get_app_config", return_value=mock_config):
|
||||
with patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config):
|
||||
async with make_checkpointer() as checkpointer:
|
||||
# Should return InMemorySaver, not None
|
||||
assert checkpointer is not None
|
||||
@@ -36,13 +36,13 @@ class TestCheckpointerNoneFix:
|
||||
|
||||
def test_sync_checkpointer_context_returns_in_memory_saver_when_not_configured(self):
|
||||
"""checkpointer_context should return InMemorySaver when config.checkpointer is None."""
|
||||
from deerflow.agents.checkpointer.provider import checkpointer_context
|
||||
from deerflow.runtime.checkpointer.provider import checkpointer_context
|
||||
|
||||
# Mock get_app_config to return a config with checkpointer=None
|
||||
mock_config = MagicMock()
|
||||
mock_config.checkpointer = None
|
||||
|
||||
with patch("deerflow.agents.checkpointer.provider.get_app_config", return_value=mock_config):
|
||||
with patch("deerflow.runtime.checkpointer.provider.get_app_config", return_value=mock_config):
|
||||
with checkpointer_context() as checkpointer:
|
||||
# Should return InMemorySaver, not None
|
||||
assert checkpointer is not None
|
||||
|
||||
@@ -817,7 +817,7 @@ class TestEnsureAgent:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]) as mock_build_middlewares,
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt") as mock_apply_prompt,
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._agent_name = "custom-agent"
|
||||
client._available_skills = {"test_skill"}
|
||||
@@ -842,7 +842,7 @@ class TestEnsureAgent:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=mock_checkpointer),
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=mock_checkpointer),
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
|
||||
@@ -867,7 +867,7 @@ class TestEnsureAgent:
|
||||
patch("deerflow.client._build_middlewares", side_effect=fake_build_middlewares),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
|
||||
@@ -886,7 +886,7 @@ class TestEnsureAgent:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=None),
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=None),
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
|
||||
@@ -1015,7 +1015,7 @@ class TestThreadQueries:
|
||||
mock_checkpointer = MagicMock()
|
||||
mock_checkpointer.list.return_value = []
|
||||
|
||||
with patch("deerflow.agents.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer):
|
||||
with patch("deerflow.runtime.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer):
|
||||
# No internal checkpointer, should fetch from provider
|
||||
result = client.list_threads()
|
||||
|
||||
@@ -1069,7 +1069,7 @@ class TestThreadQueries:
|
||||
mock_checkpointer = MagicMock()
|
||||
mock_checkpointer.list.return_value = []
|
||||
|
||||
with patch("deerflow.agents.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer):
|
||||
with patch("deerflow.runtime.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer):
|
||||
result = client.get_thread("t99")
|
||||
|
||||
assert result["thread_id"] == "t99"
|
||||
@@ -1241,7 +1241,10 @@ class TestMemoryManagement:
|
||||
with patch("deerflow.agents.memory.updater.import_memory_data", return_value=imported) as mock_import:
|
||||
result = client.import_memory(imported)
|
||||
|
||||
mock_import.assert_called_once_with(imported)
|
||||
assert mock_import.call_count == 1
|
||||
call_args = mock_import.call_args
|
||||
assert call_args.args == (imported,)
|
||||
assert "user_id" in call_args.kwargs
|
||||
assert result == imported
|
||||
|
||||
def test_reload_memory(self, client):
|
||||
@@ -1487,9 +1490,12 @@ class TestUploads:
|
||||
|
||||
class TestArtifacts:
|
||||
def test_get_artifact(self, client):
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
paths = Paths(base_dir=tmp)
|
||||
outputs = paths.sandbox_outputs_dir("t1")
|
||||
user_id = get_effective_user_id()
|
||||
outputs = paths.sandbox_outputs_dir("t1", user_id=user_id)
|
||||
outputs.mkdir(parents=True)
|
||||
(outputs / "result.txt").write_text("artifact content")
|
||||
|
||||
@@ -1500,9 +1506,12 @@ class TestArtifacts:
|
||||
assert "text" in mime
|
||||
|
||||
def test_get_artifact_not_found(self, client):
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
paths = Paths(base_dir=tmp)
|
||||
paths.sandbox_user_data_dir("t1").mkdir(parents=True)
|
||||
user_id = get_effective_user_id()
|
||||
paths.sandbox_outputs_dir("t1", user_id=user_id).mkdir(parents=True)
|
||||
|
||||
with patch("deerflow.client.get_paths", return_value=paths):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
@@ -1513,9 +1522,12 @@ class TestArtifacts:
|
||||
client.get_artifact("t1", "bad/path/file.txt")
|
||||
|
||||
def test_get_artifact_path_traversal(self, client):
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
paths = Paths(base_dir=tmp)
|
||||
paths.sandbox_user_data_dir("t1").mkdir(parents=True)
|
||||
user_id = get_effective_user_id()
|
||||
paths.sandbox_outputs_dir("t1", user_id=user_id).mkdir(parents=True)
|
||||
|
||||
with patch("deerflow.client.get_paths", return_value=paths):
|
||||
with pytest.raises(PathTraversalError):
|
||||
@@ -1699,13 +1711,16 @@ class TestScenarioFileLifecycle:
|
||||
|
||||
def test_upload_then_read_artifact(self, client):
|
||||
"""Upload a file, simulate agent producing artifact, read it back."""
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
uploads_dir = tmp_path / "uploads"
|
||||
uploads_dir.mkdir()
|
||||
|
||||
paths = Paths(base_dir=tmp_path)
|
||||
outputs_dir = paths.sandbox_outputs_dir("t-artifact")
|
||||
user_id = get_effective_user_id()
|
||||
outputs_dir = paths.sandbox_outputs_dir("t-artifact", user_id=user_id)
|
||||
outputs_dir.mkdir(parents=True)
|
||||
|
||||
# Upload phase
|
||||
@@ -1844,7 +1859,7 @@ class TestScenarioAgentRecreation:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._ensure_agent(config_a)
|
||||
first_agent = client._agent
|
||||
@@ -1872,7 +1887,7 @@ class TestScenarioAgentRecreation:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
client._ensure_agent(config)
|
||||
@@ -1897,7 +1912,7 @@ class TestScenarioAgentRecreation:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
client.reset_agent()
|
||||
@@ -1955,11 +1970,14 @@ class TestScenarioThreadIsolation:
|
||||
|
||||
def test_artifacts_isolated_per_thread(self, client):
|
||||
"""Artifacts in thread-A are not accessible from thread-B."""
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
paths = Paths(base_dir=tmp)
|
||||
outputs_a = paths.sandbox_outputs_dir("thread-a")
|
||||
user_id = get_effective_user_id()
|
||||
outputs_a = paths.sandbox_outputs_dir("thread-a", user_id=user_id)
|
||||
outputs_a.mkdir(parents=True)
|
||||
paths.sandbox_user_data_dir("thread-b").mkdir(parents=True)
|
||||
paths.sandbox_outputs_dir("thread-b", user_id=user_id).mkdir(parents=True)
|
||||
(outputs_a / "result.txt").write_text("thread-a artifact")
|
||||
|
||||
with patch("deerflow.client.get_paths", return_value=paths):
|
||||
@@ -2864,9 +2882,12 @@ class TestUploadDeleteSymlink:
|
||||
class TestArtifactHardening:
|
||||
def test_artifact_directory_rejected(self, client):
|
||||
"""get_artifact rejects paths that resolve to a directory."""
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
paths = Paths(base_dir=tmp)
|
||||
subdir = paths.sandbox_outputs_dir("t1") / "subdir"
|
||||
user_id = get_effective_user_id()
|
||||
subdir = paths.sandbox_outputs_dir("t1", user_id=user_id) / "subdir"
|
||||
subdir.mkdir(parents=True)
|
||||
|
||||
with patch("deerflow.client.get_paths", return_value=paths):
|
||||
@@ -2875,9 +2896,12 @@ class TestArtifactHardening:
|
||||
|
||||
def test_artifact_leading_slash_stripped(self, client):
|
||||
"""Paths with leading slash are handled correctly."""
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
paths = Paths(base_dir=tmp)
|
||||
outputs = paths.sandbox_outputs_dir("t1")
|
||||
user_id = get_effective_user_id()
|
||||
outputs = paths.sandbox_outputs_dir("t1", user_id=user_id)
|
||||
outputs.mkdir(parents=True)
|
||||
(outputs / "file.txt").write_text("content")
|
||||
|
||||
@@ -2991,9 +3015,12 @@ class TestBugArtifactPrefixMatchTooLoose:
|
||||
|
||||
def test_exact_prefix_without_subpath_accepted(self, client):
|
||||
"""Bare 'mnt/user-data' is accepted (will later fail as directory, not at prefix)."""
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
paths = Paths(base_dir=tmp)
|
||||
paths.sandbox_user_data_dir("t1").mkdir(parents=True)
|
||||
user_id = get_effective_user_id()
|
||||
paths.sandbox_outputs_dir("t1", user_id=user_id).mkdir(parents=True)
|
||||
|
||||
with patch("deerflow.client.get_paths", return_value=paths):
|
||||
# Accepted at prefix check, but fails because it's a directory.
|
||||
|
||||
@@ -262,8 +262,9 @@ class TestFileUploadIntegration:
|
||||
|
||||
# Physically exists
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
assert (get_paths().sandbox_uploads_dir(tid) / "readme.txt").exists()
|
||||
assert (get_paths().sandbox_uploads_dir(tid, user_id=get_effective_user_id()) / "readme.txt").exists()
|
||||
|
||||
def test_upload_duplicate_rename(self, e2e_env, tmp_path):
|
||||
"""Uploading two files with the same name auto-renames the second."""
|
||||
@@ -472,12 +473,13 @@ class TestArtifactAccess:
|
||||
def test_get_artifact_happy_path(self, e2e_env):
|
||||
"""Write a file to outputs, then read it back via get_artifact()."""
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
|
||||
tid = str(uuid.uuid4())
|
||||
|
||||
# Create an output file in the thread's outputs directory
|
||||
outputs_dir = get_paths().sandbox_outputs_dir(tid)
|
||||
outputs_dir = get_paths().sandbox_outputs_dir(tid, user_id=get_effective_user_id())
|
||||
outputs_dir.mkdir(parents=True, exist_ok=True)
|
||||
(outputs_dir / "result.txt").write_text("hello artifact")
|
||||
|
||||
@@ -488,11 +490,12 @@ class TestArtifactAccess:
|
||||
def test_get_artifact_nested_path(self, e2e_env):
|
||||
"""Artifacts in subdirectories are accessible."""
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
|
||||
tid = str(uuid.uuid4())
|
||||
|
||||
outputs_dir = get_paths().sandbox_outputs_dir(tid)
|
||||
outputs_dir = get_paths().sandbox_outputs_dir(tid, user_id=get_effective_user_id())
|
||||
sub = outputs_dir / "charts"
|
||||
sub.mkdir(parents=True, exist_ok=True)
|
||||
(sub / "data.json").write_text('{"x": 1}')
|
||||
|
||||
@@ -199,12 +199,12 @@ def test_migration_failure_is_non_fatal():
|
||||
# ── Section 5.1-5.6 upgrade path: orphan thread migration ────────────────
|
||||
|
||||
|
||||
def test_migrate_orphaned_threads_stamps_owner_id_on_unowned_rows():
|
||||
def test_migrate_orphaned_threads_stamps_user_id_on_unowned_rows():
|
||||
"""First boot finds Store-only legacy threads → stamps admin's id.
|
||||
|
||||
Validates the **TC-UPG-02 upgrade story**: an operator running main
|
||||
(no auth) accumulates threads in the LangGraph Store namespace
|
||||
``("threads",)`` with no ``metadata.owner_id``. After upgrading to
|
||||
``("threads",)`` with no ``metadata.user_id``. After upgrading to
|
||||
feat/auth-on-2.0-rc, the first ``_ensure_admin_user`` boot should
|
||||
rewrite each unowned item with the freshly created admin's id.
|
||||
"""
|
||||
@@ -215,7 +215,7 @@ def test_migrate_orphaned_threads_stamps_owner_id_on_unowned_rows():
|
||||
SimpleNamespace(key="t1", value={"metadata": {"title": "old-thread-1"}}),
|
||||
SimpleNamespace(key="t2", value={"metadata": {"title": "old-thread-2"}}),
|
||||
SimpleNamespace(key="t3", value={"metadata": {}}),
|
||||
SimpleNamespace(key="t4", value={"metadata": {"owner_id": "someone-else", "title": "preserved"}}),
|
||||
SimpleNamespace(key="t4", value={"metadata": {"user_id": "someone-else", "title": "preserved"}}),
|
||||
]
|
||||
store = AsyncMock()
|
||||
# asearch returns the entire batch on first call, then an empty page
|
||||
@@ -235,11 +235,11 @@ def test_migrate_orphaned_threads_stamps_owner_id_on_unowned_rows():
|
||||
assert len(aput_calls) == 3
|
||||
rewritten_keys = {call[1] for call in aput_calls}
|
||||
assert rewritten_keys == {"t1", "t2", "t3"}
|
||||
# Each rewrite carries the new owner_id; titles preserved where present.
|
||||
# Each rewrite carries the new user_id; titles preserved where present.
|
||||
by_key = {call[1]: call[2] for call in aput_calls}
|
||||
assert by_key["t1"]["metadata"]["owner_id"] == "admin-id-42"
|
||||
assert by_key["t1"]["metadata"]["user_id"] == "admin-id-42"
|
||||
assert by_key["t1"]["metadata"]["title"] == "old-thread-1"
|
||||
assert by_key["t3"]["metadata"]["owner_id"] == "admin-id-42"
|
||||
assert by_key["t3"]["metadata"]["user_id"] == "admin-id-42"
|
||||
# The pre-owned item must NOT have been rewritten.
|
||||
assert "t4" not in rewritten_keys
|
||||
|
||||
|
||||
@@ -60,8 +60,8 @@ class TestFeedbackRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_owner(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
record = await repo.create(run_id="r1", thread_id="t1", rating=1, owner_id="user-1")
|
||||
assert record["owner_id"] == "user-1"
|
||||
record = await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
|
||||
assert record["user_id"] == "user-1"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -97,10 +97,10 @@ class TestFeedbackRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_run(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=-1)
|
||||
await repo.create(run_id="r2", thread_id="t1", rating=1)
|
||||
results = await repo.list_by_run("t1", "r1")
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="user-2")
|
||||
await repo.create(run_id="r2", thread_id="t1", rating=1, user_id="user-1")
|
||||
results = await repo.list_by_run("t1", "r1", user_id=None)
|
||||
assert len(results) == 2
|
||||
assert all(r["run_id"] == "r1" for r in results)
|
||||
await _cleanup()
|
||||
@@ -135,9 +135,9 @@ class TestFeedbackRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_aggregate_by_run(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=-1)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-2")
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="user-3")
|
||||
stats = await repo.aggregate_by_run("t1", "r1")
|
||||
assert stats["total"] == 3
|
||||
assert stats["positive"] == 2
|
||||
@@ -154,6 +154,80 @@ class TestFeedbackRepository:
|
||||
assert stats["negative"] == 0
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_creates_new(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
record = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||
assert record["rating"] == 1
|
||||
assert record["feedback_id"]
|
||||
assert record["user_id"] == "u1"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_updates_existing(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
first = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||
second = await repo.upsert(run_id="r1", thread_id="t1", rating=-1, user_id="u1", comment="changed my mind")
|
||||
assert second["feedback_id"] == first["feedback_id"]
|
||||
assert second["rating"] == -1
|
||||
assert second["comment"] == "changed my mind"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_different_users_separate(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
r1 = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||
r2 = await repo.upsert(run_id="r1", thread_id="t1", rating=-1, user_id="u2")
|
||||
assert r1["feedback_id"] != r2["feedback_id"]
|
||||
assert r1["rating"] == 1
|
||||
assert r2["rating"] == -1
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_invalid_rating(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
with pytest.raises(ValueError):
|
||||
await repo.upsert(run_id="r1", thread_id="t1", rating=0, user_id="u1")
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_by_run(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||
deleted = await repo.delete_by_run(thread_id="t1", run_id="r1", user_id="u1")
|
||||
assert deleted is True
|
||||
results = await repo.list_by_run("t1", "r1", user_id="u1")
|
||||
assert len(results) == 0
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_by_run_nonexistent(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
deleted = await repo.delete_by_run(thread_id="t1", run_id="r1", user_id="u1")
|
||||
assert deleted is False
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_grouped(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||
await repo.upsert(run_id="r2", thread_id="t1", rating=-1, user_id="u1")
|
||||
await repo.upsert(run_id="r3", thread_id="t2", rating=1, user_id="u1")
|
||||
grouped = await repo.list_by_thread_grouped("t1", user_id="u1")
|
||||
assert "r1" in grouped
|
||||
assert "r2" in grouped
|
||||
assert "r3" not in grouped
|
||||
assert grouped["r1"]["rating"] == 1
|
||||
assert grouped["r2"]["rating"] == -1
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_grouped_empty(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
grouped = await repo.list_by_thread_grouped("t1", user_id="u1")
|
||||
assert grouped == {}
|
||||
await _cleanup()
|
||||
|
||||
|
||||
# -- Follow-up association --
|
||||
|
||||
|
||||
@@ -152,8 +152,10 @@ def test_get_work_dir_uses_base_dir_when_no_thread_id(monkeypatch, tmp_path):
|
||||
def test_get_work_dir_uses_per_thread_path_when_thread_id_given(monkeypatch, tmp_path):
|
||||
"""P1.1: _get_work_dir(thread_id) uses {base_dir}/threads/{thread_id}/acp-workspace/."""
|
||||
from deerflow.config import paths as paths_module
|
||||
from deerflow.runtime import user_context as uc_module
|
||||
|
||||
monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path))
|
||||
monkeypatch.setattr(uc_module, "get_effective_user_id", lambda: None)
|
||||
result = _get_work_dir("thread-abc-123")
|
||||
expected = tmp_path / "threads" / "thread-abc-123" / "acp-workspace"
|
||||
assert result == str(expected)
|
||||
@@ -310,8 +312,10 @@ async def test_invoke_acp_agent_uses_fixed_acp_workspace(monkeypatch, tmp_path):
|
||||
async def test_invoke_acp_agent_uses_per_thread_workspace_when_thread_id_in_config(monkeypatch, tmp_path):
|
||||
"""P1.1: When thread_id is in the RunnableConfig, ACP agent uses per-thread workspace."""
|
||||
from deerflow.config import paths as paths_module
|
||||
from deerflow.runtime import user_context as uc_module
|
||||
|
||||
monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path))
|
||||
monkeypatch.setattr(uc_module, "get_effective_user_id", lambda: None)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
|
||||
|
||||
@@ -175,46 +175,46 @@ def _make_ctx(user_id):
|
||||
def test_filter_injects_user_id():
|
||||
value = {}
|
||||
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
||||
assert value["metadata"]["owner_id"] == "user-a"
|
||||
assert value["metadata"]["user_id"] == "user-a"
|
||||
|
||||
|
||||
def test_filter_preserves_existing_metadata():
|
||||
value = {"metadata": {"title": "hello"}}
|
||||
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
||||
assert value["metadata"]["owner_id"] == "user-a"
|
||||
assert value["metadata"]["user_id"] == "user-a"
|
||||
assert value["metadata"]["title"] == "hello"
|
||||
|
||||
|
||||
def test_filter_returns_user_id_dict():
|
||||
result = asyncio.run(add_owner_filter(_make_ctx("user-x"), {}))
|
||||
assert result == {"owner_id": "user-x"}
|
||||
assert result == {"user_id": "user-x"}
|
||||
|
||||
|
||||
def test_filter_read_write_consistency():
|
||||
value = {}
|
||||
filter_dict = asyncio.run(add_owner_filter(_make_ctx("user-1"), value))
|
||||
assert value["metadata"]["owner_id"] == filter_dict["owner_id"]
|
||||
assert value["metadata"]["user_id"] == filter_dict["user_id"]
|
||||
|
||||
|
||||
def test_different_users_different_filters():
|
||||
f_a = asyncio.run(add_owner_filter(_make_ctx("a"), {}))
|
||||
f_b = asyncio.run(add_owner_filter(_make_ctx("b"), {}))
|
||||
assert f_a["owner_id"] != f_b["owner_id"]
|
||||
assert f_a["user_id"] != f_b["user_id"]
|
||||
|
||||
|
||||
def test_filter_overrides_conflicting_user_id():
|
||||
"""If value already has a different user_id in metadata, it gets overwritten."""
|
||||
value = {"metadata": {"owner_id": "attacker"}}
|
||||
value = {"metadata": {"user_id": "attacker"}}
|
||||
asyncio.run(add_owner_filter(_make_ctx("real-owner"), value))
|
||||
assert value["metadata"]["owner_id"] == "real-owner"
|
||||
assert value["metadata"]["user_id"] == "real-owner"
|
||||
|
||||
|
||||
def test_filter_with_empty_metadata():
|
||||
"""Explicit empty metadata dict is fine."""
|
||||
value = {"metadata": {}}
|
||||
result = asyncio.run(add_owner_filter(_make_ctx("user-z"), value))
|
||||
assert value["metadata"]["owner_id"] == "user-z"
|
||||
assert result == {"owner_id": "user-z"}
|
||||
assert value["metadata"]["user_id"] == "user-z"
|
||||
assert result == {"user_id": "user-z"}
|
||||
|
||||
|
||||
# ── Gateway parity ───────────────────────────────────────────────────────
|
||||
|
||||
@@ -48,6 +48,7 @@ def test_process_queue_forwards_correction_flag_to_updater() -> None:
|
||||
agent_name="lead_agent",
|
||||
correction_detected=True,
|
||||
reinforcement_detected=False,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -88,4 +89,5 @@ def test_process_queue_forwards_reinforcement_flag_to_updater() -> None:
|
||||
agent_name="lead_agent",
|
||||
correction_detected=False,
|
||||
reinforcement_detected=True,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
"""Tests for user_id propagation through memory queue."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
|
||||
|
||||
|
||||
def test_conversation_context_has_user_id():
|
||||
ctx = ConversationContext(thread_id="t1", messages=[], user_id="alice")
|
||||
assert ctx.user_id == "alice"
|
||||
|
||||
|
||||
def test_conversation_context_user_id_default_none():
|
||||
ctx = ConversationContext(thread_id="t1", messages=[])
|
||||
assert ctx.user_id is None
|
||||
|
||||
|
||||
def test_queue_add_stores_user_id():
|
||||
q = MemoryUpdateQueue()
|
||||
with patch.object(q, "_reset_timer"):
|
||||
q.add(thread_id="t1", messages=["msg"], user_id="alice")
|
||||
assert len(q._queue) == 1
|
||||
assert q._queue[0].user_id == "alice"
|
||||
q.clear()
|
||||
|
||||
|
||||
def test_queue_process_passes_user_id_to_updater():
|
||||
q = MemoryUpdateQueue()
|
||||
with patch.object(q, "_reset_timer"):
|
||||
q.add(thread_id="t1", messages=["msg"], user_id="alice")
|
||||
|
||||
mock_updater = MagicMock()
|
||||
mock_updater.update_memory.return_value = True
|
||||
with patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater):
|
||||
q._process_queue()
|
||||
|
||||
mock_updater.update_memory.assert_called_once()
|
||||
call_kwargs = mock_updater.update_memory.call_args.kwargs
|
||||
assert call_kwargs["user_id"] == "alice"
|
||||
@@ -258,12 +258,13 @@ def test_update_memory_fact_route_preserves_omitted_fields() -> None:
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
update_fact.assert_called_once_with(
|
||||
fact_id="fact_edit",
|
||||
content="User prefers spaces",
|
||||
category=None,
|
||||
confidence=None,
|
||||
)
|
||||
assert update_fact.call_count == 1
|
||||
call_kwargs = update_fact.call_args.kwargs
|
||||
assert call_kwargs.get("fact_id") == "fact_edit"
|
||||
assert call_kwargs.get("content") == "User prefers spaces"
|
||||
assert call_kwargs.get("category") is None
|
||||
assert call_kwargs.get("confidence") is None
|
||||
assert "user_id" in call_kwargs
|
||||
assert response.json()["facts"] == updated_memory["facts"]
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,150 @@
|
||||
"""Tests for per-user memory storage isolation."""
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from deerflow.agents.memory.storage import FileMemoryStorage, create_empty_memory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_dir(tmp_path: Path) -> Path:
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage() -> FileMemoryStorage:
|
||||
return FileMemoryStorage()
|
||||
|
||||
|
||||
class TestUserIsolatedStorage:
|
||||
def test_save_and_load_per_user(self, storage: FileMemoryStorage, base_dir: Path):
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
memory_a = create_empty_memory()
|
||||
memory_a["user"]["workContext"]["summary"] = "User A context"
|
||||
storage.save(memory_a, user_id="alice")
|
||||
|
||||
memory_b = create_empty_memory()
|
||||
memory_b["user"]["workContext"]["summary"] = "User B context"
|
||||
storage.save(memory_b, user_id="bob")
|
||||
|
||||
loaded_a = storage.load(user_id="alice")
|
||||
loaded_b = storage.load(user_id="bob")
|
||||
|
||||
assert loaded_a["user"]["workContext"]["summary"] == "User A context"
|
||||
assert loaded_b["user"]["workContext"]["summary"] == "User B context"
|
||||
|
||||
def test_user_memory_file_location(self, base_dir: Path):
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage()
|
||||
memory = create_empty_memory()
|
||||
s.save(memory, user_id="alice")
|
||||
expected_path = base_dir / "users" / "alice" / "memory.json"
|
||||
assert expected_path.exists()
|
||||
|
||||
def test_cache_isolated_per_user(self, base_dir: Path):
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage()
|
||||
memory_a = create_empty_memory()
|
||||
memory_a["user"]["workContext"]["summary"] = "A"
|
||||
s.save(memory_a, user_id="alice")
|
||||
|
||||
memory_b = create_empty_memory()
|
||||
memory_b["user"]["workContext"]["summary"] = "B"
|
||||
s.save(memory_b, user_id="bob")
|
||||
|
||||
loaded_a = s.load(user_id="alice")
|
||||
assert loaded_a["user"]["workContext"]["summary"] == "A"
|
||||
|
||||
def test_no_user_id_uses_legacy_path(self, base_dir: Path):
|
||||
from deerflow.config.paths import Paths
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
|
||||
s = FileMemoryStorage()
|
||||
memory = create_empty_memory()
|
||||
s.save(memory, user_id=None)
|
||||
expected_path = base_dir / "memory.json"
|
||||
assert expected_path.exists()
|
||||
|
||||
def test_user_and_legacy_do_not_interfere(self, base_dir: Path):
|
||||
"""user_id=None (legacy) and user_id='alice' must use different files and caches."""
|
||||
from deerflow.config.paths import Paths
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
|
||||
s = FileMemoryStorage()
|
||||
|
||||
legacy_mem = create_empty_memory()
|
||||
legacy_mem["user"]["workContext"]["summary"] = "legacy"
|
||||
s.save(legacy_mem, user_id=None)
|
||||
|
||||
user_mem = create_empty_memory()
|
||||
user_mem["user"]["workContext"]["summary"] = "alice"
|
||||
s.save(user_mem, user_id="alice")
|
||||
|
||||
assert s.load(user_id=None)["user"]["workContext"]["summary"] == "legacy"
|
||||
assert s.load(user_id="alice")["user"]["workContext"]["summary"] == "alice"
|
||||
|
||||
def test_user_agent_memory_file_location(self, base_dir: Path):
|
||||
"""Per-user per-agent memory uses the user_agent_memory_file path."""
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage()
|
||||
memory = create_empty_memory()
|
||||
memory["user"]["workContext"]["summary"] = "agent scoped"
|
||||
s.save(memory, "test-agent", user_id="alice")
|
||||
expected_path = base_dir / "users" / "alice" / "agents" / "test-agent" / "memory.json"
|
||||
assert expected_path.exists()
|
||||
|
||||
def test_cache_key_is_user_agent_tuple(self, base_dir: Path):
|
||||
"""Cache keys must be (user_id, agent_name) tuples, not bare agent names."""
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage()
|
||||
memory = create_empty_memory()
|
||||
s.save(memory, user_id="alice")
|
||||
# After save, cache should have tuple key
|
||||
assert ("alice", None) in s._memory_cache
|
||||
|
||||
def test_reload_with_user_id(self, base_dir: Path):
|
||||
"""reload() with user_id should force re-read from the user-scoped file."""
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage()
|
||||
memory = create_empty_memory()
|
||||
memory["user"]["workContext"]["summary"] = "initial"
|
||||
s.save(memory, user_id="alice")
|
||||
|
||||
# Load once to prime cache
|
||||
s.load(user_id="alice")
|
||||
|
||||
# Write updated content directly to file
|
||||
user_file = base_dir / "users" / "alice" / "memory.json"
|
||||
import json
|
||||
|
||||
updated = create_empty_memory()
|
||||
updated["user"]["workContext"]["summary"] = "updated"
|
||||
user_file.write_text(json.dumps(updated))
|
||||
|
||||
# reload should pick up the new content
|
||||
reloaded = s.reload(user_id="alice")
|
||||
assert reloaded["user"]["workContext"]["summary"] == "updated"
|
||||
@@ -0,0 +1,156 @@
|
||||
"""Owner isolation tests for MemoryThreadMetaStore.
|
||||
|
||||
Mirrors the SQL-backed tests in test_owner_isolation.py but exercises
|
||||
the in-memory LangGraph Store backend used when database.backend=memory.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
|
||||
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
||||
|
||||
USER_A = SimpleNamespace(id="user-a", email="a@test.local")
|
||||
USER_B = SimpleNamespace(id="user-b", email="b@test.local")
|
||||
|
||||
|
||||
def _as_user(user):
|
||||
class _Ctx:
|
||||
def __enter__(self):
|
||||
self._token = set_current_user(user)
|
||||
return user
|
||||
|
||||
def __exit__(self, *exc):
|
||||
reset_current_user(self._token)
|
||||
|
||||
return _Ctx()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store():
|
||||
return MemoryThreadMetaStore(InMemoryStore())
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_search_isolation(store):
|
||||
"""search() returns only threads owned by the current user."""
|
||||
with _as_user(USER_A):
|
||||
await store.create("t-alpha", display_name="A's thread")
|
||||
with _as_user(USER_B):
|
||||
await store.create("t-beta", display_name="B's thread")
|
||||
|
||||
with _as_user(USER_A):
|
||||
results = await store.search()
|
||||
assert [r["thread_id"] for r in results] == ["t-alpha"]
|
||||
|
||||
with _as_user(USER_B):
|
||||
results = await store.search()
|
||||
assert [r["thread_id"] for r in results] == ["t-beta"]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_get_isolation(store):
|
||||
"""get() returns None for threads owned by another user."""
|
||||
with _as_user(USER_A):
|
||||
await store.create("t-alpha", display_name="A's thread")
|
||||
|
||||
with _as_user(USER_B):
|
||||
assert await store.get("t-alpha") is None
|
||||
|
||||
with _as_user(USER_A):
|
||||
result = await store.get("t-alpha")
|
||||
assert result is not None
|
||||
assert result["display_name"] == "A's thread"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_update_display_name_denied(store):
|
||||
"""User B cannot rename User A's thread."""
|
||||
with _as_user(USER_A):
|
||||
await store.create("t-alpha", display_name="original")
|
||||
|
||||
with _as_user(USER_B):
|
||||
await store.update_display_name("t-alpha", "hacked")
|
||||
|
||||
with _as_user(USER_A):
|
||||
row = await store.get("t-alpha")
|
||||
assert row is not None
|
||||
assert row["display_name"] == "original"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_update_status_denied(store):
|
||||
"""User B cannot change status of User A's thread."""
|
||||
with _as_user(USER_A):
|
||||
await store.create("t-alpha")
|
||||
|
||||
with _as_user(USER_B):
|
||||
await store.update_status("t-alpha", "error")
|
||||
|
||||
with _as_user(USER_A):
|
||||
row = await store.get("t-alpha")
|
||||
assert row is not None
|
||||
assert row["status"] == "idle"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_update_metadata_denied(store):
|
||||
"""User B cannot modify metadata of User A's thread."""
|
||||
with _as_user(USER_A):
|
||||
await store.create("t-alpha", metadata={"key": "original"})
|
||||
|
||||
with _as_user(USER_B):
|
||||
await store.update_metadata("t-alpha", {"key": "hacked"})
|
||||
|
||||
with _as_user(USER_A):
|
||||
row = await store.get("t-alpha")
|
||||
assert row is not None
|
||||
assert row["metadata"]["key"] == "original"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_delete_denied(store):
|
||||
"""User B cannot delete User A's thread."""
|
||||
with _as_user(USER_A):
|
||||
await store.create("t-alpha")
|
||||
|
||||
with _as_user(USER_B):
|
||||
await store.delete("t-alpha")
|
||||
|
||||
with _as_user(USER_A):
|
||||
row = await store.get("t-alpha")
|
||||
assert row is not None
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_no_context_raises(store):
|
||||
"""Calling methods without user context raises RuntimeError."""
|
||||
with pytest.raises(RuntimeError, match="no user context is set"):
|
||||
await store.search()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_explicit_none_bypasses_filter(store):
|
||||
"""user_id=None bypasses isolation (migration/CLI escape hatch)."""
|
||||
with _as_user(USER_A):
|
||||
await store.create("t-alpha")
|
||||
with _as_user(USER_B):
|
||||
await store.create("t-beta")
|
||||
|
||||
all_rows = await store.search(user_id=None)
|
||||
assert {r["thread_id"] for r in all_rows} == {"t-alpha", "t-beta"}
|
||||
|
||||
row = await store.get("t-alpha", user_id=None)
|
||||
assert row is not None
|
||||
@@ -301,8 +301,8 @@ def test_import_memory_data_saves_and_returns_imported_memory() -> None:
|
||||
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||
result = import_memory_data(imported_memory)
|
||||
|
||||
mock_storage.save.assert_called_once_with(imported_memory, None)
|
||||
mock_storage.load.assert_called_once_with(None)
|
||||
mock_storage.save.assert_called_once_with(imported_memory, None, user_id=None)
|
||||
mock_storage.load.assert_called_once_with(None, user_id=None)
|
||||
assert result == imported_memory
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
"""Tests for user_id propagation in memory updater."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from deerflow.agents.memory.updater import get_memory_data, clear_memory_data, _save_memory_to_file
|
||||
|
||||
|
||||
def test_get_memory_data_passes_user_id():
|
||||
mock_storage = MagicMock()
|
||||
mock_storage.load.return_value = {"version": "1.0"}
|
||||
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||
get_memory_data(user_id="alice")
|
||||
mock_storage.load.assert_called_once_with(None, user_id="alice")
|
||||
|
||||
|
||||
def test_save_memory_passes_user_id():
|
||||
mock_storage = MagicMock()
|
||||
mock_storage.save.return_value = True
|
||||
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||
_save_memory_to_file({"version": "1.0"}, user_id="bob")
|
||||
mock_storage.save.assert_called_once_with({"version": "1.0"}, None, user_id="bob")
|
||||
|
||||
|
||||
def test_clear_memory_data_passes_user_id():
|
||||
mock_storage = MagicMock()
|
||||
mock_storage.save.return_value = True
|
||||
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||
clear_memory_data(user_id="charlie")
|
||||
# Verify save was called with user_id
|
||||
assert mock_storage.save.call_args.kwargs["user_id"] == "charlie"
|
||||
@@ -0,0 +1,116 @@
|
||||
"""Tests for per-user data migration."""
|
||||
import json
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_dir(tmp_path: Path) -> Path:
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def paths(base_dir: Path) -> Paths:
|
||||
return Paths(base_dir)
|
||||
|
||||
|
||||
class TestMigrateThreadDirs:
|
||||
def test_moves_thread_to_user_dir(self, base_dir: Path, paths: Paths):
|
||||
legacy = base_dir / "threads" / "t1" / "user-data" / "workspace"
|
||||
legacy.mkdir(parents=True)
|
||||
(legacy / "file.txt").write_text("hello")
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||
migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"})
|
||||
|
||||
expected = base_dir / "users" / "alice" / "threads" / "t1" / "user-data" / "workspace" / "file.txt"
|
||||
assert expected.exists()
|
||||
assert expected.read_text() == "hello"
|
||||
assert not (base_dir / "threads" / "t1").exists()
|
||||
|
||||
def test_unowned_thread_goes_to_default(self, base_dir: Path, paths: Paths):
|
||||
legacy = base_dir / "threads" / "t2" / "user-data" / "workspace"
|
||||
legacy.mkdir(parents=True)
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||
migrate_thread_dirs(paths, thread_owner_map={})
|
||||
|
||||
expected = base_dir / "users" / "default" / "threads" / "t2"
|
||||
assert expected.exists()
|
||||
|
||||
def test_idempotent_skip_already_migrated(self, base_dir: Path, paths: Paths):
|
||||
new_dir = base_dir / "users" / "alice" / "threads" / "t1" / "user-data" / "workspace"
|
||||
new_dir.mkdir(parents=True)
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||
migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"})
|
||||
assert new_dir.exists()
|
||||
|
||||
def test_conflict_preserved(self, base_dir: Path, paths: Paths):
|
||||
legacy = base_dir / "threads" / "t1" / "user-data" / "workspace"
|
||||
legacy.mkdir(parents=True)
|
||||
(legacy / "old.txt").write_text("old")
|
||||
|
||||
dest = base_dir / "users" / "alice" / "threads" / "t1" / "user-data" / "workspace"
|
||||
dest.mkdir(parents=True)
|
||||
(dest / "new.txt").write_text("new")
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||
migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"})
|
||||
|
||||
assert (dest / "new.txt").read_text() == "new"
|
||||
conflicts = base_dir / "migration-conflicts" / "t1"
|
||||
assert conflicts.exists()
|
||||
|
||||
def test_cleans_up_empty_legacy_dir(self, base_dir: Path, paths: Paths):
|
||||
legacy = base_dir / "threads" / "t1" / "user-data"
|
||||
legacy.mkdir(parents=True)
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||
migrate_thread_dirs(paths, thread_owner_map={})
|
||||
|
||||
assert not (base_dir / "threads").exists()
|
||||
|
||||
def test_dry_run_does_not_move(self, base_dir: Path, paths: Paths):
|
||||
legacy = base_dir / "threads" / "t1" / "user-data"
|
||||
legacy.mkdir(parents=True)
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||
report = migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"}, dry_run=True)
|
||||
|
||||
assert len(report) == 1
|
||||
assert (base_dir / "threads" / "t1").exists() # not moved
|
||||
assert not (base_dir / "users" / "alice" / "threads" / "t1").exists()
|
||||
|
||||
|
||||
class TestMigrateMemory:
|
||||
def test_moves_global_memory(self, base_dir: Path, paths: Paths):
|
||||
legacy_mem = base_dir / "memory.json"
|
||||
legacy_mem.write_text(json.dumps({"version": "1.0", "facts": []}))
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_memory
|
||||
migrate_memory(paths, user_id="default")
|
||||
|
||||
expected = base_dir / "users" / "default" / "memory.json"
|
||||
assert expected.exists()
|
||||
assert not legacy_mem.exists()
|
||||
|
||||
def test_skips_if_destination_exists(self, base_dir: Path, paths: Paths):
|
||||
legacy_mem = base_dir / "memory.json"
|
||||
legacy_mem.write_text(json.dumps({"version": "old"}))
|
||||
|
||||
dest = base_dir / "users" / "default" / "memory.json"
|
||||
dest.parent.mkdir(parents=True)
|
||||
dest.write_text(json.dumps({"version": "new"}))
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_memory
|
||||
migrate_memory(paths, user_id="default")
|
||||
|
||||
assert json.loads(dest.read_text())["version"] == "new"
|
||||
assert (base_dir / "memory.legacy.json").exists()
|
||||
|
||||
def test_no_legacy_memory_is_noop(self, base_dir: Path, paths: Paths):
|
||||
from scripts.migrate_user_isolation import migrate_memory
|
||||
migrate_memory(paths, user_id="default") # should not raise
|
||||
@@ -9,8 +9,8 @@ These tests bypass the HTTP layer and exercise the storage-layer
|
||||
owner filter directly by switching the ``user_context`` contextvar
|
||||
between two users. The safety property under test is:
|
||||
|
||||
After a repository write with owner_id=A, a subsequent read with
|
||||
owner_id=B must not return the row, and vice versa.
|
||||
After a repository write with user_id=A, a subsequent read with
|
||||
user_id=B must not return the row, and vice versa.
|
||||
|
||||
The HTTP layer is covered by test_auth_middleware.py, which proves
|
||||
that a request cookie reaches the ``set_current_user`` call. Together
|
||||
@@ -431,13 +431,13 @@ async def test_repository_without_context_raises(tmp_path):
|
||||
await cleanup()
|
||||
|
||||
|
||||
# ── Escape hatch: explicit owner_id=None bypasses filter (for migration) ──
|
||||
# ── Escape hatch: explicit user_id=None bypasses filter (for migration) ──
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_explicit_none_bypasses_filter(tmp_path):
|
||||
"""Migration scripts pass owner_id=None to see all rows regardless of owner."""
|
||||
"""Migration scripts pass user_id=None to see all rows regardless of owner."""
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
||||
|
||||
@@ -452,14 +452,14 @@ async def test_explicit_none_bypasses_filter(tmp_path):
|
||||
await repo.create("t-beta")
|
||||
|
||||
# Migration-style read: no contextvar, explicit None bypass.
|
||||
all_rows = await repo.search(owner_id=None)
|
||||
all_rows = await repo.search(user_id=None)
|
||||
thread_ids = {r["thread_id"] for r in all_rows}
|
||||
assert thread_ids == {"t-alpha", "t-beta"}
|
||||
|
||||
# Explicit get with None does not apply the filter either.
|
||||
row_a = await repo.get("t-alpha", owner_id=None)
|
||||
row_a = await repo.get("t-alpha", user_id=None)
|
||||
assert row_a is not None
|
||||
row_b = await repo.get("t-beta", owner_id=None)
|
||||
row_b = await repo.get("t-beta", user_id=None)
|
||||
assert row_b is not None
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
@@ -0,0 +1,167 @@
|
||||
"""Tests for user-scoped path resolution in Paths."""
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def paths(tmp_path: Path) -> Paths:
|
||||
return Paths(tmp_path)
|
||||
|
||||
|
||||
class TestValidateUserId:
|
||||
def test_valid_user_id(self, paths: Paths):
|
||||
d = paths.user_dir("u-abc-123")
|
||||
assert d == paths.base_dir / "users" / "u-abc-123"
|
||||
|
||||
def test_rejects_path_traversal(self, paths: Paths):
|
||||
with pytest.raises(ValueError, match="Invalid user_id"):
|
||||
paths.user_dir("../escape")
|
||||
|
||||
def test_rejects_slash(self, paths: Paths):
|
||||
with pytest.raises(ValueError, match="Invalid user_id"):
|
||||
paths.user_dir("foo/bar")
|
||||
|
||||
def test_rejects_empty(self, paths: Paths):
|
||||
with pytest.raises(ValueError, match="Invalid user_id"):
|
||||
paths.user_dir("")
|
||||
|
||||
|
||||
class TestUserDir:
|
||||
def test_user_dir(self, paths: Paths):
|
||||
assert paths.user_dir("alice") == paths.base_dir / "users" / "alice"
|
||||
|
||||
|
||||
class TestUserMemoryFile:
|
||||
def test_user_memory_file(self, paths: Paths):
|
||||
assert paths.user_memory_file("bob") == paths.base_dir / "users" / "bob" / "memory.json"
|
||||
|
||||
|
||||
class TestUserAgentMemoryFile:
|
||||
def test_user_agent_memory_file(self, paths: Paths):
|
||||
expected = paths.base_dir / "users" / "bob" / "agents" / "myagent" / "memory.json"
|
||||
assert paths.user_agent_memory_file("bob", "myagent") == expected
|
||||
|
||||
def test_user_agent_memory_file_lowercases_name(self, paths: Paths):
|
||||
expected = paths.base_dir / "users" / "bob" / "agents" / "myagent" / "memory.json"
|
||||
assert paths.user_agent_memory_file("bob", "MyAgent") == expected
|
||||
|
||||
|
||||
class TestUserThreadDir:
|
||||
def test_user_thread_dir(self, paths: Paths):
|
||||
expected = paths.base_dir / "users" / "u1" / "threads" / "t1"
|
||||
assert paths.thread_dir("t1", user_id="u1") == expected
|
||||
|
||||
def test_thread_dir_no_user_id_falls_back_to_legacy(self, paths: Paths):
|
||||
expected = paths.base_dir / "threads" / "t1"
|
||||
assert paths.thread_dir("t1") == expected
|
||||
|
||||
|
||||
class TestUserSandboxDirs:
|
||||
def test_sandbox_work_dir(self, paths: Paths):
|
||||
expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "user-data" / "workspace"
|
||||
assert paths.sandbox_work_dir("t1", user_id="u1") == expected
|
||||
|
||||
def test_sandbox_uploads_dir(self, paths: Paths):
|
||||
expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "user-data" / "uploads"
|
||||
assert paths.sandbox_uploads_dir("t1", user_id="u1") == expected
|
||||
|
||||
def test_sandbox_outputs_dir(self, paths: Paths):
|
||||
expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "user-data" / "outputs"
|
||||
assert paths.sandbox_outputs_dir("t1", user_id="u1") == expected
|
||||
|
||||
def test_sandbox_user_data_dir(self, paths: Paths):
|
||||
expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "user-data"
|
||||
assert paths.sandbox_user_data_dir("t1", user_id="u1") == expected
|
||||
|
||||
def test_acp_workspace_dir(self, paths: Paths):
|
||||
expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "acp-workspace"
|
||||
assert paths.acp_workspace_dir("t1", user_id="u1") == expected
|
||||
|
||||
def test_legacy_sandbox_work_dir(self, paths: Paths):
|
||||
expected = paths.base_dir / "threads" / "t1" / "user-data" / "workspace"
|
||||
assert paths.sandbox_work_dir("t1") == expected
|
||||
|
||||
|
||||
class TestHostPathsWithUserId:
|
||||
def test_host_thread_dir_with_user_id(self, paths: Paths):
|
||||
result = paths.host_thread_dir("t1", user_id="u1")
|
||||
assert "users" in result
|
||||
assert "u1" in result
|
||||
assert "threads" in result
|
||||
assert "t1" in result
|
||||
|
||||
def test_host_thread_dir_legacy(self, paths: Paths):
|
||||
result = paths.host_thread_dir("t1")
|
||||
assert "threads" in result
|
||||
assert "t1" in result
|
||||
assert "users" not in result
|
||||
|
||||
def test_host_sandbox_user_data_dir_with_user_id(self, paths: Paths):
|
||||
result = paths.host_sandbox_user_data_dir("t1", user_id="u1")
|
||||
assert "users" in result
|
||||
assert "user-data" in result
|
||||
|
||||
def test_host_sandbox_work_dir_with_user_id(self, paths: Paths):
|
||||
result = paths.host_sandbox_work_dir("t1", user_id="u1")
|
||||
assert "workspace" in result
|
||||
|
||||
def test_host_sandbox_uploads_dir_with_user_id(self, paths: Paths):
|
||||
result = paths.host_sandbox_uploads_dir("t1", user_id="u1")
|
||||
assert "uploads" in result
|
||||
|
||||
def test_host_sandbox_outputs_dir_with_user_id(self, paths: Paths):
|
||||
result = paths.host_sandbox_outputs_dir("t1", user_id="u1")
|
||||
assert "outputs" in result
|
||||
|
||||
def test_host_acp_workspace_dir_with_user_id(self, paths: Paths):
|
||||
result = paths.host_acp_workspace_dir("t1", user_id="u1")
|
||||
assert "acp-workspace" in result
|
||||
|
||||
|
||||
class TestEnsureAndDeleteWithUserId:
|
||||
def test_ensure_thread_dirs_creates_user_scoped(self, paths: Paths):
|
||||
paths.ensure_thread_dirs("t1", user_id="u1")
|
||||
assert paths.sandbox_work_dir("t1", user_id="u1").is_dir()
|
||||
assert paths.sandbox_uploads_dir("t1", user_id="u1").is_dir()
|
||||
assert paths.sandbox_outputs_dir("t1", user_id="u1").is_dir()
|
||||
assert paths.acp_workspace_dir("t1", user_id="u1").is_dir()
|
||||
|
||||
def test_delete_thread_dir_removes_user_scoped(self, paths: Paths):
|
||||
paths.ensure_thread_dirs("t1", user_id="u1")
|
||||
assert paths.thread_dir("t1", user_id="u1").exists()
|
||||
paths.delete_thread_dir("t1", user_id="u1")
|
||||
assert not paths.thread_dir("t1", user_id="u1").exists()
|
||||
|
||||
def test_delete_thread_dir_idempotent(self, paths: Paths):
|
||||
paths.delete_thread_dir("nonexistent", user_id="u1") # should not raise
|
||||
|
||||
def test_ensure_thread_dirs_legacy_still_works(self, paths: Paths):
|
||||
paths.ensure_thread_dirs("t1")
|
||||
assert paths.sandbox_work_dir("t1").is_dir()
|
||||
|
||||
def test_user_scoped_and_legacy_are_independent(self, paths: Paths):
|
||||
paths.ensure_thread_dirs("t1", user_id="u1")
|
||||
paths.ensure_thread_dirs("t1")
|
||||
# Both exist independently
|
||||
assert paths.thread_dir("t1", user_id="u1").exists()
|
||||
assert paths.thread_dir("t1").exists()
|
||||
# Delete one doesn't affect the other
|
||||
paths.delete_thread_dir("t1", user_id="u1")
|
||||
assert not paths.thread_dir("t1", user_id="u1").exists()
|
||||
assert paths.thread_dir("t1").exists()
|
||||
|
||||
|
||||
class TestResolveVirtualPathWithUserId:
|
||||
def test_resolve_virtual_path_with_user_id(self, paths: Paths):
|
||||
paths.ensure_thread_dirs("t1", user_id="u1")
|
||||
result = paths.resolve_virtual_path("t1", "/mnt/user-data/workspace/file.txt", user_id="u1")
|
||||
expected_base = paths.sandbox_user_data_dir("t1", user_id="u1").resolve()
|
||||
assert str(result).startswith(str(expected_base))
|
||||
|
||||
def test_resolve_virtual_path_legacy(self, paths: Paths):
|
||||
paths.ensure_thread_dirs("t1")
|
||||
result = paths.resolve_virtual_path("t1", "/mnt/user-data/workspace/file.txt")
|
||||
expected_base = paths.sandbox_user_data_dir("t1").resolve()
|
||||
assert str(result).startswith(str(expected_base))
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
Tests:
|
||||
1. DatabaseConfig property derivation (paths, URLs)
|
||||
2. MemoryRunStore CRUD + owner_id filtering
|
||||
2. MemoryRunStore CRUD + user_id filtering
|
||||
3. Base.to_dict() via inspect mixin
|
||||
4. Engine init/close lifecycle (memory + SQLite)
|
||||
5. Postgres missing-dep error message
|
||||
@@ -24,18 +24,19 @@ class TestDatabaseConfig:
|
||||
assert c.backend == "memory"
|
||||
assert c.pool_size == 5
|
||||
|
||||
def test_sqlite_paths_are_different(self):
|
||||
def test_sqlite_paths_unified(self):
|
||||
c = DatabaseConfig(backend="sqlite", sqlite_dir="./mydata")
|
||||
assert c.checkpointer_sqlite_path.endswith("checkpoints.db")
|
||||
assert c.app_sqlite_path.endswith("app.db")
|
||||
assert "mydata" in c.checkpointer_sqlite_path
|
||||
assert c.checkpointer_sqlite_path != c.app_sqlite_path
|
||||
assert c.sqlite_path.endswith("deerflow.db")
|
||||
assert "mydata" in c.sqlite_path
|
||||
# Backward-compatible aliases point to the same file
|
||||
assert c.checkpointer_sqlite_path == c.sqlite_path
|
||||
assert c.app_sqlite_path == c.sqlite_path
|
||||
|
||||
def test_app_sqlalchemy_url_sqlite(self):
|
||||
c = DatabaseConfig(backend="sqlite", sqlite_dir="./data")
|
||||
url = c.app_sqlalchemy_url
|
||||
assert url.startswith("sqlite+aiosqlite:///")
|
||||
assert "app.db" in url
|
||||
assert "deerflow.db" in url
|
||||
|
||||
def test_app_sqlalchemy_url_postgres(self):
|
||||
c = DatabaseConfig(
|
||||
@@ -105,17 +106,17 @@ class TestMemoryRunStore:
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_owner_filter(self, store):
|
||||
await store.put("r1", thread_id="t1", owner_id="alice")
|
||||
await store.put("r2", thread_id="t1", owner_id="bob")
|
||||
rows = await store.list_by_thread("t1", owner_id="alice")
|
||||
await store.put("r1", thread_id="t1", user_id="alice")
|
||||
await store.put("r2", thread_id="t1", user_id="bob")
|
||||
rows = await store.list_by_thread("t1", user_id="alice")
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["owner_id"] == "alice"
|
||||
assert rows[0]["user_id"] == "alice"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_owner_none_returns_all(self, store):
|
||||
await store.put("r1", thread_id="t1", owner_id="alice")
|
||||
await store.put("r2", thread_id="t1", owner_id="bob")
|
||||
rows = await store.list_by_thread("t1", owner_id=None)
|
||||
await store.put("r1", thread_id="t1", user_id="alice")
|
||||
await store.put("r2", thread_id="t1", user_id="bob")
|
||||
rows = await store.list_by_thread("t1", user_id=None)
|
||||
assert len(rows) == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
|
||||
@@ -38,7 +38,7 @@ def test_present_files_keeps_virtual_outputs_path(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
present_file_tool_module,
|
||||
"get_paths",
|
||||
lambda: SimpleNamespace(resolve_virtual_path=lambda thread_id, path: artifact_path),
|
||||
lambda: SimpleNamespace(resolve_virtual_path=lambda thread_id, path, *, user_id=None: artifact_path),
|
||||
)
|
||||
|
||||
result = present_file_tool_module.present_file_tool.func(
|
||||
|
||||
@@ -0,0 +1,107 @@
|
||||
"""Tests for paginated list_messages_by_run across all RunEventStore backends."""
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_store():
|
||||
return MemoryRunEventStore()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_messages_by_run_default_returns_all(base_store):
|
||||
store = base_store
|
||||
for i in range(7):
|
||||
await store.put(
|
||||
thread_id="t1", run_id="run-a",
|
||||
event_type="human_message" if i % 2 == 0 else "ai_message",
|
||||
category="message", content=f"msg-a-{i}",
|
||||
)
|
||||
for i in range(3):
|
||||
await store.put(
|
||||
thread_id="t1", run_id="run-b",
|
||||
event_type="human_message", category="message", content=f"msg-b-{i}",
|
||||
)
|
||||
await store.put(thread_id="t1", run_id="run-a", event_type="tool_call", category="trace", content="trace")
|
||||
|
||||
msgs = await store.list_messages_by_run("t1", "run-a")
|
||||
assert len(msgs) == 7
|
||||
assert all(m["category"] == "message" for m in msgs)
|
||||
assert all(m["run_id"] == "run-a" for m in msgs)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_messages_by_run_with_limit(base_store):
|
||||
store = base_store
|
||||
for i in range(7):
|
||||
await store.put(
|
||||
thread_id="t1", run_id="run-a",
|
||||
event_type="human_message" if i % 2 == 0 else "ai_message",
|
||||
category="message", content=f"msg-a-{i}",
|
||||
)
|
||||
|
||||
msgs = await store.list_messages_by_run("t1", "run-a", limit=3)
|
||||
assert len(msgs) == 3
|
||||
seqs = [m["seq"] for m in msgs]
|
||||
assert seqs == sorted(seqs)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_messages_by_run_after_seq(base_store):
|
||||
store = base_store
|
||||
for i in range(7):
|
||||
await store.put(
|
||||
thread_id="t1", run_id="run-a",
|
||||
event_type="human_message" if i % 2 == 0 else "ai_message",
|
||||
category="message", content=f"msg-a-{i}",
|
||||
)
|
||||
|
||||
all_msgs = await store.list_messages_by_run("t1", "run-a")
|
||||
cursor_seq = all_msgs[2]["seq"]
|
||||
msgs = await store.list_messages_by_run("t1", "run-a", after_seq=cursor_seq, limit=50)
|
||||
assert all(m["seq"] > cursor_seq for m in msgs)
|
||||
assert len(msgs) == 4
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_messages_by_run_before_seq(base_store):
|
||||
store = base_store
|
||||
for i in range(7):
|
||||
await store.put(
|
||||
thread_id="t1", run_id="run-a",
|
||||
event_type="human_message" if i % 2 == 0 else "ai_message",
|
||||
category="message", content=f"msg-a-{i}",
|
||||
)
|
||||
|
||||
all_msgs = await store.list_messages_by_run("t1", "run-a")
|
||||
cursor_seq = all_msgs[4]["seq"]
|
||||
msgs = await store.list_messages_by_run("t1", "run-a", before_seq=cursor_seq, limit=50)
|
||||
assert all(m["seq"] < cursor_seq for m in msgs)
|
||||
assert len(msgs) == 4
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_messages_by_run_does_not_include_other_run(base_store):
|
||||
store = base_store
|
||||
for i in range(7):
|
||||
await store.put(
|
||||
thread_id="t1", run_id="run-a",
|
||||
event_type="human_message", category="message", content=f"msg-a-{i}",
|
||||
)
|
||||
for i in range(3):
|
||||
await store.put(
|
||||
thread_id="t1", run_id="run-b",
|
||||
event_type="human_message", category="message", content=f"msg-b-{i}",
|
||||
)
|
||||
|
||||
msgs = await store.list_messages_by_run("t1", "run-b")
|
||||
assert len(msgs) == 3
|
||||
assert all(m["run_id"] == "run-b" for m in msgs)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_messages_by_run_empty_run(base_store):
|
||||
store = base_store
|
||||
msgs = await store.list_messages_by_run("t1", "nonexistent")
|
||||
assert msgs == []
|
||||
@@ -709,6 +709,81 @@ class TestToolResultMessage:
|
||||
assert tool_end["metadata"]["tool_call_id"] == "call_from_obj"
|
||||
assert tool_end["metadata"]["tool_name"] == "web_search"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_invoke_end_to_end_unwraps_command(self, journal_setup):
|
||||
"""End-to-end: invoke a real LangChain tool that returns Command(update={'messages':[ToolMessage]}).
|
||||
|
||||
This goes through the real LangChain callback path (tool.invoke -> CallbackManager
|
||||
-> on_tool_start/on_tool_end), which is what the production agent uses. Mirrors
|
||||
the ``present_files`` tool shape exactly.
|
||||
"""
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import Command
|
||||
|
||||
j, store = journal_setup
|
||||
|
||||
@tool
|
||||
def fake_present_files(filepaths: list[str]) -> Command:
|
||||
"""Fake present_files that returns a Command with an inner ToolMessage."""
|
||||
return Command(
|
||||
update={
|
||||
"artifacts": filepaths,
|
||||
"messages": [ToolMessage("Successfully presented files", tool_call_id="tc_123")],
|
||||
},
|
||||
)
|
||||
|
||||
# Real LangChain callback dispatch (matches production agent path)
|
||||
cm = CallbackManager(handlers=[j])
|
||||
fake_present_files.invoke(
|
||||
{"filepaths": ["/mnt/user-data/outputs/report.md"]},
|
||||
config={"callbacks": cm, "run_id": uuid4()},
|
||||
)
|
||||
await j.flush()
|
||||
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1, f"expected 1 message event, got {len(messages)}: {messages}"
|
||||
content = messages[0]["content"]
|
||||
assert content["type"] == "tool"
|
||||
# CRITICAL: must be the inner ToolMessage text, not str(Command(...))
|
||||
assert content["content"] == "Successfully presented files", (
|
||||
f"Command unwrap failed; stored content = {content['content']!r}"
|
||||
)
|
||||
assert "Command(update=" not in str(content["content"])
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_end_unwraps_command_with_inner_tool_message(self, journal_setup):
|
||||
"""Tools like ``present_files`` return Command(update={'messages': [ToolMessage(...)]}).
|
||||
|
||||
LangGraph unwraps the inner ToolMessage into checkpoint state, so the
|
||||
event store must do the same — otherwise it captures ``str(Command(...))``
|
||||
and the /history response diverges from the real rendered message.
|
||||
"""
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
j, store = journal_setup
|
||||
run_id = uuid4()
|
||||
inner = ToolMessage(
|
||||
content="Successfully presented files",
|
||||
tool_call_id="call_present",
|
||||
name="present_files",
|
||||
status="success",
|
||||
)
|
||||
cmd = Command(update={"artifacts": ["/mnt/user-data/outputs/report.md"], "messages": [inner]})
|
||||
j.on_tool_end(cmd, run_id=run_id)
|
||||
await j.flush()
|
||||
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
content = messages[0]["content"]
|
||||
assert content["type"] == "tool"
|
||||
assert content["content"] == "Successfully presented files"
|
||||
assert content["tool_call_id"] == "call_present"
|
||||
assert content["name"] == "present_files"
|
||||
assert "Command(update=" not in str(content["content"])
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_message_object_overrides_kwargs(self, journal_setup):
|
||||
"""ToolMessage object fields take priority over kwargs."""
|
||||
|
||||
@@ -73,11 +73,11 @@ class TestRunRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_owner_filter(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", owner_id="alice")
|
||||
await repo.put("r2", thread_id="t1", owner_id="bob")
|
||||
rows = await repo.list_by_thread("t1", owner_id="alice")
|
||||
await repo.put("r1", thread_id="t1", user_id="alice")
|
||||
await repo.put("r2", thread_id="t1", user_id="bob")
|
||||
rows = await repo.list_by_thread("t1", user_id="alice")
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["owner_id"] == "alice"
|
||||
assert rows[0]["user_id"] == "alice"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -189,8 +189,8 @@ class TestRunRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_owner_none_returns_all(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", owner_id="alice")
|
||||
await repo.put("r2", thread_id="t1", owner_id="bob")
|
||||
rows = await repo.list_by_thread("t1", owner_id=None)
|
||||
await repo.put("r1", thread_id="t1", user_id="alice")
|
||||
await repo.put("r2", thread_id="t1", user_id="bob")
|
||||
rows = await repo.list_by_thread("t1", user_id=None)
|
||||
assert len(rows) == 2
|
||||
await _cleanup()
|
||||
|
||||
@@ -0,0 +1,243 @@
|
||||
"""Tests for GET /api/runs/{run_id}/messages and GET /api/runs/{run_id}/feedback endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from _router_auth_helpers import make_authed_test_app
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.routers import runs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_app(run_store=None, event_store=None, feedback_repo=None):
|
||||
"""Build a test FastAPI app with stub auth and mocked state."""
|
||||
app = make_authed_test_app()
|
||||
app.include_router(runs.router)
|
||||
|
||||
if run_store is not None:
|
||||
app.state.run_store = run_store
|
||||
if event_store is not None:
|
||||
app.state.run_event_store = event_store
|
||||
if feedback_repo is not None:
|
||||
app.state.feedback_repo = feedback_repo
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def _make_run_store(run_record: dict | None):
|
||||
"""Return an AsyncMock run store whose get() returns run_record."""
|
||||
store = MagicMock()
|
||||
store.get = AsyncMock(return_value=run_record)
|
||||
return store
|
||||
|
||||
|
||||
def _make_event_store(rows: list[dict]):
|
||||
"""Return an AsyncMock event store whose list_messages_by_run() returns rows."""
|
||||
store = MagicMock()
|
||||
store.list_messages_by_run = AsyncMock(return_value=rows)
|
||||
return store
|
||||
|
||||
|
||||
def _make_message(seq: int) -> dict:
|
||||
return {"seq": seq, "event_type": "on_chat_model_stream", "category": "message", "content": f"msg-{seq}"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_run_messages_returns_envelope():
|
||||
"""GET /api/runs/{run_id}/messages returns {data: [...], has_more: bool}."""
|
||||
rows = [_make_message(i) for i in range(1, 4)]
|
||||
run_record = {"run_id": "run-1", "thread_id": "thread-1"}
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
event_store=_make_event_store(rows),
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-1/messages")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert "data" in body
|
||||
assert "has_more" in body
|
||||
assert body["has_more"] is False
|
||||
assert len(body["data"]) == 3
|
||||
|
||||
|
||||
def test_run_messages_404_when_run_not_found():
|
||||
"""Returns 404 when the run store returns None."""
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(None),
|
||||
event_store=_make_event_store([]),
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/missing-run/messages")
|
||||
assert response.status_code == 404
|
||||
assert "missing-run" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_run_messages_has_more_true_when_extra_row_returned():
|
||||
"""has_more=True when event store returns limit+1 rows."""
|
||||
# Default limit is 50; provide 51 rows
|
||||
rows = [_make_message(i) for i in range(1, 52)] # 51 rows
|
||||
run_record = {"run_id": "run-2", "thread_id": "thread-2"}
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
event_store=_make_event_store(rows),
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-2/messages")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["has_more"] is True
|
||||
assert len(body["data"]) == 50 # trimmed to limit
|
||||
|
||||
|
||||
def test_run_messages_passes_after_seq_to_event_store():
|
||||
"""after_seq query param is forwarded to event_store.list_messages_by_run."""
|
||||
rows = [_make_message(10)]
|
||||
run_record = {"run_id": "run-3", "thread_id": "thread-3"}
|
||||
event_store = _make_event_store(rows)
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
event_store=event_store,
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-3/messages?after_seq=5")
|
||||
assert response.status_code == 200
|
||||
event_store.list_messages_by_run.assert_awaited_once_with(
|
||||
"thread-3", "run-3",
|
||||
limit=51, # default limit(50) + 1
|
||||
before_seq=None,
|
||||
after_seq=5,
|
||||
)
|
||||
|
||||
|
||||
def test_run_messages_respects_custom_limit():
|
||||
"""Custom limit is respected and capped at 200."""
|
||||
rows = [_make_message(i) for i in range(1, 6)]
|
||||
run_record = {"run_id": "run-4", "thread_id": "thread-4"}
|
||||
event_store = _make_event_store(rows)
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
event_store=event_store,
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-4/messages?limit=10")
|
||||
assert response.status_code == 200
|
||||
event_store.list_messages_by_run.assert_awaited_once_with(
|
||||
"thread-4", "run-4",
|
||||
limit=11, # 10 + 1
|
||||
before_seq=None,
|
||||
after_seq=None,
|
||||
)
|
||||
|
||||
|
||||
def test_run_messages_passes_before_seq_to_event_store():
|
||||
"""before_seq query param is forwarded to event_store.list_messages_by_run."""
|
||||
rows = [_make_message(3)]
|
||||
run_record = {"run_id": "run-5", "thread_id": "thread-5"}
|
||||
event_store = _make_event_store(rows)
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
event_store=event_store,
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-5/messages?before_seq=10")
|
||||
assert response.status_code == 200
|
||||
event_store.list_messages_by_run.assert_awaited_once_with(
|
||||
"thread-5", "run-5",
|
||||
limit=51,
|
||||
before_seq=10,
|
||||
after_seq=None,
|
||||
)
|
||||
|
||||
|
||||
def test_run_messages_empty_data():
|
||||
"""Returns empty data list when no messages exist."""
|
||||
run_record = {"run_id": "run-6", "thread_id": "thread-6"}
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
event_store=_make_event_store([]),
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-6/messages")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["data"] == []
|
||||
assert body["has_more"] is False
|
||||
|
||||
|
||||
def _make_feedback_repo(rows: list[dict]):
|
||||
"""Return an AsyncMock feedback repo whose list_by_run() returns rows."""
|
||||
repo = MagicMock()
|
||||
repo.list_by_run = AsyncMock(return_value=rows)
|
||||
return repo
|
||||
|
||||
|
||||
def _make_feedback(run_id: str, idx: int) -> dict:
|
||||
return {"id": f"fb-{idx}", "run_id": run_id, "thread_id": "thread-x", "value": "up"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestRunFeedback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunFeedback:
|
||||
def test_returns_list_of_feedback_dicts(self):
|
||||
"""GET /api/runs/{run_id}/feedback returns a list of feedback dicts."""
|
||||
run_record = {"run_id": "run-fb-1", "thread_id": "thread-fb-1"}
|
||||
rows = [_make_feedback("run-fb-1", i) for i in range(3)]
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
feedback_repo=_make_feedback_repo(rows),
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-fb-1/feedback")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert isinstance(body, list)
|
||||
assert len(body) == 3
|
||||
|
||||
def test_404_when_run_not_found(self):
|
||||
"""Returns 404 when run store returns None."""
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(None),
|
||||
feedback_repo=_make_feedback_repo([]),
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/missing-run/feedback")
|
||||
assert response.status_code == 404
|
||||
assert "missing-run" in response.json()["detail"]
|
||||
|
||||
def test_empty_list_when_no_feedback(self):
|
||||
"""Returns empty list when no feedback exists for the run."""
|
||||
run_record = {"run_id": "run-fb-2", "thread_id": "thread-fb-2"}
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
feedback_repo=_make_feedback_repo([]),
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-fb-2/feedback")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
def test_503_when_feedback_repo_not_configured(self):
|
||||
"""Returns 503 when feedback_repo is None (no DB configured)."""
|
||||
run_record = {"run_id": "run-fb-3", "thread_id": "thread-fb-3"}
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
)
|
||||
# Explicitly set feedback_repo to None to simulate missing DB
|
||||
app.state.feedback_repo = None
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-fb-3/feedback")
|
||||
assert response.status_code == 503
|
||||
@@ -47,7 +47,7 @@ def test_generate_suggestions_parses_and_limits(monkeypatch):
|
||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
# Bypass the require_permission decorator (which needs request +
|
||||
# thread_meta_repo) — these tests cover the parsing logic.
|
||||
# thread_store) — these tests cover the parsing logic.
|
||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
||||
|
||||
assert result.suggestions == ["Q1", "Q2", "Q3"]
|
||||
@@ -67,7 +67,7 @@ def test_generate_suggestions_parses_list_block_content(monkeypatch):
|
||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
# Bypass the require_permission decorator (which needs request +
|
||||
# thread_meta_repo) — these tests cover the parsing logic.
|
||||
# thread_store) — these tests cover the parsing logic.
|
||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
||||
|
||||
assert result.suggestions == ["Q1", "Q2"]
|
||||
@@ -87,7 +87,7 @@ def test_generate_suggestions_parses_output_text_block_content(monkeypatch):
|
||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
# Bypass the require_permission decorator (which needs request +
|
||||
# thread_meta_repo) — these tests cover the parsing logic.
|
||||
# thread_store) — these tests cover the parsing logic.
|
||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
||||
|
||||
assert result.suggestions == ["Q1", "Q2"]
|
||||
@@ -104,7 +104,7 @@ def test_generate_suggestions_returns_empty_on_model_error(monkeypatch):
|
||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
# Bypass the require_permission decorator (which needs request +
|
||||
# thread_meta_repo) — these tests cover the parsing logic.
|
||||
# thread_store) — these tests cover the parsing logic.
|
||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
||||
|
||||
assert result.suggestions == []
|
||||
|
||||
@@ -43,8 +43,8 @@ class TestThreadMetaRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_owner_and_display_name(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
record = await repo.create("t1", owner_id="user1", display_name="My Thread")
|
||||
assert record["owner_id"] == "user1"
|
||||
record = await repo.create("t1", user_id="user1", display_name="My Thread")
|
||||
assert record["user_id"] == "user1"
|
||||
assert record["display_name"] == "My Thread"
|
||||
await _cleanup()
|
||||
|
||||
@@ -61,26 +61,6 @@ class TestThreadMetaRepository:
|
||||
assert await repo.get("nonexistent") is None
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_owner(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1", owner_id="user1")
|
||||
await repo.create("t2", owner_id="user1")
|
||||
await repo.create("t3", owner_id="user2")
|
||||
results = await repo.list_by_owner("user1")
|
||||
assert len(results) == 2
|
||||
assert all(r["owner_id"] == "user1" for r in results)
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_owner_with_limit_and_offset(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
for i in range(5):
|
||||
await repo.create(f"t{i}", owner_id="user1")
|
||||
results = await repo.list_by_owner("user1", limit=2, offset=1)
|
||||
assert len(results) == 2
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_no_record_allows(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
@@ -90,23 +70,23 @@ class TestThreadMetaRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_owner_matches(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1", owner_id="user1")
|
||||
await repo.create("t1", user_id="user1")
|
||||
assert await repo.check_access("t1", "user1") is True
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_owner_mismatch(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1", owner_id="user1")
|
||||
await repo.create("t1", user_id="user1")
|
||||
assert await repo.check_access("t1", "user2") is False
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_no_owner_allows_all(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
# Explicit owner_id=None to bypass the new AUTO default that
|
||||
# Explicit user_id=None to bypass the new AUTO default that
|
||||
# would otherwise pick up the test user from the autouse fixture.
|
||||
await repo.create("t1", owner_id=None)
|
||||
await repo.create("t1", user_id=None)
|
||||
assert await repo.check_access("t1", "anyone") is True
|
||||
await _cleanup()
|
||||
|
||||
@@ -125,27 +105,27 @@ class TestThreadMetaRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_strict_owner_match_allowed(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1", owner_id="user1")
|
||||
await repo.create("t1", user_id="user1")
|
||||
assert await repo.check_access("t1", "user1", require_existing=True) is True
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_strict_owner_mismatch_denied(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1", owner_id="user1")
|
||||
await repo.create("t1", user_id="user1")
|
||||
assert await repo.check_access("t1", "user2", require_existing=True) is False
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_strict_null_owner_still_allowed(self, tmp_path):
|
||||
"""Even in strict mode, a row with NULL owner_id stays shared.
|
||||
"""Even in strict mode, a row with NULL user_id stays shared.
|
||||
|
||||
The strict flag tightens the *missing row* case, not the *shared
|
||||
row* case — legacy pre-auth rows that survived a clean migration
|
||||
without an owner are still everyone's.
|
||||
"""
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1", owner_id=None)
|
||||
await repo.create("t1", user_id=None)
|
||||
assert await repo.check_access("t1", "anyone", require_existing=True) is True
|
||||
await _cleanup()
|
||||
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
"""Tests for paginated GET /api/threads/{thread_id}/runs/{run_id}/messages endpoint."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from _router_auth_helpers import make_authed_test_app
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.routers import thread_runs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_app(event_store=None):
|
||||
"""Build a test FastAPI app with stub auth and mocked state."""
|
||||
app = make_authed_test_app()
|
||||
app.include_router(thread_runs.router)
|
||||
|
||||
if event_store is not None:
|
||||
app.state.run_event_store = event_store
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def _make_event_store(rows: list[dict]):
|
||||
"""Return an AsyncMock event store whose list_messages_by_run() returns rows."""
|
||||
store = MagicMock()
|
||||
store.list_messages_by_run = AsyncMock(return_value=rows)
|
||||
return store
|
||||
|
||||
|
||||
def _make_message(seq: int) -> dict:
|
||||
return {"seq": seq, "event_type": "ai_message", "category": "message", "content": f"msg-{seq}"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_returns_paginated_envelope():
|
||||
"""GET /api/threads/{tid}/runs/{rid}/messages returns {data: [...], has_more: bool}."""
|
||||
rows = [_make_message(i) for i in range(1, 4)]
|
||||
app = _make_app(event_store=_make_event_store(rows))
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/threads/thread-1/runs/run-1/messages")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert "data" in body
|
||||
assert "has_more" in body
|
||||
assert body["has_more"] is False
|
||||
assert len(body["data"]) == 3
|
||||
|
||||
|
||||
def test_has_more_true_when_extra_row_returned():
|
||||
"""has_more=True when event store returns limit+1 rows."""
|
||||
# Default limit is 50; provide 51 rows
|
||||
rows = [_make_message(i) for i in range(1, 52)] # 51 rows
|
||||
app = _make_app(event_store=_make_event_store(rows))
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/threads/thread-2/runs/run-2/messages")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["has_more"] is True
|
||||
assert len(body["data"]) == 50 # trimmed to limit
|
||||
|
||||
|
||||
def test_after_seq_forwarded_to_event_store():
|
||||
"""after_seq query param is forwarded to event_store.list_messages_by_run."""
|
||||
rows = [_make_message(10)]
|
||||
event_store = _make_event_store(rows)
|
||||
app = _make_app(event_store=event_store)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/threads/thread-3/runs/run-3/messages?after_seq=5")
|
||||
assert response.status_code == 200
|
||||
event_store.list_messages_by_run.assert_awaited_once_with(
|
||||
"thread-3", "run-3",
|
||||
limit=51, # default limit(50) + 1
|
||||
before_seq=None,
|
||||
after_seq=5,
|
||||
)
|
||||
|
||||
|
||||
def test_before_seq_forwarded_to_event_store():
|
||||
"""before_seq query param is forwarded to event_store.list_messages_by_run."""
|
||||
rows = [_make_message(3)]
|
||||
event_store = _make_event_store(rows)
|
||||
app = _make_app(event_store=event_store)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/threads/thread-4/runs/run-4/messages?before_seq=10")
|
||||
assert response.status_code == 200
|
||||
event_store.list_messages_by_run.assert_awaited_once_with(
|
||||
"thread-4", "run-4",
|
||||
limit=51,
|
||||
before_seq=10,
|
||||
after_seq=None,
|
||||
)
|
||||
|
||||
|
||||
def test_custom_limit_forwarded_to_event_store():
|
||||
"""Custom limit is forwarded as limit+1 to the event store."""
|
||||
rows = [_make_message(i) for i in range(1, 6)]
|
||||
event_store = _make_event_store(rows)
|
||||
app = _make_app(event_store=event_store)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/threads/thread-5/runs/run-5/messages?limit=10")
|
||||
assert response.status_code == 200
|
||||
event_store.list_messages_by_run.assert_awaited_once_with(
|
||||
"thread-5", "run-5",
|
||||
limit=11, # 10 + 1
|
||||
before_seq=None,
|
||||
after_seq=None,
|
||||
)
|
||||
|
||||
|
||||
def test_empty_data_when_no_messages():
|
||||
"""Returns empty data list with has_more=False when no messages exist."""
|
||||
app = _make_app(event_store=_make_event_store([]))
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/threads/thread-6/runs/run-6/messages")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["data"] == []
|
||||
assert body["has_more"] is False
|
||||
@@ -0,0 +1,439 @@
|
||||
"""Tests for event-store-backed message loading in thread state/history endpoints.
|
||||
|
||||
Covers the helper functions added to ``app/gateway/routers/threads.py``:
|
||||
|
||||
- ``_sanitize_legacy_command_repr`` — extracts inner ToolMessage text from
|
||||
legacy ``str(Command(...))`` strings captured before the ``journal.py``
|
||||
fix for state-updating tools like ``present_files``.
|
||||
- ``_get_event_store_messages`` — loads the full message stream with full
|
||||
pagination, copy-on-read id patching, legacy Command sanitization, and
|
||||
a clean fallback to ``None`` when the event store is unavailable.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from app.gateway.routers.threads import (
|
||||
_get_event_store_messages,
|
||||
_sanitize_legacy_command_repr,
|
||||
)
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def event_store() -> MemoryRunEventStore:
|
||||
return MemoryRunEventStore()
|
||||
|
||||
|
||||
class _FakeFeedbackRepo:
|
||||
"""Minimal ``FeedbackRepository`` stand-in that returns a configured map."""
|
||||
|
||||
def __init__(self, by_run: dict[str, dict] | None = None) -> None:
|
||||
self._by_run = by_run or {}
|
||||
|
||||
async def list_by_thread_grouped(self, thread_id: str, *, user_id: str | None) -> dict[str, dict]:
|
||||
return dict(self._by_run)
|
||||
|
||||
|
||||
def _make_request(
|
||||
event_store: MemoryRunEventStore,
|
||||
feedback_repo: _FakeFeedbackRepo | None = None,
|
||||
) -> Any:
|
||||
"""Build a minimal FastAPI-like Request object.
|
||||
|
||||
``get_run_event_store(request)`` reads ``request.app.state.run_event_store``.
|
||||
``get_feedback_repo(request)`` reads ``request.app.state.feedback_repo``.
|
||||
``get_current_user`` is monkey-patched separately in tests that need it.
|
||||
"""
|
||||
state = SimpleNamespace(
|
||||
run_event_store=event_store,
|
||||
feedback_repo=feedback_repo or _FakeFeedbackRepo(),
|
||||
)
|
||||
app = SimpleNamespace(state=state)
|
||||
return SimpleNamespace(app=app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _stub_current_user(monkeypatch):
|
||||
"""Stub out ``get_current_user`` so tests don't need real auth context."""
|
||||
import app.gateway.routers.threads as threads_mod
|
||||
|
||||
async def _fake(_request):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(threads_mod, "get_current_user", _fake)
|
||||
|
||||
|
||||
async def _seed_simple_run(store: MemoryRunEventStore, thread_id: str, run_id: str) -> None:
|
||||
"""Seed one run: human + ai_tool_call + tool_result + final ai_message, plus a trace."""
|
||||
await store.put(
|
||||
thread_id=thread_id, run_id=run_id,
|
||||
event_type="human_message", category="message",
|
||||
content={
|
||||
"type": "human", "id": None,
|
||||
"content": [{"type": "text", "text": "hello"}],
|
||||
"additional_kwargs": {}, "response_metadata": {}, "name": None,
|
||||
},
|
||||
)
|
||||
await store.put(
|
||||
thread_id=thread_id, run_id=run_id,
|
||||
event_type="ai_tool_call", category="message",
|
||||
content={
|
||||
"type": "ai", "id": "lc_run--tc1",
|
||||
"content": "",
|
||||
"tool_calls": [{"name": "search", "args": {"q": "x"}, "id": "call_1", "type": "tool_call"}],
|
||||
"invalid_tool_calls": [],
|
||||
"additional_kwargs": {}, "response_metadata": {}, "name": None,
|
||||
"usage_metadata": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
||||
},
|
||||
)
|
||||
await store.put(
|
||||
thread_id=thread_id, run_id=run_id,
|
||||
event_type="tool_result", category="message",
|
||||
content={
|
||||
"type": "tool", "id": None,
|
||||
"content": "results",
|
||||
"tool_call_id": "call_1", "name": "search",
|
||||
"artifact": None, "status": "success",
|
||||
"additional_kwargs": {}, "response_metadata": {},
|
||||
},
|
||||
)
|
||||
await store.put(
|
||||
thread_id=thread_id, run_id=run_id,
|
||||
event_type="ai_message", category="message",
|
||||
content={
|
||||
"type": "ai", "id": "lc_run--final1",
|
||||
"content": "done",
|
||||
"tool_calls": [], "invalid_tool_calls": [],
|
||||
"additional_kwargs": {}, "response_metadata": {"finish_reason": "stop"}, "name": None,
|
||||
"usage_metadata": {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30},
|
||||
},
|
||||
)
|
||||
# Non-message trace — must be filtered out.
|
||||
await store.put(
|
||||
thread_id=thread_id, run_id=run_id,
|
||||
event_type="llm_request", category="trace",
|
||||
content={"model": "test"},
|
||||
)
|
||||
|
||||
|
||||
class TestSanitizeLegacyCommandRepr:
|
||||
def test_passthrough_non_string(self):
|
||||
assert _sanitize_legacy_command_repr(None) is None
|
||||
assert _sanitize_legacy_command_repr(42) == 42
|
||||
assert _sanitize_legacy_command_repr([{"type": "text", "text": "x"}]) == [{"type": "text", "text": "x"}]
|
||||
|
||||
def test_passthrough_plain_string(self):
|
||||
assert _sanitize_legacy_command_repr("Successfully presented files") == "Successfully presented files"
|
||||
assert _sanitize_legacy_command_repr("") == ""
|
||||
|
||||
def test_extracts_inner_content_single_quotes(self):
|
||||
legacy = (
|
||||
"Command(update={'artifacts': ['/mnt/user-data/outputs/report.md'], "
|
||||
"'messages': [ToolMessage(content='Successfully presented files', "
|
||||
"tool_call_id='call_abc')]})"
|
||||
)
|
||||
assert _sanitize_legacy_command_repr(legacy) == "Successfully presented files"
|
||||
|
||||
def test_extracts_inner_content_double_quotes(self):
|
||||
legacy = 'Command(update={"messages": [ToolMessage(content="ok", tool_call_id="x")]})'
|
||||
assert _sanitize_legacy_command_repr(legacy) == "ok"
|
||||
|
||||
def test_unparseable_command_returns_original(self):
|
||||
legacy = "Command(update={'something_else': 1})"
|
||||
assert _sanitize_legacy_command_repr(legacy) == legacy
|
||||
|
||||
|
||||
class TestGetEventStoreMessages:
|
||||
@pytest.mark.anyio
|
||||
async def test_returns_none_when_store_empty(self, event_store):
|
||||
request = _make_request(event_store)
|
||||
assert await _get_event_store_messages(request, "t_missing") is None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_extracts_all_message_types_in_order(self, event_store):
|
||||
await _seed_simple_run(event_store, "t1", "r1")
|
||||
request = _make_request(event_store)
|
||||
messages = await _get_event_store_messages(request, "t1")
|
||||
assert messages is not None
|
||||
types = [m["type"] for m in messages]
|
||||
assert types == ["human", "ai", "tool", "ai"]
|
||||
# Trace events must not appear
|
||||
for m in messages:
|
||||
assert m.get("type") in {"human", "ai", "tool"}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_null_ids_get_deterministic_uuid5(self, event_store):
|
||||
await _seed_simple_run(event_store, "t1", "r1")
|
||||
request = _make_request(event_store)
|
||||
messages = await _get_event_store_messages(request, "t1")
|
||||
assert messages is not None
|
||||
|
||||
# AI messages keep their LLM ids
|
||||
assert messages[1]["id"] == "lc_run--tc1"
|
||||
assert messages[3]["id"] == "lc_run--final1"
|
||||
|
||||
# Human (seq=1) + tool (seq=3) get deterministic uuid5
|
||||
expected_human_id = str(uuid.uuid5(uuid.NAMESPACE_URL, "t1:1"))
|
||||
expected_tool_id = str(uuid.uuid5(uuid.NAMESPACE_URL, "t1:3"))
|
||||
assert messages[0]["id"] == expected_human_id
|
||||
assert messages[2]["id"] == expected_tool_id
|
||||
|
||||
# Re-running produces the same ids (stability across requests)
|
||||
messages2 = await _get_event_store_messages(request, "t1")
|
||||
assert [m["id"] for m in messages2] == [m["id"] for m in messages]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_helper_does_not_mutate_store(self, event_store):
|
||||
"""Helper must copy content dicts; the live store must stay unchanged."""
|
||||
await _seed_simple_run(event_store, "t1", "r1")
|
||||
request = _make_request(event_store)
|
||||
_ = await _get_event_store_messages(request, "t1")
|
||||
|
||||
# Raw store records still have id=None for human/tool
|
||||
raw = await event_store.list_messages("t1", limit=500)
|
||||
human = next(e for e in raw if e["content"]["type"] == "human")
|
||||
tool = next(e for e in raw if e["content"]["type"] == "tool")
|
||||
assert human["content"]["id"] is None
|
||||
assert tool["content"]["id"] is None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_legacy_command_repr_sanitized(self, event_store):
|
||||
"""A tool_result whose content is a legacy ``str(Command(...))`` is cleaned."""
|
||||
legacy = (
|
||||
"Command(update={'artifacts': ['/mnt/user-data/outputs/x.md'], "
|
||||
"'messages': [ToolMessage(content='Successfully presented files', "
|
||||
"tool_call_id='call_p')]})"
|
||||
)
|
||||
await event_store.put(
|
||||
thread_id="t2", run_id="r1",
|
||||
event_type="tool_result", category="message",
|
||||
content={
|
||||
"type": "tool", "id": None,
|
||||
"content": legacy,
|
||||
"tool_call_id": "call_p", "name": "present_files",
|
||||
"artifact": None, "status": "success",
|
||||
"additional_kwargs": {}, "response_metadata": {},
|
||||
},
|
||||
)
|
||||
request = _make_request(event_store)
|
||||
messages = await _get_event_store_messages(request, "t2")
|
||||
assert messages is not None and len(messages) == 1
|
||||
assert messages[0]["content"] == "Successfully presented files"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_pagination_covers_more_than_one_page(self, event_store, monkeypatch):
|
||||
"""Simulate a long thread that exceeds a single page to exercise the loop."""
|
||||
thread_id = "t_long"
|
||||
# Seed 12 human messages
|
||||
for i in range(12):
|
||||
await event_store.put(
|
||||
thread_id=thread_id, run_id="r1",
|
||||
event_type="human_message", category="message",
|
||||
content={
|
||||
"type": "human", "id": None,
|
||||
"content": [{"type": "text", "text": f"msg {i}"}],
|
||||
"additional_kwargs": {}, "response_metadata": {}, "name": None,
|
||||
},
|
||||
)
|
||||
|
||||
# Force small page size to exercise pagination
|
||||
import app.gateway.routers.threads as threads_mod
|
||||
original = threads_mod._get_event_store_messages
|
||||
|
||||
# Monkeypatch MemoryRunEventStore.list_messages to assert it's called with cursor pagination
|
||||
calls: list[dict] = []
|
||||
real_list = event_store.list_messages
|
||||
|
||||
async def spy_list_messages(tid, *, limit=50, before_seq=None, after_seq=None):
|
||||
calls.append({"limit": limit, "after_seq": after_seq})
|
||||
return await real_list(tid, limit=limit, before_seq=before_seq, after_seq=after_seq)
|
||||
|
||||
monkeypatch.setattr(event_store, "list_messages", spy_list_messages)
|
||||
|
||||
request = _make_request(event_store)
|
||||
messages = await original(request, thread_id)
|
||||
assert messages is not None
|
||||
assert len(messages) == 12
|
||||
assert [m["content"][0]["text"] for m in messages] == [f"msg {i}" for i in range(12)]
|
||||
# At least one call was made with after_seq=None (the initial page)
|
||||
assert any(c["after_seq"] is None for c in calls)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_summarize_regression_recovers_pre_summarize_messages(self, event_store):
|
||||
"""The exact bug: checkpoint would have only post-summarize messages;
|
||||
event store must surface the original pre-summarize human query."""
|
||||
# Run 1 (pre-summarize)
|
||||
await event_store.put(
|
||||
thread_id="t_sum", run_id="r1",
|
||||
event_type="human_message", category="message",
|
||||
content={
|
||||
"type": "human", "id": None,
|
||||
"content": [{"type": "text", "text": "original question"}],
|
||||
"additional_kwargs": {}, "response_metadata": {}, "name": None,
|
||||
},
|
||||
)
|
||||
await event_store.put(
|
||||
thread_id="t_sum", run_id="r1",
|
||||
event_type="ai_message", category="message",
|
||||
content={
|
||||
"type": "ai", "id": "lc_run--r1",
|
||||
"content": "first answer",
|
||||
"tool_calls": [], "invalid_tool_calls": [],
|
||||
"additional_kwargs": {}, "response_metadata": {}, "name": None,
|
||||
"usage_metadata": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0},
|
||||
},
|
||||
)
|
||||
# Run 2 (post-summarize — what the checkpoint still has)
|
||||
await event_store.put(
|
||||
thread_id="t_sum", run_id="r2",
|
||||
event_type="human_message", category="message",
|
||||
content={
|
||||
"type": "human", "id": None,
|
||||
"content": [{"type": "text", "text": "follow up"}],
|
||||
"additional_kwargs": {}, "response_metadata": {}, "name": None,
|
||||
},
|
||||
)
|
||||
await event_store.put(
|
||||
thread_id="t_sum", run_id="r2",
|
||||
event_type="ai_message", category="message",
|
||||
content={
|
||||
"type": "ai", "id": "lc_run--r2",
|
||||
"content": "second answer",
|
||||
"tool_calls": [], "invalid_tool_calls": [],
|
||||
"additional_kwargs": {}, "response_metadata": {}, "name": None,
|
||||
"usage_metadata": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0},
|
||||
},
|
||||
)
|
||||
|
||||
request = _make_request(event_store)
|
||||
messages = await _get_event_store_messages(request, "t_sum")
|
||||
assert messages is not None
|
||||
# 4 messages, not 2 (which is what the summarized checkpoint would yield)
|
||||
assert len(messages) == 4
|
||||
assert messages[0]["content"][0]["text"] == "original question"
|
||||
assert messages[1]["id"] == "lc_run--r1"
|
||||
assert messages[3]["id"] == "lc_run--r2"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_run_id_attached_to_every_message(self, event_store):
|
||||
await _seed_simple_run(event_store, "t1", "r1")
|
||||
request = _make_request(event_store)
|
||||
messages = await _get_event_store_messages(request, "t1")
|
||||
assert messages is not None
|
||||
assert all(m.get("run_id") == "r1" for m in messages)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_feedback_attached_only_to_final_ai_message_per_run(self, event_store):
|
||||
await _seed_simple_run(event_store, "t1", "r1")
|
||||
feedback_repo = _FakeFeedbackRepo(
|
||||
{"r1": {"feedback_id": "fb1", "rating": 1, "comment": "great"}}
|
||||
)
|
||||
request = _make_request(event_store, feedback_repo=feedback_repo)
|
||||
messages = await _get_event_store_messages(request, "t1")
|
||||
assert messages is not None
|
||||
|
||||
# human (0), ai_tool_call (1), tool (2), ai_message (3)
|
||||
final_ai = messages[3]
|
||||
assert final_ai["feedback"] == {
|
||||
"feedback_id": "fb1",
|
||||
"rating": 1,
|
||||
"comment": "great",
|
||||
}
|
||||
# Non-final messages must NOT have a feedback key at all — the
|
||||
# frontend keys button visibility off of this.
|
||||
assert "feedback" not in messages[0]
|
||||
assert "feedback" not in messages[1]
|
||||
assert "feedback" not in messages[2]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_feedback_none_when_no_row_for_run(self, event_store):
|
||||
await _seed_simple_run(event_store, "t1", "r1")
|
||||
request = _make_request(event_store, feedback_repo=_FakeFeedbackRepo({}))
|
||||
messages = await _get_event_store_messages(request, "t1")
|
||||
assert messages is not None
|
||||
# Final ai_message gets an explicit ``None`` — distinguishes "eligible
|
||||
# but unrated" from "not eligible" (field absent).
|
||||
assert messages[3]["feedback"] is None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_feedback_per_run_for_multi_run_thread(self, event_store):
|
||||
"""A thread with two runs: each final ai_message should get its own feedback."""
|
||||
# Run 1
|
||||
await event_store.put(
|
||||
thread_id="t_multi", run_id="r1",
|
||||
event_type="human_message", category="message",
|
||||
content={"type": "human", "id": None, "content": "q1",
|
||||
"additional_kwargs": {}, "response_metadata": {}, "name": None},
|
||||
)
|
||||
await event_store.put(
|
||||
thread_id="t_multi", run_id="r1",
|
||||
event_type="ai_message", category="message",
|
||||
content={"type": "ai", "id": "lc_run--a1", "content": "a1",
|
||||
"tool_calls": [], "invalid_tool_calls": [],
|
||||
"additional_kwargs": {}, "response_metadata": {}, "name": None,
|
||||
"usage_metadata": None},
|
||||
)
|
||||
# Run 2
|
||||
await event_store.put(
|
||||
thread_id="t_multi", run_id="r2",
|
||||
event_type="human_message", category="message",
|
||||
content={"type": "human", "id": None, "content": "q2",
|
||||
"additional_kwargs": {}, "response_metadata": {}, "name": None},
|
||||
)
|
||||
await event_store.put(
|
||||
thread_id="t_multi", run_id="r2",
|
||||
event_type="ai_message", category="message",
|
||||
content={"type": "ai", "id": "lc_run--a2", "content": "a2",
|
||||
"tool_calls": [], "invalid_tool_calls": [],
|
||||
"additional_kwargs": {}, "response_metadata": {}, "name": None,
|
||||
"usage_metadata": None},
|
||||
)
|
||||
feedback_repo = _FakeFeedbackRepo({
|
||||
"r1": {"feedback_id": "fb_r1", "rating": 1, "comment": None},
|
||||
"r2": {"feedback_id": "fb_r2", "rating": -1, "comment": "meh"},
|
||||
})
|
||||
request = _make_request(event_store, feedback_repo=feedback_repo)
|
||||
messages = await _get_event_store_messages(request, "t_multi")
|
||||
assert messages is not None
|
||||
# human[r1], ai[r1], human[r2], ai[r2]
|
||||
assert messages[1]["feedback"]["feedback_id"] == "fb_r1"
|
||||
assert messages[1]["feedback"]["rating"] == 1
|
||||
assert messages[3]["feedback"]["feedback_id"] == "fb_r2"
|
||||
assert messages[3]["feedback"]["rating"] == -1
|
||||
# Humans don't get feedback
|
||||
assert "feedback" not in messages[0]
|
||||
assert "feedback" not in messages[2]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_feedback_repo_failure_does_not_break_helper(self, monkeypatch, event_store):
|
||||
"""If feedback lookup throws, messages still come back without feedback."""
|
||||
await _seed_simple_run(event_store, "t1", "r1")
|
||||
|
||||
class _BoomRepo:
|
||||
async def list_by_thread_grouped(self, *a, **kw):
|
||||
raise RuntimeError("db down")
|
||||
|
||||
request = _make_request(event_store, feedback_repo=_BoomRepo())
|
||||
messages = await _get_event_store_messages(request, "t1")
|
||||
assert messages is not None
|
||||
assert len(messages) == 4
|
||||
for m in messages:
|
||||
assert "feedback" not in m
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_returns_none_when_dep_raises(self, monkeypatch, event_store):
|
||||
"""When ``get_run_event_store`` is not configured, helper returns None."""
|
||||
import app.gateway.routers.threads as threads_mod
|
||||
|
||||
def boom(_request):
|
||||
raise RuntimeError("no store")
|
||||
|
||||
monkeypatch.setattr(threads_mod, "get_run_event_store", boom)
|
||||
request = _make_request(event_store)
|
||||
assert await threads_mod._get_event_store_messages(request, "t1") is None
|
||||
@@ -50,10 +50,13 @@ def test_delete_thread_data_rejects_invalid_thread_id(tmp_path):
|
||||
|
||||
|
||||
def test_delete_thread_route_cleans_thread_directory(tmp_path):
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
paths = Paths(tmp_path)
|
||||
thread_dir = paths.thread_dir("thread-route")
|
||||
paths.sandbox_work_dir("thread-route").mkdir(parents=True, exist_ok=True)
|
||||
(paths.sandbox_work_dir("thread-route") / "notes.txt").write_text("hello", encoding="utf-8")
|
||||
user_id = get_effective_user_id()
|
||||
thread_dir = paths.thread_dir("thread-route", user_id=user_id)
|
||||
paths.sandbox_work_dir("thread-route", user_id=user_id).mkdir(parents=True, exist_ok=True)
|
||||
(paths.sandbox_work_dir("thread-route", user_id=user_id) / "notes.txt").write_text("hello", encoding="utf-8")
|
||||
|
||||
app = make_authed_test_app()
|
||||
app.include_router(threads.router)
|
||||
@@ -113,14 +116,8 @@ def test_delete_thread_data_returns_generic_500_error(tmp_path):
|
||||
# ── Server-reserved metadata key stripping ──────────────────────────────────
|
||||
|
||||
|
||||
def test_strip_reserved_metadata_removes_owner_id():
|
||||
"""Client-supplied owner_id is dropped to prevent reflection attacks."""
|
||||
out = threads._strip_reserved_metadata({"owner_id": "victim-id", "title": "ok"})
|
||||
assert out == {"title": "ok"}
|
||||
|
||||
|
||||
def test_strip_reserved_metadata_removes_user_id():
|
||||
"""user_id is also reserved (defense in depth for any future use)."""
|
||||
"""Client-supplied user_id is dropped to prevent reflection attacks."""
|
||||
out = threads._strip_reserved_metadata({"user_id": "victim-id", "title": "ok"})
|
||||
assert out == {"title": "ok"}
|
||||
|
||||
@@ -136,6 +133,6 @@ def test_strip_reserved_metadata_empty_input():
|
||||
assert threads._strip_reserved_metadata({}) == {}
|
||||
|
||||
|
||||
def test_strip_reserved_metadata_strips_both_simultaneously():
|
||||
out = threads._strip_reserved_metadata({"owner_id": "x", "user_id": "y", "keep": "me"})
|
||||
def test_strip_reserved_metadata_strips_all_reserved_keys():
|
||||
out = threads._strip_reserved_metadata({"user_id": "x", "keep": "me"})
|
||||
assert out == {"keep": "me"}
|
||||
|
||||
@@ -34,7 +34,9 @@ def _runtime(thread_id: str | None = THREAD_ID) -> MagicMock:
|
||||
|
||||
|
||||
def _uploads_dir(tmp_path: Path, thread_id: str = THREAD_ID) -> Path:
|
||||
d = Paths(str(tmp_path)).sandbox_uploads_dir(thread_id)
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
d = Paths(str(tmp_path)).sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
return d
|
||||
|
||||
|
||||
@@ -11,7 +11,9 @@ import pytest
|
||||
|
||||
from deerflow.runtime.user_context import (
|
||||
CurrentUser,
|
||||
DEFAULT_USER_ID,
|
||||
get_current_user,
|
||||
get_effective_user_id,
|
||||
require_current_user,
|
||||
reset_current_user,
|
||||
set_current_user,
|
||||
@@ -67,3 +69,42 @@ def test_protocol_rejects_no_id():
|
||||
"""Objects without .id do not satisfy CurrentUser Protocol."""
|
||||
not_a_user = SimpleNamespace(email="no-id@example.com")
|
||||
assert not isinstance(not_a_user, CurrentUser)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_effective_user_id / DEFAULT_USER_ID tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_default_user_id_is_default():
|
||||
assert DEFAULT_USER_ID == "default"
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_effective_user_id_returns_default_when_no_user():
|
||||
"""No user in context -> fallback to DEFAULT_USER_ID."""
|
||||
assert get_effective_user_id() == "default"
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_effective_user_id_returns_user_id_when_set():
|
||||
user = SimpleNamespace(id="u-abc-123")
|
||||
token = set_current_user(user)
|
||||
try:
|
||||
assert get_effective_user_id() == "u-abc-123"
|
||||
finally:
|
||||
reset_current_user(token)
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_effective_user_id_coerces_to_str():
|
||||
"""User.id might be a UUID object; must come back as str."""
|
||||
import uuid
|
||||
uid = uuid.uuid4()
|
||||
|
||||
user = SimpleNamespace(id=uid)
|
||||
token = set_current_user(user)
|
||||
try:
|
||||
assert get_effective_user_id() == str(uid)
|
||||
finally:
|
||||
reset_current_user(token)
|
||||
|
||||
+2
-2
@@ -740,8 +740,8 @@ skill_evolution:
|
||||
# backend: sqlite -- Single-node deployment, files in sqlite_dir
|
||||
# backend: postgres -- Production multi-node deployment
|
||||
#
|
||||
# SQLite mode automatically uses separate .db files for checkpointer
|
||||
# and application data to avoid write-lock contention.
|
||||
# SQLite mode uses a single deerflow.db file with WAL journal mode
|
||||
# for both checkpointer and application data.
|
||||
#
|
||||
# Postgres mode: put your connection URL in .env as DATABASE_URL,
|
||||
# then reference it here with $DATABASE_URL.
|
||||
|
||||
@@ -0,0 +1,471 @@
|
||||
# Event Store History — Backend Compatibility Layer
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** Replace checkpoint state with the append-only event store as the message source in the thread state/history endpoints, so summarization never causes message loss.
|
||||
|
||||
**Architecture:** The Gateway's `get_thread_state` and `get_thread_history` endpoints currently read messages from `checkpoint.channel_values["messages"]`. After summarization, those messages are replaced with a synthetic summary-as-human message and all pre-summarize messages are gone. We modify these endpoints to read messages from the RunEventStore instead (append-only, unaffected by summarization). The response shape for each message stays identical so the chat render path needs no changes, but the frontend's feedback hook must be aligned to use the same full-history view (see Task 4).
|
||||
|
||||
**Tech Stack:** Python (FastAPI, SQLAlchemy), pytest, TypeScript (React Query)
|
||||
|
||||
**Scope:** Gateway mode only (`make dev-pro`). Standard mode uses the LangGraph Server directly and does not go through these endpoints; the summarize bug is still present there and must be tracked as a separate follow-up (see §"Follow-ups" at end of plan).
|
||||
|
||||
**Prerequisite already landed:** `backend/packages/harness/deerflow/runtime/journal.py` now unwraps `Command(update={'messages':[ToolMessage(...)]})` in `on_tool_end`, so new runs that use state-updating tools (e.g. `present_files`) write the inner `ToolMessage` content to the event store instead of `str(Command(...))`. Legacy data captured before this fix is cleaned up defensively by the new helper (see Task 1 Step 3 `_sanitize_legacy_command_repr`).
|
||||
|
||||
---
|
||||
|
||||
## Real Data Alignment Analysis
|
||||
|
||||
Compared real `POST /history` response (checkpoint-based) with `run_events` table for thread `6d30913e-dcd4-41c8-8941-f66c716cf359` (docs/resp.json + backend/.deer-flow/data/deerflow.db). See `docs/superpowers/specs/2026-04-11-runjournal-history-evaluation.md` for full evidence chain.
|
||||
|
||||
| Message type | Fields compared | Difference |
|
||||
|-------------|----------------|------------|
|
||||
| human_message | all fields | `id` is `None` in event store, has UUID in checkpoint |
|
||||
| ai_message (tool_call) | all fields, 6 overlapping | **IDENTICAL** (0 diffs) |
|
||||
| ai_message (final) | all fields | **IDENTICAL** |
|
||||
| tool_result (normal) | all fields | Only `id` differs (`None` vs UUID) |
|
||||
| tool_result (from `Command`-returning tool) | content | **Legacy data stored `str(Command(...))` repr instead of inner ToolMessage** — fixed in journal.py for new runs; legacy rows sanitized by helper |
|
||||
|
||||
**Root cause for id difference:** LangGraph's checkpoint assigns `id` to HumanMessage and ToolMessage during graph execution. Event store writes happen earlier, when those ids are still None. AI messages receive `id` from the LLM response (`lc_run--*`) and are unaffected.
|
||||
|
||||
**Fix for id:** Generate deterministic UUIDs for `id=None` messages using `uuid5(NAMESPACE_URL, f"{thread_id}:{seq}")` at read time. Patch a **copy** of the content dict, never the live store object.
|
||||
|
||||
**Summarize impact quantified on the reproducer thread**: event_store has 16 messages (7 AI + 9 others); checkpoint has 12 after summarize (5 AI + 7 others). AI id overlap: 5 of 7 — the 2 missing AI messages are pre-summarize.
|
||||
|
||||
---
|
||||
|
||||
## File Structure
|
||||
|
||||
| File | Action | Responsibility |
|
||||
|------|--------|----------------|
|
||||
| `backend/app/gateway/routers/threads.py` | Modify | Replace checkpoint messages with event store messages in `get_thread_state` and `get_thread_history` |
|
||||
| `backend/tests/test_thread_state_event_store.py` | Create | Tests for the modified endpoints |
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Add `_get_event_store_messages` helper to `threads.py`
|
||||
|
||||
A shared helper that loads the **full** message stream from the event store, patches `id=None` messages with deterministic UUIDs, and defensively sanitizes legacy `Command(update=...)` reprs captured before the journal.py fix. Patches a copy of each content dict so the live store is never mutated.
|
||||
|
||||
**Design constraints (derived from evaluation §3, §4, §5):**
|
||||
- **Full pagination**, not `limit=1000`. `RunEventStore.list_messages` returns "latest N records" — a fixed limit silently truncates older messages. Use `count_messages()` to size the request or loop with `after_seq` cursors.
|
||||
- **Copy before mutate**. `MemoryRunEventStore` returns live dict references; the JSONL/DB stores may return detached rows but we must not rely on that. Always `content = dict(evt["content"])` before patching `id`.
|
||||
- **Legacy Command sanitization.** Legacy data contains `content["content"] == "Command(update={'artifacts': [...], 'messages': [ToolMessage(content='X', ...)]})"`. Regex-extract the inner ToolMessage content string and replace; if extraction fails, leave content as-is (still strictly better than nothing because checkpoint fallback is also wrong for summarized threads).
|
||||
- **User context.** `DbRunEventStore.list_messages` is user-scoped via `resolve_user_id(AUTO)` and relies on the auth contextvar set by `@require_permission`. Both endpoints are already decorated — document this dependency in the helper docstring.
|
||||
|
||||
**Files:**
|
||||
- Modify: `backend/app/gateway/routers/threads.py`
|
||||
- Test: `backend/tests/test_thread_state_event_store.py`
|
||||
|
||||
- [ ] **Step 1: Write the test**
|
||||
|
||||
Create `backend/tests/test_thread_state_event_store.py`:
|
||||
|
||||
```python
|
||||
"""Tests for event-store-backed message loading in thread state/history endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def event_store():
|
||||
return MemoryRunEventStore()
|
||||
|
||||
|
||||
async def _seed_conversation(event_store: MemoryRunEventStore, thread_id: str = "t1"):
|
||||
"""Seed a realistic multi-turn conversation matching real checkpoint format."""
|
||||
# human_message: id is None (same as real data)
|
||||
await event_store.put(
|
||||
thread_id=thread_id, run_id="r1",
|
||||
event_type="human_message", category="message",
|
||||
content={
|
||||
"type": "human", "id": None,
|
||||
"content": [{"type": "text", "text": "Hello"}],
|
||||
"additional_kwargs": {}, "response_metadata": {}, "name": None,
|
||||
},
|
||||
)
|
||||
# ai_tool_call: id is set by LLM
|
||||
await event_store.put(
|
||||
thread_id=thread_id, run_id="r1",
|
||||
event_type="ai_tool_call", category="message",
|
||||
content={
|
||||
"type": "ai", "id": "lc_run--abc123",
|
||||
"content": "",
|
||||
"tool_calls": [{"name": "search", "args": {"q": "cats"}, "id": "call_1", "type": "tool_call"}],
|
||||
"invalid_tool_calls": [],
|
||||
"additional_kwargs": {}, "response_metadata": {}, "name": None,
|
||||
"usage_metadata": {"input_tokens": 100, "output_tokens": 50, "total_tokens": 150},
|
||||
},
|
||||
)
|
||||
# tool_result: id is None (same as real data)
|
||||
await event_store.put(
|
||||
thread_id=thread_id, run_id="r1",
|
||||
event_type="tool_result", category="message",
|
||||
content={
|
||||
"type": "tool", "id": None,
|
||||
"content": "Found 10 results",
|
||||
"tool_call_id": "call_1", "name": "search",
|
||||
"artifact": None, "status": "success",
|
||||
"additional_kwargs": {}, "response_metadata": {},
|
||||
},
|
||||
)
|
||||
# ai_message: id is set by LLM
|
||||
await event_store.put(
|
||||
thread_id=thread_id, run_id="r1",
|
||||
event_type="ai_message", category="message",
|
||||
content={
|
||||
"type": "ai", "id": "lc_run--def456",
|
||||
"content": "I found 10 results about cats.",
|
||||
"tool_calls": [], "invalid_tool_calls": [],
|
||||
"additional_kwargs": {}, "response_metadata": {"finish_reason": "stop"}, "name": None,
|
||||
"usage_metadata": {"input_tokens": 200, "output_tokens": 100, "total_tokens": 300},
|
||||
},
|
||||
)
|
||||
# Also add a trace event — should NOT appear
|
||||
await event_store.put(
|
||||
thread_id=thread_id, run_id="r1",
|
||||
event_type="llm_request", category="trace",
|
||||
content={"model": "gpt-4"},
|
||||
)
|
||||
|
||||
|
||||
class TestGetEventStoreMessages:
|
||||
"""Verify event store message extraction with id patching."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extracts_all_message_types(self, event_store):
|
||||
await _seed_conversation(event_store)
|
||||
events = await event_store.list_messages("t1", limit=500)
|
||||
messages = [evt["content"] for evt in events if isinstance(evt.get("content"), dict) and "type" in evt["content"]]
|
||||
assert len(messages) == 4
|
||||
assert [m["type"] for m in messages] == ["human", "ai", "tool", "ai"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_null_ids_get_patched(self, event_store):
|
||||
"""Messages with id=None should get deterministic UUIDs."""
|
||||
await _seed_conversation(event_store)
|
||||
events = await event_store.list_messages("t1", limit=500)
|
||||
messages = []
|
||||
for evt in events:
|
||||
content = evt.get("content")
|
||||
if isinstance(content, dict) and "type" in content:
|
||||
if content.get("id") is None:
|
||||
content["id"] = str(uuid.uuid5(uuid.NAMESPACE_URL, f"t1:{evt['seq']}"))
|
||||
messages.append(content)
|
||||
|
||||
# All messages now have an id
|
||||
for m in messages:
|
||||
assert m["id"] is not None
|
||||
assert isinstance(m["id"], str)
|
||||
assert len(m["id"]) > 0
|
||||
|
||||
# AI messages keep their original id
|
||||
assert messages[1]["id"] == "lc_run--abc123"
|
||||
assert messages[3]["id"] == "lc_run--def456"
|
||||
|
||||
# Human and tool messages get deterministic ids (same input = same output)
|
||||
human_id_1 = str(uuid.uuid5(uuid.NAMESPACE_URL, "t1:1"))
|
||||
assert messages[0]["id"] == human_id_1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_thread(self, event_store):
|
||||
events = await event_store.list_messages("nonexistent", limit=500)
|
||||
messages = [evt["content"] for evt in events if isinstance(evt.get("content"), dict)]
|
||||
assert messages == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_fields_preserved(self, event_store):
|
||||
await _seed_conversation(event_store)
|
||||
events = await event_store.list_messages("t1", limit=500)
|
||||
messages = [evt["content"] for evt in events if isinstance(evt.get("content"), dict) and "type" in evt["content"]]
|
||||
|
||||
# AI tool_call message
|
||||
ai_tc = messages[1]
|
||||
assert ai_tc["tool_calls"][0]["name"] == "search"
|
||||
assert ai_tc["tool_calls"][0]["id"] == "call_1"
|
||||
|
||||
# Tool result
|
||||
tool = messages[2]
|
||||
assert tool["tool_call_id"] == "call_1"
|
||||
assert tool["status"] == "success"
|
||||
```
|
||||
|
||||
- [ ] **Step 2: Run tests to verify they pass**
|
||||
|
||||
Run: `cd backend && PYTHONPATH=. uv run pytest tests/test_thread_state_event_store.py -v`
|
||||
|
||||
- [ ] **Step 3: Add the helper function and modify `get_thread_history`**
|
||||
|
||||
In `backend/app/gateway/routers/threads.py`:
|
||||
|
||||
1. Add import at the top:
|
||||
```python
|
||||
import uuid # ADD (may already exist, check first)
|
||||
from app.gateway.deps import get_run_event_store # ADD
|
||||
```
|
||||
|
||||
2. Add the helper function (before the endpoint functions, after the model definitions):
|
||||
|
||||
```python
|
||||
_LEGACY_CMD_INNER_CONTENT_RE = re.compile(
|
||||
r"ToolMessage\(content=(?P<q>['\"])(?P<inner>.*?)(?P=q)",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_legacy_command_repr(content_field: Any) -> Any:
|
||||
"""Recover the inner ToolMessage text from a legacy ``str(Command(...))`` repr.
|
||||
|
||||
Runs that pre-date the ``on_tool_end`` fix in ``journal.py`` stored
|
||||
``str(Command(update={'messages':[ToolMessage(content='X', ...)]}))`` as the
|
||||
tool_result content. New runs store ``'X'`` directly. For old threads, try
|
||||
to extract ``'X'`` defensively; return the original string if extraction
|
||||
fails (still no worse than the current checkpoint-based fallback, which is
|
||||
broken for summarized threads anyway).
|
||||
"""
|
||||
if not isinstance(content_field, str) or not content_field.startswith("Command(update="):
|
||||
return content_field
|
||||
match = _LEGACY_CMD_INNER_CONTENT_RE.search(content_field)
|
||||
return match.group("inner") if match else content_field
|
||||
|
||||
|
||||
async def _get_event_store_messages(request: Request, thread_id: str) -> list[dict] | None:
|
||||
"""Load messages from the event store, returning None if unavailable.
|
||||
|
||||
The event store is append-only and immune to summarization. Each
|
||||
message event's ``content`` field contains a ``model_dump()``'d
|
||||
LangChain Message dict that is already JSON-serialisable.
|
||||
|
||||
**Full pagination, not a fixed limit.** ``RunEventStore.list_messages``
|
||||
returns the newest ``limit`` records when no cursor is given, which
|
||||
silently drops older messages. We call ``count_messages()`` first and
|
||||
request that many records. For stores that may return fewer (e.g. filtered
|
||||
by user), we also fall back to ``after_seq``-cursor pagination.
|
||||
|
||||
**Copy-on-read.** Each content dict is copied before ``id`` is patched so
|
||||
the live store object is never mutated; ``MemoryRunEventStore`` returns
|
||||
live references.
|
||||
|
||||
**Legacy Command repr sanitization.** See ``_sanitize_legacy_command_repr``.
|
||||
|
||||
**User context.** ``DbRunEventStore`` is user-scoped by default via
|
||||
``resolve_user_id(AUTO)`` (see ``runtime/user_context.py``). Callers of
|
||||
this helper must be inside a request where ``@require_permission`` has
|
||||
populated the user contextvar. Both ``get_thread_history`` and
|
||||
``get_thread_state`` satisfy that. Do not call this helper from CLI or
|
||||
migration scripts without passing ``user_id=None`` explicitly.
|
||||
|
||||
Returns ``None`` when the event store is not configured or contains no
|
||||
messages for this thread, so callers can fall back to checkpoint messages.
|
||||
"""
|
||||
try:
|
||||
event_store = get_run_event_store(request)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
try:
|
||||
total = await event_store.count_messages(thread_id)
|
||||
except Exception:
|
||||
logger.exception("count_messages failed for thread %s", sanitize_log_param(thread_id))
|
||||
return None
|
||||
if not total:
|
||||
return None
|
||||
|
||||
# Batch by page_size to keep memory bounded for very long threads.
|
||||
page_size = 500
|
||||
collected: list[dict] = []
|
||||
after_seq: int | None = None
|
||||
while True:
|
||||
page = await event_store.list_messages(thread_id, limit=page_size, after_seq=after_seq)
|
||||
if not page:
|
||||
break
|
||||
collected.extend(page)
|
||||
if len(page) < page_size:
|
||||
break
|
||||
after_seq = page[-1].get("seq")
|
||||
if after_seq is None:
|
||||
break
|
||||
|
||||
messages: list[dict] = []
|
||||
for evt in collected:
|
||||
raw = evt.get("content")
|
||||
if not isinstance(raw, dict) or "type" not in raw:
|
||||
continue
|
||||
# Copy to avoid mutating the store-owned dict.
|
||||
content = dict(raw)
|
||||
if content.get("id") is None:
|
||||
content["id"] = str(uuid.uuid5(uuid.NAMESPACE_URL, f"{thread_id}:{evt['seq']}"))
|
||||
# Sanitize legacy Command reprs on tool_result messages only.
|
||||
if content.get("type") == "tool":
|
||||
content["content"] = _sanitize_legacy_command_repr(content.get("content"))
|
||||
messages.append(content)
|
||||
return messages if messages else None
|
||||
```
|
||||
|
||||
Also add `import re` at the top of the file if it isn't already imported.
|
||||
|
||||
3. In `get_thread_history` (around line 585-590), replace the messages section:
|
||||
|
||||
**Before:**
|
||||
```python
|
||||
# Attach messages from checkpointer only for the latest checkpoint
|
||||
if is_latest_checkpoint:
|
||||
messages = channel_values.get("messages")
|
||||
if messages:
|
||||
values["messages"] = serialize_channel_values({"messages": messages}).get("messages", [])
|
||||
is_latest_checkpoint = False
|
||||
```
|
||||
|
||||
**After:**
|
||||
```python
|
||||
# Attach messages: prefer event store (immune to summarization),
|
||||
# fall back to checkpoint messages when event store is unavailable.
|
||||
if is_latest_checkpoint:
|
||||
es_messages = await _get_event_store_messages(request, thread_id)
|
||||
if es_messages is not None:
|
||||
values["messages"] = es_messages
|
||||
else:
|
||||
messages = channel_values.get("messages")
|
||||
if messages:
|
||||
values["messages"] = serialize_channel_values({"messages": messages}).get("messages", [])
|
||||
is_latest_checkpoint = False
|
||||
```
|
||||
|
||||
- [ ] **Step 4: Modify `get_thread_state` similarly**
|
||||
|
||||
In `get_thread_state` (around line 443-444), replace:
|
||||
|
||||
**Before:**
|
||||
```python
|
||||
return ThreadStateResponse(
|
||||
values=serialize_channel_values(channel_values),
|
||||
```
|
||||
|
||||
**After:**
|
||||
```python
|
||||
values = serialize_channel_values(channel_values)
|
||||
|
||||
# Override messages with event store data (immune to summarization)
|
||||
es_messages = await _get_event_store_messages(request, thread_id)
|
||||
if es_messages is not None:
|
||||
values["messages"] = es_messages
|
||||
|
||||
return ThreadStateResponse(
|
||||
values=values,
|
||||
```
|
||||
|
||||
- [ ] **Step 5: Run all backend tests**
|
||||
|
||||
Run: `cd backend && PYTHONPATH=. uv run pytest tests/ -v --timeout=30 -x`
|
||||
|
||||
- [ ] **Step 6: Commit**
|
||||
|
||||
```bash
|
||||
git add backend/app/gateway/routers/threads.py backend/tests/test_thread_state_event_store.py
|
||||
git commit -m "feat(threads): load messages from event store instead of checkpoint state
|
||||
|
||||
Event store is append-only and immune to summarization. Messages with
|
||||
null ids (human, tool) get deterministic UUIDs based on thread_id:seq
|
||||
for stable frontend rendering."
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 2 (OPTIONAL, deferred): Reduce flush_threshold for shorter mid-stream gap
|
||||
|
||||
**Status:** Not a correctness fix. Re-evaluation (see spec) found that `RunJournal` already flushes on `run_end`, `run_error`, cancel, and worker `finally` paths. The only window this tuning narrows is a hard process crash or mid-run reload. Defer and decide separately; do not couple with Task 1 merge.
|
||||
|
||||
If pursued: change `flush_threshold` default from 20 → 5 in `journal.py:42`, rerun `tests/test_run_journal.py`, commit as a separate `perf(journal): …` commit.
|
||||
|
||||
---
|
||||
|
||||
### Task 3: Fix `useThreadFeedback` pagination in frontend
|
||||
|
||||
Once `/history` returns the full event-store-backed message stream, the frontend's `runIdByAiIndex` map must also cover the full stream or its positional AI-index mapping drifts and feedback clicks go to the wrong `run_id`. The current hook hardcodes `limit=200`.
|
||||
|
||||
**Files:**
|
||||
- Modify: `frontend/src/core/threads/hooks.ts` (around line 679)
|
||||
|
||||
- [ ] **Step 1: Replace the fixed `?limit=200` with full pagination**
|
||||
|
||||
Change:
|
||||
|
||||
```ts
|
||||
const res = await fetchWithAuth(
|
||||
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/messages?limit=200`,
|
||||
);
|
||||
```
|
||||
|
||||
to a loop that pages via `after_seq` (or an equivalent query param exposed by the `/messages` endpoint — check `backend/app/gateway/routers/thread_runs.py:285-323` for the actual parameter names before writing the TS code). Accumulate `messages` until a page returns fewer than the page size.
|
||||
|
||||
- [ ] **Step 2: Defensive index guard**
|
||||
|
||||
`runIdByAiIndex[aiMessageIndex]` can still be `undefined` when the frontend renders optimistic state before the messages query refreshes. The current `?? undefined` in `message-list.tsx:71` already handles this; do not remove it.
|
||||
|
||||
- [ ] **Step 3: Invalidate `["thread-feedback", threadId]` after a new run**
|
||||
|
||||
In `useThreadStream` (or wherever stream-end is handled), call `queryClient.invalidateQueries({ queryKey: ["thread-feedback", threadId] })` when the stream closes so the runIdByAiIndex picks up the new run's AI message immediately.
|
||||
|
||||
- [ ] **Step 4: Run `pnpm check`**
|
||||
|
||||
```bash
|
||||
cd frontend && pnpm check
|
||||
```
|
||||
|
||||
- [ ] **Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add frontend/src/core/threads/hooks.ts
|
||||
git commit -m "fix(feedback): paginate useThreadFeedback and invalidate after stream"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 4: End-to-end test — summarize + multi-run feedback
|
||||
|
||||
Add a regression test that exercises the exact bug class we are fixing: a summarized thread with at least two runs, where feedback clicks must target the correct `run_id`.
|
||||
|
||||
**Files:**
|
||||
- Modify: `backend/tests/test_thread_state_event_store.py`
|
||||
|
||||
- [ ] **Step 1: Write the test**
|
||||
|
||||
Seed a `MemoryRunEventStore` with two runs worth of messages (`r1`: human + ai + human + ai, `r2`: human + ai), then simulate a summarized checkpoint state that drops the `r1` messages. Call `_get_event_store_messages` and assert:
|
||||
|
||||
- Length matches the event store, not the checkpoint
|
||||
- The first message is the original `r1` human, not a summary
|
||||
- AI messages preserve their `lc_run--*` ids in order
|
||||
- Any `id=None` messages get a stable `uuid5(...)` id
|
||||
- A legacy `str(Command(update=...))` content field in a tool_result is sanitized to the inner text
|
||||
|
||||
- [ ] **Step 2: Run the new test**
|
||||
|
||||
```bash
|
||||
cd backend && PYTHONPATH=. uv run pytest tests/test_thread_state_event_store.py -v
|
||||
```
|
||||
|
||||
- [ ] **Step 3: Commit with Tasks 1, 3 changes**
|
||||
|
||||
Bundle with the Task 1 commit so tests always land alongside the implementation.
|
||||
|
||||
---
|
||||
|
||||
### Task 5: Standard mode follow-up (documentation only)
|
||||
|
||||
Standard mode (`make dev`) hits LangGraph Server directly for `/threads/{id}/history` and does not go through the Gateway router we just patched. The summarize bug is still present there.
|
||||
|
||||
**Files:**
|
||||
- Modify: this plan (add follow-up section at the bottom, see below) OR create a separate tracking issue
|
||||
|
||||
- [ ] **Step 1: Record the gap**
|
||||
|
||||
Append to the bottom of this plan (or open a GitHub issue and link it):
|
||||
|
||||
> **Follow-up — Standard mode summarize bug**
|
||||
> `get_thread_history` in `backend/app/gateway/routers/threads.py` is only hit in Gateway mode. Standard mode proxies `/api/langgraph/*` directly to the LangGraph Server (see `backend/CLAUDE.md` nginx routing and `frontend/CLAUDE.md` `NEXT_PUBLIC_LANGGRAPH_BASE_URL`). The summarize-message-loss symptom is still reproducible there. Options: (a) teach the LangGraph Server checkpointer to branch on an override, (b) move `/history` behind Gateway in Standard mode as well, (c) accept as known limitation for Standard mode. Decide before GA.
|
||||
@@ -0,0 +1,191 @@
|
||||
# RunJournal 替换 History Messages — 方案评估与对比
|
||||
|
||||
**日期**:2026-04-11
|
||||
**分支**:`rayhpeng/fix-persistence-new`
|
||||
**相关 plan**:[`docs/superpowers/plans/2026-04-10-event-store-history.md`](../plans/2026-04-10-event-store-history.md)(尚未落地)
|
||||
|
||||
---
|
||||
|
||||
## 1. 问题与数据核对
|
||||
|
||||
**症状**:SummarizationMiddleware 触发后,前端历史中无法展示 summarize 之前的真实用户消息。
|
||||
|
||||
**复现数据**(thread `6d30913e-dcd4-41c8-8941-f66c716cf359`):
|
||||
|
||||
| 数据源 | seq=1 的 message | 总 message 数 | 是否保留原始 human |
|
||||
|---|---|---:|---|
|
||||
| `run_events`(SQLite) | human `"最新伊美局势"` | 9(1 human + 7 ai_tool_call + 9 tool_result + 1 ai_message) | ✅ |
|
||||
| `/history` 响应(`docs/resp.json`) | type=human,content=`"Here is a summary of the conversation to date:…"` | 不定 | ❌(已被 summary 替换)|
|
||||
|
||||
**根因**:`backend/app/gateway/routers/threads.py:587-589` 的 `get_thread_history` 从 `checkpoint.channel_values["messages"]` 读取,而 LangGraph 的 SummarizationMiddleware 会原地改写这个列表。
|
||||
|
||||
---
|
||||
|
||||
## 2. 候选方案
|
||||
|
||||
| 方案 | 描述 | 本次是否推荐 |
|
||||
|---|---|---|
|
||||
| **A. event_store 覆盖 messages**(已有 plan) | `/history`、`/state` 改读 `RunEventStore.list_messages()`,覆盖 `channel_values["messages"]`;其它字段保持 checkpoint 来源 | ✅ 主方案 |
|
||||
| B. 修 SummarizationMiddleware | 让 summarize 不原地替换 messages(作为附加 system message) | ❌ 违背 summarize 的 token 预算初衷 |
|
||||
| C. 双读合并(checkpoint + event_store diff) | 合并 summarize 切点前后的两段 | ❌ 合并逻辑复杂无额外收益 |
|
||||
| D. 切到现有 `/api/threads/{id}/messages` 端点 | 前端直接消费已经存在的 event-store 消息端点(`thread_runs.py:285-323`)| ⚠️ 更干净但需要前端改动 |
|
||||
|
||||
---
|
||||
|
||||
## 3. Claude 自评 vs Codex 独立评估
|
||||
|
||||
两方独立分析了同一份 plan。重合点基本一致,但 **Codex 发现了一个我遗漏的关键 bug**。
|
||||
|
||||
### 3.1 一致结论
|
||||
|
||||
| 维度 | 结论 |
|
||||
|---|---|
|
||||
| 正确性方向 | event_store 是 append-only + 不受 summarize 影响,方向正确 |
|
||||
| ID 补齐 | `uuid5(NAMESPACE_URL, f"{thread_id}:{seq}")` 稳定且确定性,安全 |
|
||||
| 前端 schema | 零改动 |
|
||||
| Non-message 字段(artifacts/todos/title/thread_data) | summarize 只影响 messages,不需要覆盖其它字段 |
|
||||
| 多 checkpoint 语义 | 前端 `useStream` 只取 `limit: 1`(`frontend/src/core/threads/hooks.ts:203-210`),不做时间旅行;latest-only 可接受但应在注释/文档写清楚 |
|
||||
| 作用域 | 仅 Gateway mode;Standard mode 直连 LangGraph Server,bug 在默认部署路径仍然存在 |
|
||||
|
||||
### 3.2 Claude 的独立观察
|
||||
|
||||
1. 已验证数据对齐:plan 文档第 15-28 行的真实数据对齐表与本次 `run_events` 导出一致(9 条消息 id 分布:AI 来自 LLM `lc_run--*`、human/tool 为 None)。
|
||||
2. 担心 `run_end` / `run_error` / `cancel` 路径未必都 flush —— 这一点 Codex 实际核查了代码并给出确定结论(见下)。
|
||||
3. 方案 A 的单文件改动约 60 行,复杂度小。
|
||||
|
||||
### 3.3 Codex 的关键补充(Claude 遗漏)
|
||||
|
||||
> **Bug #1 — Plan 用 `limit=1000` 并非全量**
|
||||
> `RunEventStore.list_messages()` 的语义是"返回最新 limit 条"(`base.py:51-65`、`db.py:151-181`)。对于消息数超过 1000 的长对话,plan 当前写法会**丢掉最早的消息**,再次引入"消息丢失"bug(只是换了丢失的段)。
|
||||
|
||||
> **Bug #2 — helper 就地修改了 store 的 dict**
|
||||
> plan 的 helper 里对 `content` 原地写 `id`;`MemoryRunEventStore` 返回的是**活引用**,会污染 store 中的对象。应 deep-copy 或 dict 推导出新对象。
|
||||
|
||||
> **Flush 路径已核查**:
|
||||
> `RunJournal` 在 threshold (`journal.py:360-373`)、`run_end` (`91-96`)、`run_error` (`97-106`)、worker `finally` (`worker.py:280-286`) 都会 flush;`CancelledError` 也走 finally。**正常 end/error/cancel 都 flush,仅硬 kill / 进程崩溃会丢缓冲区**。
|
||||
> 因此 `flush_threshold 20 → 5` 的意义**仅在于硬崩溃窗口**与 mid-run reload 可见性,**不是正确性修复**,属于可选 tuning。代价是更多 put_batch / SQLite churn;且 `_flush_sync()` (`383-398`) 已防止并发 flush,所以"每 5 条一 flush"是 best-effort 非严格保证。
|
||||
|
||||
### 3.4 Codex 未否决但提示的次要点
|
||||
|
||||
- 方案 D(消费现有 `/api/threads/{id}/messages` 端点)更干净但需前端改动。
|
||||
- `/history` 一旦被方案 A 改过,就不再是严格意义上的"按 checkpoint 快照"API(对 messages 字段),应写进注释和 API 文档。
|
||||
- Standard mode 的 summarize bug 应建立独立 follow-up issue。
|
||||
|
||||
---
|
||||
|
||||
## 4. 最终合并判决
|
||||
|
||||
**Codex**:APPROVE-WITH-CHANGES
|
||||
**Claude**:同意 Codex 的判决
|
||||
|
||||
### 合并前必须修改(Top 3)
|
||||
|
||||
1. **修复分页 bug**:不能用固定 `limit=1000`。必须用以下之一:
|
||||
- `count = await event_store.count_messages(thread_id)`,再 `list_messages(thread_id, limit=count)`
|
||||
- 或循环 cursor 分页(`after_seq`)直到耗尽
|
||||
2. **不要原地修改 store dict**:helper 对 `content` 的 id 补齐需要 copy(`dict(content)` 浅拷贝足够,因为只写 top-level `id`)
|
||||
3. **Standard mode 显式 follow-up**:在 plan 文末加 "Standard-mode follow-up: TODO #xxx",或在合并 PR 描述中明确这是 Gateway-only 止血
|
||||
|
||||
### 可选(非阻塞)
|
||||
|
||||
4. `flush_threshold 20 → 5` 降级为"可选 tuning",不是修复的一部分;或独立一条 commit 并说明只对硬崩溃窗口有用
|
||||
5. `get_thread_history` 新增注释,说明 messages 字段脱离了 checkpoint 快照语义
|
||||
6. 测试覆盖:模拟 summarize 后的 checkpoint + 真实 event_store,端到端验证 `/history` 返回包含原始 human 消息
|
||||
|
||||
---
|
||||
|
||||
## 5. 推荐执行顺序
|
||||
|
||||
1. 按本文档 §4 修订 `docs/superpowers/plans/2026-04-10-event-store-history.md`(主要是 Task 1 的 helper 实现 + 分页)
|
||||
2. 按修订后的 plan 执行(走 `superpowers:executing-plans`)
|
||||
3. 合并后立即建 Standard mode follow-up issue
|
||||
|
||||
## 6. Feedback 影响分析(2026-04-11 补充)
|
||||
|
||||
### 6.1 数据模型
|
||||
|
||||
`feedback` 表(`persistence/feedback/model.py`):
|
||||
|
||||
| 字段 | 说明 |
|
||||
|---|---|
|
||||
| `feedback_id` PK | - |
|
||||
| `run_id` NOT NULL | 反馈目标 run |
|
||||
| `thread_id` NOT NULL | - |
|
||||
| `user_id` | - |
|
||||
| `message_id` nullable | 注释明确写:`optional RunEventStore event identifier` — 已经面向 event_store 设计 |
|
||||
| UNIQUE(thread_id, run_id, user_id) | 每 run 每用户至多一条 |
|
||||
|
||||
**结论**:feedback **不按 message uuid 存**,按 `run_id` 存,所以 summarize 导致的 checkpoint messages 丢失**不会影响 feedback 存储**。schema 天生与 event_store 兼容,**无需数据迁移**。
|
||||
|
||||
### 6.2 前端的 runId 映射:发现隐藏 bug
|
||||
|
||||
前端 feedback 目前走两条并行的数据链:
|
||||
|
||||
| 用途 | 数据源 | 位置 |
|
||||
|---|---|---|
|
||||
| 渲染消息体 | `POST /history`(checkpoint) | `useStream` → `thread.messages` |
|
||||
| 拿 `runId` 映射 | `GET /api/threads/{id}/messages?limit=200`(**event_store**) | `useThreadFeedback` (`hooks.ts:669-709`) |
|
||||
|
||||
两者通过 **"AI 消息的序号"** 对齐:
|
||||
|
||||
```ts
|
||||
// hooks.ts:691-698
|
||||
for (const msg of messages) {
|
||||
if (msg.event_type === "ai_message") {
|
||||
runIdByAiIndex.push(msg.run_id); // 只按 AI 顺序 push
|
||||
}
|
||||
}
|
||||
// message-list.tsx:70-71
|
||||
runId = feedbackData.runIdByAiIndex[aiMessageIndex]
|
||||
```
|
||||
|
||||
**Bug**:summarize 过的 thread 里,两条数据链的 AI 消息数量和顺序**不一致**:
|
||||
|
||||
| 数据源 | 本 thread 的 AI 消息序列 | 数量 |
|
||||
|---|---|---:|
|
||||
| `/history`(checkpoint,summarize 后) | seq=19,31,37,45,53 | 5 |
|
||||
| `/messages`(event_store,完整) | seq=5,13,19,31,37,45,53 | 7 |
|
||||
|
||||
结果:前端渲染的"第 0 条 AI 消息"是 seq=19,但 `runIdByAiIndex[0]` 指向 seq=5 的 run(本例同一 run 里没事,**跨多 run 的 thread 点赞就会打到错的 run 上**)。
|
||||
|
||||
**这个 bug 和本次 plan 无关,已经存在了**。只是用户未必注意到。
|
||||
|
||||
### 6.3 方案 A 对 feedback 的影响
|
||||
|
||||
**负面**:无。feedback 存储不受影响。
|
||||
|
||||
**正面(意外收益)**:`/history` 切换到 event_store 后,**两条数据链的 AI 消息序列自动对齐**,§6.2 的隐藏 bug 被顺带修好。
|
||||
|
||||
**前提条件**(加入 Top 3 改动之一同等重要):
|
||||
|
||||
- 新 helper 必须和 `/messages` 端点用**同样的消息获取逻辑**(same store, same filter)。否则两条链仍然可能在边界条件下漂移
|
||||
- 具体说:**两边都要做完整分页**。目前 `/messages?limit=200` 在前端硬编码 200,如果 thread 有 >200 条消息就会截断;plan 的 `limit=1000` 也一样有上限。两个上限不一致 → 两边顺序不再对齐 → feedback 映射错位
|
||||
- **必须修**:`useThreadFeedback` 的 `limit=200` 需要改成分页获取全部,或者 `/messages` 后端改为默认全量
|
||||
|
||||
### 6.4 对前端改造顺序的影响
|
||||
|
||||
原 plan 声明"零前端改动",但加入 feedback 考虑后应修正为:
|
||||
|
||||
| 改动 | 必须 | 可选 |
|
||||
|---|---|---|
|
||||
| 后端 `/history` 改读 event_store | ✅ | - |
|
||||
| 后端 helper 用分页而非 `limit=1000` | ✅ | - |
|
||||
| 前端 `useThreadFeedback` 改用分页或提升 limit | ✅ | - |
|
||||
| `runIdByAiIndex` 增加防御:索引越界 fallback `undefined`(已有)| - | ✅ 已经是 |
|
||||
| 前端改用 `/messages` 直接做渲染(方案 D) | - | ✅ 长期更干净 |
|
||||
|
||||
### 6.5 feedback 相关的新 Top 3 补充
|
||||
|
||||
在原来的 Top 3 之外,再加:
|
||||
|
||||
4. **前端 `useThreadFeedback` 必须分页或拉全**(`frontend/src/core/threads/hooks.ts:679`),否则和 `/history` 的新全量行为仍然错位
|
||||
5. **端到端测试**:一个 thread 跨 >1 个 run + 触发 summarize + 给历史 AI 消息点赞,确认 feedback 打到正确的 run_id
|
||||
6. **TanStack Query 缓存协调**:`thread-feedback` 与 history 查询的 `staleTime` / invalidation 需要在新 run 结束时同步刷新,否则新消息写入后 `runIdByAiIndex` 没更新,点赞会打到上一个 run
|
||||
|
||||
---
|
||||
|
||||
## 8. 未决问题
|
||||
|
||||
- `RunEventStore.count_messages()` 与 `list_messages(after_seq=...)` 的实际性能(SQLite 上对于数千消息级别应无问题,但未压测)
|
||||
- `MemoryRunEventStore` 与 `DbRunEventStore` 分页语义是否一致(Codex 只核查了 `db.py`,`memory.py` 需确认)
|
||||
- 是否应把 `/api/threads/{id}/messages` 提升为前端主用 endpoint,把 `/history` 保留为纯 checkpoint API —— 架构层面更干净但成本更高
|
||||
@@ -0,0 +1,203 @@
|
||||
# Summarize Marker in History — Design & Verification
|
||||
|
||||
**Date**: 2026-04-11
|
||||
**Branch**: `rayhpeng/fix-persistence-new`
|
||||
**Status**: Design approved, implementation deferred to a follow-up PR
|
||||
**Depends on**: [`2026-04-11-runjournal-history-evaluation.md`](./2026-04-11-runjournal-history-evaluation.md) (the event-store-backed history fix this builds on)
|
||||
|
||||
---
|
||||
|
||||
## 1. Goal
|
||||
|
||||
Display a "summarization happened here" marker in the conversation history UI when `SummarizationMiddleware` ran mid-run, so users understand why earlier messages look condensed or missing. The event-store-backed `/history` fix already recovered the original messages; this spec adds a **visible marker** at the seq position where summarization occurred, optionally showing the generated summary text.
|
||||
|
||||
## 2. Investigation findings
|
||||
|
||||
### 2.1 Today's state: zero middleware records
|
||||
|
||||
Full scan of `backend/.deer-flow/data/deerflow.db` `run_events`:
|
||||
|
||||
| category | rows |
|
||||
|---|---:|
|
||||
| trace | 76 |
|
||||
| message | 34 |
|
||||
| lifecycle | 8 |
|
||||
| **middleware** | **0** |
|
||||
|
||||
No row has `event_type` containing `summariz` or `middleware`. The middleware category is dead in production.
|
||||
|
||||
### 2.2 Why: two dead code paths in `journal.py`
|
||||
|
||||
| Location | Status |
|
||||
|---|---|
|
||||
| `journal.py:343-362` — `on_custom_event("summarization", ...)` writes one trace event + one `category="middleware"` event. | Dead. Only fires when something calls `adispatch_custom_event("summarization", {...})`. The upstream LangChain `SummarizationMiddleware` (`.venv/.../langchain/agents/middleware/summarization.py:272`) **never emits custom events** — its `before_model`/`abefore_model` just mutate messages in place and return `{'messages': new_messages}`. Callback never triggered. |
|
||||
| `journal.py:449` — `record_middleware(tag, *, name, hook, action, changes)` helper | Dead. Grep shows zero callers in the harness. Added speculatively, never wired up. |
|
||||
|
||||
### 2.3 Concrete evidence of summarize running unlogged
|
||||
|
||||
Thread `3d5dea4a-0983-4727-a4e8-41a64428933a`:
|
||||
|
||||
- `run_events` seq=1 → original human `"写一份关于deer-flow的详细技术报告"` ✓ (event store is fine)
|
||||
- `run_events` seq=43 → `llm_request` trace whose `messages[0]` literal contains `"Here is a summary of the conversation to date:"` — proof that SummarizationMiddleware did inject a summary mid-run
|
||||
- Zero rows with `category='middleware'` for this thread → nothing captured for UI to render
|
||||
|
||||
## 3. Approaches considered
|
||||
|
||||
### A. Subclass `SummarizationMiddleware` and dispatch a custom event
|
||||
|
||||
Wrap the upstream class, override `abefore_model`, call `await adispatch_custom_event("summarization", {...})` after super(). Journal's existing `on_custom_event` path captures it.
|
||||
|
||||
### B. Frontend-only diff heuristic
|
||||
|
||||
Compare `event_store.count_messages()` vs rendered count, infer summarization happened from the gap. **Rejected**: can't pinpoint position in the stream, can't show summary text. Only yields a vague badge.
|
||||
|
||||
### C. Hybrid A + frontend inline card rendered at the middleware event's seq position
|
||||
|
||||
Same backend as A, plus frontend renders an inline `[N messages condensed]` card at the correct chronological position. **Recommended terminal state**.
|
||||
|
||||
## 4. Subagent's wrong claim and its rebuttal
|
||||
|
||||
An independent agent flagged approach A as structurally broken because:
|
||||
|
||||
> `RunnableCallable(trace=False)` skips `set_config_context`, therefore `var_child_runnable_config` is never set, therefore `adispatch_custom_event` raises `RuntimeError("Unable to dispatch an adhoc event without a parent run id")`.
|
||||
|
||||
**This is wrong.** The user's counter-intuition was correct: `trace=False` does not prevent `adispatch_custom_event` from working, as long as the middleware signature explicitly accepts `config: RunnableConfig`. The mechanism:
|
||||
|
||||
1. `RunnableCallable.__init__` (`langgraph/_internal/_runnable.py:293-319`) inspects the function signature. If it accepts `config: RunnableConfig`, that parameter is recorded in `self.func_accepts`.
|
||||
2. Both `trace=True` and `trace=False` branches of `ainvoke` run the same kwarg-injection loop (`_runnable.py:349-356`): `if kw == "config": kw_value = config`. The `config` passed to `ainvoke` (from Pregel's `task.proc.ainvoke(task.input, config)` at `pregel/_retry.py:138`) is the task config with callbacks already bound.
|
||||
3. Inside the middleware, passing that `config` explicitly to `adispatch_custom_event(..., config=config)` means the function doesn't rely on `var_child_runnable_config.get()` at all. The LangChain docstring at `langchain_core/callbacks/manager.py:2574-2579` even says "If using python 3.10 and async, you MUST specify the config parameter" — which is exactly this path.
|
||||
|
||||
`trace=False` only changes whether **this runnable layer creates a new child callback scope**. It does not affect whether the outer-layer config (with callbacks including `RunJournal`) is passed down to the function.
|
||||
|
||||
## 5. Verification
|
||||
|
||||
Ran `/tmp/verify_summarize_event.py` (standalone minimal reproduction):
|
||||
|
||||
- Minimal `AgentMiddleware` subclass with `abefore_model(self, state, runtime, config: RunnableConfig)`
|
||||
- Calls `await adispatch_custom_event("summarization", {...}, config=config)` inside
|
||||
- `create_agent(model=FakeChatModel, middleware=[probe])`
|
||||
- `agent.ainvoke({...}, config={"callbacks": [RecordingHandler()]})`
|
||||
|
||||
**Result**:
|
||||
|
||||
```
|
||||
INFO verify: ProbeMiddleware.abefore_model called
|
||||
INFO verify: config keys: ['callbacks', 'configurable', 'metadata']
|
||||
INFO verify: config.callbacks type: AsyncCallbackManager
|
||||
INFO verify: config.metadata: {'langgraph_step': 1, 'langgraph_node': 'probe.before_model', ...}
|
||||
INFO verify: on_custom_event fired: name=summarization
|
||||
run_id=019d7d19-1727-7830-aa33-648ecbee4b95
|
||||
data={'summary': 'fake summary', 'replaced_count': 3}
|
||||
SUCCESS: approach A is viable (config injection + adispatch work)
|
||||
```
|
||||
|
||||
All five predictions held:
|
||||
|
||||
1. ✅ `config: RunnableConfig` signature triggers auto-injection despite `trace=False`
|
||||
2. ✅ `config.callbacks` is an `AsyncCallbackManager` with `parent_run_id` set
|
||||
3. ✅ `adispatch_custom_event(..., config=config)` runs without error
|
||||
4. ✅ `RecordingHandler.on_custom_event` receives the event
|
||||
5. ✅ The received `run_id` is a valid UUID tied to the running graph
|
||||
|
||||
**Bonus finding**: `config.metadata` contains `langgraph_step` and `langgraph_node`. These can be included in the middleware event's metadata to help the frontend position the marker on the timeline.
|
||||
|
||||
## 6. Recommended implementation (approach C)
|
||||
|
||||
### 6.1 Backend
|
||||
|
||||
**New wrapper middleware** in `backend/packages/harness/deerflow/agents/lead_agent/agent.py`:
|
||||
|
||||
```python
|
||||
from langchain.agents.middleware.summarization import SummarizationMiddleware
|
||||
from langchain_core.callbacks import adispatch_custom_event
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
|
||||
class _TrackingSummarizationMiddleware(SummarizationMiddleware):
|
||||
"""Wraps upstream SummarizationMiddleware to emit a ``summarization``
|
||||
custom event on every actual summarization, so RunJournal can persist
|
||||
a middleware:summarize row to the event store.
|
||||
|
||||
The upstream class does not emit events of its own. Declaring
|
||||
``config: RunnableConfig`` in the override lets LangGraph's
|
||||
``RunnableCallable`` inject the Pregel task config (with callbacks
|
||||
and parent_run_id) regardless of ``trace=False`` on the node.
|
||||
"""
|
||||
|
||||
async def abefore_model(self, state, runtime, config: RunnableConfig):
|
||||
before_count = len(state.get("messages") or [])
|
||||
result = await super().abefore_model(state, runtime)
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
new_messages = result.get("messages") or []
|
||||
replaced_count = max(0, before_count - len(new_messages))
|
||||
summary_text = _extract_summary_text(new_messages)
|
||||
|
||||
await adispatch_custom_event(
|
||||
"summarization",
|
||||
{
|
||||
"summary": summary_text,
|
||||
"replaced_count": replaced_count,
|
||||
},
|
||||
config=config,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _extract_summary_text(messages: list) -> str:
|
||||
"""Pull the summary string out of the HumanMessage the upstream class
|
||||
injects as ``Here is a summary of the conversation to date:...``."""
|
||||
for msg in messages:
|
||||
if getattr(msg, "type", None) == "human":
|
||||
content = getattr(msg, "content", "")
|
||||
text = content if isinstance(content, str) else ""
|
||||
if text.startswith("Here is a summary of the conversation to date"):
|
||||
return text
|
||||
return ""
|
||||
```
|
||||
|
||||
Swap the existing `SummarizationMiddleware()` instantiation in `_build_middlewares` for `_TrackingSummarizationMiddleware(...)` with the same args.
|
||||
|
||||
**Journal change**: **zero**. `on_custom_event("summarization", ...)` in `journal.py:343-362` already writes both a trace and a `category="middleware"` row.
|
||||
|
||||
**History helper change**: extend `_get_event_store_messages` in `backend/app/gateway/routers/threads.py` to surface `category="middleware"` rows as pseudo-messages, e.g.:
|
||||
|
||||
```python
|
||||
# In the per-event loop, after the existing message branch:
|
||||
if evt.get("category") == "middleware" and evt.get("event_type") == "middleware:summarize":
|
||||
meta = evt.get("metadata") or {}
|
||||
messages.append({
|
||||
"id": f"summary-marker-{evt['seq']}",
|
||||
"type": "summary_marker",
|
||||
"replaced_count": meta.get("replaced_count", 0),
|
||||
"summary": (raw or {}).get("content", "") if isinstance(raw, dict) else "",
|
||||
"run_id": evt.get("run_id"),
|
||||
})
|
||||
```
|
||||
|
||||
The marker uses a sentinel `type` (`summary_marker`) that doesn't collide with any LangChain message type, so downstream consumers that loop over messages can skip or render it explicitly.
|
||||
|
||||
### 6.2 Frontend
|
||||
|
||||
- `core/messages/utils.ts`: extend the message grouping to recognize `type === "summary_marker"` and yield it as its own group (`"assistant:summary-marker"`)
|
||||
- `components/workspace/messages/message-list.tsx`: add a branch in the grouped render switch that renders a distinctive inline card showing `N messages condensed` and a collapsible panel with the summary text
|
||||
- No changes to feedback logic: the marker has no `feedback` field so the button naturally doesn't render on it
|
||||
|
||||
## 7. Risks
|
||||
|
||||
1. **Synchronous path**. The upstream class has both `before_model` and `abefore_model`. Our wrapper only overrides the async variant. If any deer-flow code path ever uses the sync flow, those summarizations won't be captured. Mitigation: also override `before_model` and use `dispatch_custom_event` (sync variant) with the same pattern.
|
||||
2. **`_extract_summary_text` fragility**. It depends on the upstream class prefix `"Here is a summary of the conversation to date"` in the injected `HumanMessage`. Any upstream template change breaks detection. Mitigation: pick the first new `HumanMessage` that wasn't in `state["messages"]` before super() — resilient to template wording changes at the cost of a small diff helper.
|
||||
3. **`replaced_count` accuracy when concurrent updates**. If another middleware in the chain also modifies `state["messages"]` before super() returns, the naive `before_count - len(new_messages)` arithmetic is wrong. Mitigation: inspect the `RemoveMessage(id=REMOVE_ALL_MESSAGES)` that upstream emits and count from the original input list directly.
|
||||
4. **History helper contract change**. Introducing a non-LangChain-typed entry (`type="summary_marker"`) in the `/history` response could break frontend code that blindly casts entries to `Message`. Mitigation: the frontend change above adds an explicit branch; type-check the frontend end-to-end before merging.
|
||||
|
||||
## 8. Out of scope / deferred
|
||||
|
||||
- Other middleware types (Title, Guardrail, HITL) do not emit custom events either. If we want markers for those too, repeat the wrapper pattern for each. Not in this design.
|
||||
- Retroactive markers for old threads (captured before this patch) are impossible without re-running the graph. Legacy threads will show the event-store-recovered messages without a marker.
|
||||
- Standard mode (`make dev`) — agent runs inside LangGraph Server, not the Gateway-embedded runtime. `RunJournal` may not be wired there, so the custom event fires but is captured by no one. Tracked as a separate follow-up.
|
||||
|
||||
## 9. Next actions
|
||||
|
||||
1. Land the current summarize-message-loss fixes (journal `Command` unwrap + event-store-backed `/history` + inline feedback) — implementation verified, being committed now as three commits on `rayhpeng/fix-persistence-new`
|
||||
2. Summarize-marker implementation (this spec) → separate follow-up PR based on the above verified design
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user