Compare commits
2 Commits
main
...
release/2.0-rc-1
| Author | SHA1 | Date | |
|---|---|---|---|
| e75a2ff29a | |||
| 185f5649dd |
@@ -33,5 +33,9 @@ INFOQUEST_API_KEY=your-infoquest-api-key
|
||||
|
||||
# GitHub API Token
|
||||
# GITHUB_TOKEN=your-github-token
|
||||
|
||||
# Database (only needed when config.yaml has database.backend: postgres)
|
||||
# DATABASE_URL=postgresql://deerflow:password@localhost:5432/deerflow
|
||||
#
|
||||
# WECOM_BOT_ID=your-wecom-bot-id
|
||||
# WECOM_BOT_SECRET=your-wecom-bot-secret
|
||||
|
||||
+5
-1
@@ -13,6 +13,9 @@ FROM python:3.12-slim-bookworm AS builder
|
||||
ARG NODE_MAJOR=22
|
||||
ARG APT_MIRROR
|
||||
ARG UV_INDEX_URL
|
||||
# Optional extras to install (e.g. "postgres" for PostgreSQL support)
|
||||
# Usage: docker build --build-arg UV_EXTRAS=postgres ...
|
||||
ARG UV_EXTRAS
|
||||
|
||||
# Optionally override apt mirror for restricted networks (e.g. APT_MIRROR=mirrors.aliyun.com)
|
||||
RUN if [ -n "${APT_MIRROR}" ]; then \
|
||||
@@ -43,8 +46,9 @@ WORKDIR /app
|
||||
COPY backend ./backend
|
||||
|
||||
# Install dependencies with cache mount
|
||||
# When UV_EXTRAS is set (e.g. "postgres"), installs optional dependencies.
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
sh -c "cd backend && UV_INDEX_URL=${UV_INDEX_URL:-https://pypi.org/simple} uv sync"
|
||||
sh -c "cd backend && UV_INDEX_URL=${UV_INDEX_URL:-https://pypi.org/simple} uv sync ${UV_EXTRAS:+--extra $UV_EXTRAS}"
|
||||
|
||||
# ── Stage 2: Dev ──────────────────────────────────────────────────────────────
|
||||
# Retains compiler toolchain from builder so startup-time `uv sync` can build
|
||||
|
||||
+161
-1
@@ -1,16 +1,23 @@
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.gateway.auth_middleware import AuthMiddleware
|
||||
from app.gateway.config import get_gateway_config
|
||||
from app.gateway.csrf_middleware import CSRFMiddleware
|
||||
from app.gateway.deps import langgraph_runtime
|
||||
from app.gateway.routers import (
|
||||
agents,
|
||||
artifacts,
|
||||
assistants_compat,
|
||||
auth,
|
||||
channels,
|
||||
feedback,
|
||||
mcp,
|
||||
memory,
|
||||
models,
|
||||
@@ -33,6 +40,125 @@ logging.basicConfig(
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _ensure_admin_user(app: FastAPI) -> None:
|
||||
"""Auto-create the admin user on first boot if no users exist.
|
||||
|
||||
After admin creation, migrate orphan threads from the LangGraph
|
||||
store (metadata.owner_id unset) to the admin account. This is the
|
||||
"no-auth → with-auth" upgrade path: users who ran DeerFlow without
|
||||
authentication have existing LangGraph thread data that needs an
|
||||
owner assigned.
|
||||
|
||||
No SQL persistence migration is needed: the four owner_id columns
|
||||
(threads_meta, runs, run_events, feedback) only come into existence
|
||||
alongside the auth module via create_all, so freshly created tables
|
||||
never contain NULL-owner rows. "Existing persistence DB + new auth"
|
||||
is not a supported upgrade path — fresh install or wipe-and-retry.
|
||||
|
||||
Multi-worker safe: relies on SQLite UNIQUE constraint to resolve
|
||||
races during admin creation. Only the worker that successfully
|
||||
creates/updates the admin prints the password; losers silently skip.
|
||||
"""
|
||||
import secrets
|
||||
|
||||
from app.gateway.auth.credential_file import write_initial_credentials
|
||||
from app.gateway.deps import get_local_provider
|
||||
|
||||
def _announce_credentials(email: str, password: str, *, label: str, headline: str) -> None:
|
||||
"""Write the password to a 0600 file and log the path (never the secret)."""
|
||||
cred_path = write_initial_credentials(email, password, label=label)
|
||||
logger.info("=" * 60)
|
||||
logger.info(" %s", headline)
|
||||
logger.info(" Credentials written to: %s (mode 0600)", cred_path)
|
||||
logger.info(" Change it after login: Settings -> Account")
|
||||
logger.info("=" * 60)
|
||||
|
||||
provider = get_local_provider()
|
||||
user_count = await provider.count_users()
|
||||
|
||||
admin = None
|
||||
|
||||
if user_count == 0:
|
||||
password = secrets.token_urlsafe(16)
|
||||
try:
|
||||
admin = await provider.create_user(email="admin@deerflow.dev", password=password, system_role="admin", needs_setup=True)
|
||||
except ValueError:
|
||||
return # Another worker already created the admin.
|
||||
_announce_credentials(admin.email, password, label="initial", headline="Admin account created on first boot")
|
||||
else:
|
||||
# Admin exists but setup never completed — reset password so operator
|
||||
# can always find it in the console without needing the CLI.
|
||||
# Multi-worker guard: if admin was created less than 30s ago, another
|
||||
# worker just created it and will print the password — skip reset.
|
||||
admin = await provider.get_user_by_email("admin@deerflow.dev")
|
||||
if admin and admin.needs_setup:
|
||||
import time
|
||||
|
||||
age = time.time() - admin.created_at.replace(tzinfo=UTC).timestamp()
|
||||
if age >= 30:
|
||||
from app.gateway.auth.password import hash_password_async
|
||||
|
||||
password = secrets.token_urlsafe(16)
|
||||
admin.password_hash = await hash_password_async(password)
|
||||
admin.token_version += 1
|
||||
await provider.update_user(admin)
|
||||
_announce_credentials(admin.email, password, label="reset", headline="Admin account setup incomplete — password reset")
|
||||
|
||||
if admin is None:
|
||||
return # Nothing to bind orphans to.
|
||||
|
||||
admin_id = str(admin.id)
|
||||
|
||||
# LangGraph store orphan migration — non-fatal.
|
||||
# This covers the "no-auth → with-auth" upgrade path for users
|
||||
# whose existing LangGraph thread metadata has no owner_id set.
|
||||
store = getattr(app.state, "store", None)
|
||||
if store is not None:
|
||||
try:
|
||||
migrated = await _migrate_orphaned_threads(store, admin_id)
|
||||
if migrated:
|
||||
logger.info("Migrated %d orphan LangGraph thread(s) to admin", migrated)
|
||||
except Exception:
|
||||
logger.exception("LangGraph thread migration failed (non-fatal)")
|
||||
|
||||
|
||||
async def _iter_store_items(store, namespace, *, page_size: int = 500):
|
||||
"""Paginated async iterator over a LangGraph store namespace.
|
||||
|
||||
Replaces the old hardcoded ``limit=1000`` call with a cursor-style
|
||||
loop so that environments with more than one page of orphans do
|
||||
not silently lose data. Terminates when a page is empty OR when a
|
||||
short page arrives (indicating the last page).
|
||||
"""
|
||||
offset = 0
|
||||
while True:
|
||||
batch = await store.asearch(namespace, limit=page_size, offset=offset)
|
||||
if not batch:
|
||||
return
|
||||
for item in batch:
|
||||
yield item
|
||||
if len(batch) < page_size:
|
||||
return
|
||||
offset += page_size
|
||||
|
||||
|
||||
async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
|
||||
"""Migrate LangGraph store threads with no owner_id to the given admin.
|
||||
|
||||
Uses cursor pagination so all orphans are migrated regardless of
|
||||
count. Returns the number of rows migrated.
|
||||
"""
|
||||
migrated = 0
|
||||
async for item in _iter_store_items(store, ("threads",)):
|
||||
metadata = item.value.get("metadata", {})
|
||||
if not metadata.get("owner_id"):
|
||||
metadata["owner_id"] = admin_user_id
|
||||
item.value["metadata"] = metadata
|
||||
await store.aput(("threads",), item.key, item.value)
|
||||
migrated += 1
|
||||
return migrated
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Application lifespan handler."""
|
||||
@@ -52,6 +178,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
async with langgraph_runtime(app):
|
||||
logger.info("LangGraph runtime initialised")
|
||||
|
||||
# Ensure admin user exists (auto-create on first boot)
|
||||
# Must run AFTER langgraph_runtime so app.state.store is available for thread migration
|
||||
await _ensure_admin_user(app)
|
||||
|
||||
# Start IM channel service if any channels are configured
|
||||
try:
|
||||
from app.channels.service import start_channel_service
|
||||
@@ -163,7 +293,31 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
|
||||
],
|
||||
)
|
||||
|
||||
# CORS is handled by nginx - no need for FastAPI middleware
|
||||
# Auth: reject unauthenticated requests to non-public paths (fail-closed safety net)
|
||||
app.add_middleware(AuthMiddleware)
|
||||
|
||||
# CSRF: Double Submit Cookie pattern for state-changing requests
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
# CORS: when GATEWAY_CORS_ORIGINS is set (dev without nginx), add CORS middleware.
|
||||
# In production, nginx handles CORS and no middleware is needed.
|
||||
cors_origins_env = os.environ.get("GATEWAY_CORS_ORIGINS", "")
|
||||
if cors_origins_env:
|
||||
cors_origins = [o.strip() for o in cors_origins_env.split(",") if o.strip()]
|
||||
# Validate: wildcard origin with credentials is a security misconfiguration
|
||||
for origin in cors_origins:
|
||||
if origin == "*":
|
||||
logger.error("GATEWAY_CORS_ORIGINS contains wildcard '*' with allow_credentials=True. This is a security misconfiguration — browsers will reject the response. Use explicit scheme://host:port origins instead.")
|
||||
cors_origins = [o for o in cors_origins if o != "*"]
|
||||
break
|
||||
if cors_origins:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include routers
|
||||
# Models API is mounted at /api/models
|
||||
@@ -199,6 +353,12 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
|
||||
# Assistants compatibility API (LangGraph Platform stub)
|
||||
app.include_router(assistants_compat.router)
|
||||
|
||||
# Auth API is mounted at /api/v1/auth
|
||||
app.include_router(auth.router)
|
||||
|
||||
# Feedback API is mounted at /api/threads/{thread_id}/runs/{run_id}/feedback
|
||||
app.include_router(feedback.router)
|
||||
|
||||
# Thread Runs API (LangGraph Platform-compatible runs lifecycle)
|
||||
app.include_router(thread_runs.router)
|
||||
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
"""Authentication module for DeerFlow.
|
||||
|
||||
This module provides:
|
||||
- JWT-based authentication
|
||||
- Provider Factory pattern for extensible auth methods
|
||||
- UserRepository interface for storage backends (SQLite)
|
||||
"""
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, get_auth_config, set_auth_config
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
|
||||
from app.gateway.auth.jwt import TokenPayload, create_access_token, decode_token
|
||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||
from app.gateway.auth.models import User, UserResponse
|
||||
from app.gateway.auth.password import hash_password, verify_password
|
||||
from app.gateway.auth.providers import AuthProvider
|
||||
from app.gateway.auth.repositories.base import UserRepository
|
||||
|
||||
__all__ = [
|
||||
# Config
|
||||
"AuthConfig",
|
||||
"get_auth_config",
|
||||
"set_auth_config",
|
||||
# Errors
|
||||
"AuthErrorCode",
|
||||
"AuthErrorResponse",
|
||||
"TokenError",
|
||||
# JWT
|
||||
"TokenPayload",
|
||||
"create_access_token",
|
||||
"decode_token",
|
||||
# Password
|
||||
"hash_password",
|
||||
"verify_password",
|
||||
# Models
|
||||
"User",
|
||||
"UserResponse",
|
||||
# Providers
|
||||
"AuthProvider",
|
||||
"LocalAuthProvider",
|
||||
# Repository
|
||||
"UserRepository",
|
||||
]
|
||||
@@ -0,0 +1,57 @@
|
||||
"""Authentication configuration for DeerFlow."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthConfig(BaseModel):
|
||||
"""JWT and auth-related configuration. Parsed once at startup.
|
||||
|
||||
Note: the ``users`` table now lives in the shared persistence
|
||||
database managed by ``deerflow.persistence.engine``. The old
|
||||
``users_db_path`` config key has been removed — user storage is
|
||||
configured through ``config.database`` like every other table.
|
||||
"""
|
||||
|
||||
jwt_secret: str = Field(
|
||||
...,
|
||||
description="Secret key for JWT signing. MUST be set via AUTH_JWT_SECRET.",
|
||||
)
|
||||
token_expiry_days: int = Field(default=7, ge=1, le=30)
|
||||
oauth_github_client_id: str | None = Field(default=None)
|
||||
oauth_github_client_secret: str | None = Field(default=None)
|
||||
|
||||
|
||||
_auth_config: AuthConfig | None = None
|
||||
|
||||
|
||||
def get_auth_config() -> AuthConfig:
|
||||
"""Get the global AuthConfig instance. Parses from env on first call."""
|
||||
global _auth_config
|
||||
if _auth_config is None:
|
||||
jwt_secret = os.environ.get("AUTH_JWT_SECRET")
|
||||
if not jwt_secret:
|
||||
jwt_secret = secrets.token_urlsafe(32)
|
||||
os.environ["AUTH_JWT_SECRET"] = jwt_secret
|
||||
logger.warning(
|
||||
"⚠ AUTH_JWT_SECRET is not set — using an auto-generated ephemeral secret. "
|
||||
"Sessions will be invalidated on restart. "
|
||||
"For production, add AUTH_JWT_SECRET to your .env file: "
|
||||
'python -c "import secrets; print(secrets.token_urlsafe(32))"'
|
||||
)
|
||||
_auth_config = AuthConfig(jwt_secret=jwt_secret)
|
||||
return _auth_config
|
||||
|
||||
|
||||
def set_auth_config(config: AuthConfig) -> None:
|
||||
"""Set the global AuthConfig instance (for testing)."""
|
||||
global _auth_config
|
||||
_auth_config = config
|
||||
@@ -0,0 +1,48 @@
|
||||
"""Write initial admin credentials to a restricted file instead of logs.
|
||||
|
||||
Logging secrets to stdout/stderr is a well-known CodeQL finding
|
||||
(py/clear-text-logging-sensitive-data) — in production those logs
|
||||
get collected into ELK/Splunk/etc and become a secret sprawl
|
||||
source. This helper writes the credential to a 0600 file that only
|
||||
the process user can read, and returns the path so the caller can
|
||||
log **the path** (not the password) for the operator to pick up.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.config.paths import get_paths
|
||||
|
||||
_CREDENTIAL_FILENAME = "admin_initial_credentials.txt"
|
||||
|
||||
|
||||
def write_initial_credentials(email: str, password: str, *, label: str = "initial") -> Path:
|
||||
"""Write the admin email + password to ``{base_dir}/admin_initial_credentials.txt``.
|
||||
|
||||
The file is created **atomically** with mode 0600 via ``os.open``
|
||||
so the password is never world-readable, even for the single syscall
|
||||
window between ``write_text`` and ``chmod``.
|
||||
|
||||
``label`` distinguishes "initial" (fresh creation) from "reset"
|
||||
(password reset) in the file header so an operator picking up the
|
||||
file after a restart can tell which event produced it.
|
||||
|
||||
Returns the absolute :class:`Path` to the file.
|
||||
"""
|
||||
target = get_paths().base_dir / _CREDENTIAL_FILENAME
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
content = (
|
||||
f"# DeerFlow admin {label} credentials\n# This file is generated on first boot or password reset.\n# Change the password after login via Settings -> Account,\n# then delete this file.\n#\nemail: {email}\npassword: {password}\n"
|
||||
)
|
||||
|
||||
# Atomic 0600 create-or-truncate. O_TRUNC (not O_EXCL) so the
|
||||
# reset-password path can rewrite an existing file without a
|
||||
# separate unlink-then-create dance.
|
||||
fd = os.open(target, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as fh:
|
||||
fh.write(content)
|
||||
|
||||
return target.resolve()
|
||||
@@ -0,0 +1,44 @@
|
||||
"""Typed error definitions for auth module.
|
||||
|
||||
AuthErrorCode: exhaustive enum of all auth failure conditions.
|
||||
TokenError: exhaustive enum of JWT decode failures.
|
||||
AuthErrorResponse: structured error payload for HTTP responses.
|
||||
"""
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AuthErrorCode(StrEnum):
|
||||
"""Exhaustive list of auth error conditions."""
|
||||
|
||||
INVALID_CREDENTIALS = "invalid_credentials"
|
||||
TOKEN_EXPIRED = "token_expired"
|
||||
TOKEN_INVALID = "token_invalid"
|
||||
USER_NOT_FOUND = "user_not_found"
|
||||
EMAIL_ALREADY_EXISTS = "email_already_exists"
|
||||
PROVIDER_NOT_FOUND = "provider_not_found"
|
||||
NOT_AUTHENTICATED = "not_authenticated"
|
||||
|
||||
|
||||
class TokenError(StrEnum):
|
||||
"""Exhaustive list of JWT decode failure reasons."""
|
||||
|
||||
EXPIRED = "expired"
|
||||
INVALID_SIGNATURE = "invalid_signature"
|
||||
MALFORMED = "malformed"
|
||||
|
||||
|
||||
class AuthErrorResponse(BaseModel):
|
||||
"""Structured error response — replaces bare `detail` strings."""
|
||||
|
||||
code: AuthErrorCode
|
||||
message: str
|
||||
|
||||
|
||||
def token_error_to_code(err: TokenError) -> AuthErrorCode:
|
||||
"""Map TokenError to AuthErrorCode — single source of truth."""
|
||||
if err == TokenError.EXPIRED:
|
||||
return AuthErrorCode.TOKEN_EXPIRED
|
||||
return AuthErrorCode.TOKEN_INVALID
|
||||
@@ -0,0 +1,55 @@
|
||||
"""JWT token creation and verification."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.gateway.auth.config import get_auth_config
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
"""JWT token payload."""
|
||||
|
||||
sub: str # user_id
|
||||
exp: datetime
|
||||
iat: datetime | None = None
|
||||
ver: int = 0 # token_version — must match User.token_version
|
||||
|
||||
|
||||
def create_access_token(user_id: str, expires_delta: timedelta | None = None, token_version: int = 0) -> str:
|
||||
"""Create a JWT access token.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID as string
|
||||
expires_delta: Optional custom expiry, defaults to 7 days
|
||||
token_version: User's current token_version for invalidation
|
||||
|
||||
Returns:
|
||||
Encoded JWT string
|
||||
"""
|
||||
config = get_auth_config()
|
||||
expiry = expires_delta or timedelta(days=config.token_expiry_days)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
payload = {"sub": user_id, "exp": now + expiry, "iat": now, "ver": token_version}
|
||||
return jwt.encode(payload, config.jwt_secret, algorithm="HS256")
|
||||
|
||||
|
||||
def decode_token(token: str) -> TokenPayload | TokenError:
|
||||
"""Decode and validate a JWT token.
|
||||
|
||||
Returns:
|
||||
TokenPayload if valid, or a specific TokenError variant.
|
||||
"""
|
||||
config = get_auth_config()
|
||||
try:
|
||||
payload = jwt.decode(token, config.jwt_secret, algorithms=["HS256"])
|
||||
return TokenPayload(**payload)
|
||||
except jwt.ExpiredSignatureError:
|
||||
return TokenError.EXPIRED
|
||||
except jwt.InvalidSignatureError:
|
||||
return TokenError.INVALID_SIGNATURE
|
||||
except jwt.PyJWTError:
|
||||
return TokenError.MALFORMED
|
||||
@@ -0,0 +1,87 @@
|
||||
"""Local email/password authentication provider."""
|
||||
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.auth.password import hash_password_async, verify_password_async
|
||||
from app.gateway.auth.providers import AuthProvider
|
||||
from app.gateway.auth.repositories.base import UserRepository
|
||||
|
||||
|
||||
class LocalAuthProvider(AuthProvider):
|
||||
"""Email/password authentication provider using local database."""
|
||||
|
||||
def __init__(self, repository: UserRepository):
|
||||
"""Initialize with a UserRepository.
|
||||
|
||||
Args:
|
||||
repository: UserRepository implementation (SQLite)
|
||||
"""
|
||||
self._repo = repository
|
||||
|
||||
async def authenticate(self, credentials: dict) -> User | None:
|
||||
"""Authenticate with email and password.
|
||||
|
||||
Args:
|
||||
credentials: dict with 'email' and 'password' keys
|
||||
|
||||
Returns:
|
||||
User if authentication succeeds, None otherwise
|
||||
"""
|
||||
email = credentials.get("email")
|
||||
password = credentials.get("password")
|
||||
|
||||
if not email or not password:
|
||||
return None
|
||||
|
||||
user = await self._repo.get_user_by_email(email)
|
||||
if user is None:
|
||||
return None
|
||||
|
||||
if user.password_hash is None:
|
||||
# OAuth user without local password
|
||||
return None
|
||||
|
||||
if not await verify_password_async(password, user.password_hash):
|
||||
return None
|
||||
|
||||
return user
|
||||
|
||||
async def get_user(self, user_id: str) -> User | None:
|
||||
"""Get user by ID."""
|
||||
return await self._repo.get_user_by_id(user_id)
|
||||
|
||||
async def create_user(self, email: str, password: str | None = None, system_role: str = "user", needs_setup: bool = False) -> User:
|
||||
"""Create a new local user.
|
||||
|
||||
Args:
|
||||
email: User email address
|
||||
password: Plain text password (will be hashed)
|
||||
system_role: Role to assign ("admin" or "user")
|
||||
needs_setup: If True, user must complete setup on first login
|
||||
|
||||
Returns:
|
||||
Created User instance
|
||||
"""
|
||||
password_hash = await hash_password_async(password) if password else None
|
||||
user = User(
|
||||
email=email,
|
||||
password_hash=password_hash,
|
||||
system_role=system_role,
|
||||
needs_setup=needs_setup,
|
||||
)
|
||||
return await self._repo.create_user(user)
|
||||
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||
"""Get user by OAuth provider and ID."""
|
||||
return await self._repo.get_user_by_oauth(provider, oauth_id)
|
||||
|
||||
async def count_users(self) -> int:
|
||||
"""Return total number of registered users."""
|
||||
return await self._repo.count_users()
|
||||
|
||||
async def update_user(self, user: User) -> User:
|
||||
"""Update an existing user."""
|
||||
return await self._repo.update_user(user)
|
||||
|
||||
async def get_user_by_email(self, email: str) -> User | None:
|
||||
"""Get user by email."""
|
||||
return await self._repo.get_user_by_email(email)
|
||||
@@ -0,0 +1,41 @@
|
||||
"""User Pydantic models for authentication."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import Literal
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field
|
||||
|
||||
|
||||
def _utc_now() -> datetime:
|
||||
"""Return current UTC time (timezone-aware)."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
"""Internal user representation."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, description="Primary key")
|
||||
email: EmailStr = Field(..., description="Unique email address")
|
||||
password_hash: str | None = Field(None, description="bcrypt hash, nullable for OAuth users")
|
||||
system_role: Literal["admin", "user"] = Field(default="user")
|
||||
created_at: datetime = Field(default_factory=_utc_now)
|
||||
|
||||
# OAuth linkage (optional)
|
||||
oauth_provider: str | None = Field(None, description="e.g. 'github', 'google'")
|
||||
oauth_id: str | None = Field(None, description="User ID from OAuth provider")
|
||||
|
||||
# Auth lifecycle
|
||||
needs_setup: bool = Field(default=False, description="True for auto-created admin until setup completes")
|
||||
token_version: int = Field(default=0, description="Incremented on password change to invalidate old JWTs")
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
"""Response model for user info endpoint."""
|
||||
|
||||
id: str
|
||||
email: str
|
||||
system_role: Literal["admin", "user"]
|
||||
needs_setup: bool = False
|
||||
@@ -0,0 +1,33 @@
|
||||
"""Password hashing utilities using bcrypt directly."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import bcrypt
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a password using bcrypt."""
|
||||
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against its hash."""
|
||||
return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8"))
|
||||
|
||||
|
||||
async def hash_password_async(password: str) -> str:
|
||||
"""Hash a password using bcrypt (non-blocking).
|
||||
|
||||
Wraps the blocking bcrypt operation in a thread pool to avoid
|
||||
blocking the event loop during password hashing.
|
||||
"""
|
||||
return await asyncio.to_thread(hash_password, password)
|
||||
|
||||
|
||||
async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against its hash (non-blocking).
|
||||
|
||||
Wraps the blocking bcrypt operation in a thread pool to avoid
|
||||
blocking the event loop during password verification.
|
||||
"""
|
||||
return await asyncio.to_thread(verify_password, plain_password, hashed_password)
|
||||
@@ -0,0 +1,24 @@
|
||||
"""Auth provider abstraction."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class AuthProvider(ABC):
|
||||
"""Abstract base class for authentication providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def authenticate(self, credentials: dict) -> "User | None":
|
||||
"""Authenticate user with given credentials.
|
||||
|
||||
Returns User if authentication succeeds, None otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_user(self, user_id: str) -> "User | None":
|
||||
"""Retrieve user by ID."""
|
||||
...
|
||||
|
||||
|
||||
# Import User at runtime to avoid circular imports
|
||||
from app.gateway.auth.models import User # noqa: E402
|
||||
@@ -0,0 +1,97 @@
|
||||
"""User repository interface for abstracting database operations."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from app.gateway.auth.models import User
|
||||
|
||||
|
||||
class UserNotFoundError(LookupError):
|
||||
"""Raised when a user repository operation targets a non-existent row.
|
||||
|
||||
Subclass of :class:`LookupError` so callers that already catch
|
||||
``LookupError`` for "missing entity" can keep working unchanged,
|
||||
while specific call sites can pin to this class to distinguish
|
||||
"concurrent delete during update" from other lookups.
|
||||
"""
|
||||
|
||||
|
||||
class UserRepository(ABC):
|
||||
"""Abstract interface for user data storage.
|
||||
|
||||
Implement this interface to support different storage backends
|
||||
(SQLite)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def create_user(self, user: User) -> User:
|
||||
"""Create a new user.
|
||||
|
||||
Args:
|
||||
user: User object to create
|
||||
|
||||
Returns:
|
||||
Created User with ID assigned
|
||||
|
||||
Raises:
|
||||
ValueError: If email already exists
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||
"""Get user by ID.
|
||||
|
||||
Args:
|
||||
user_id: User UUID as string
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_by_email(self, email: str) -> User | None:
|
||||
"""Get user by email.
|
||||
|
||||
Args:
|
||||
email: User email address
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def update_user(self, user: User) -> User:
|
||||
"""Update an existing user.
|
||||
|
||||
Args:
|
||||
user: User object with updated fields
|
||||
|
||||
Returns:
|
||||
Updated User
|
||||
|
||||
Raises:
|
||||
UserNotFoundError: If no row exists for ``user.id``. This is
|
||||
a hard failure (not a no-op) so callers cannot mistake a
|
||||
concurrent-delete race for a successful update.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def count_users(self) -> int:
|
||||
"""Return total number of registered users."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||
"""Get user by OAuth provider and ID.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider name (e.g. 'github', 'google')
|
||||
oauth_id: User ID from the OAuth provider
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise
|
||||
"""
|
||||
...
|
||||
@@ -0,0 +1,122 @@
|
||||
"""SQLAlchemy-backed UserRepository implementation.
|
||||
|
||||
Uses the shared async session factory from
|
||||
``deerflow.persistence.engine`` — the ``users`` table lives in the
|
||||
same database as ``threads_meta``, ``runs``, ``run_events``, and
|
||||
``feedback``.
|
||||
|
||||
Constructor takes the session factory directly (same pattern as the
|
||||
other four repositories in ``deerflow.persistence.*``). Callers
|
||||
construct this after ``init_engine_from_config()`` has run.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.auth.repositories.base import UserNotFoundError, UserRepository
|
||||
from deerflow.persistence.user.model import UserRow
|
||||
|
||||
|
||||
class SQLiteUserRepository(UserRepository):
|
||||
"""Async user repository backed by the shared SQLAlchemy engine."""
|
||||
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
self._sf = session_factory
|
||||
|
||||
# ── Converters ────────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _row_to_user(row: UserRow) -> User:
|
||||
return User(
|
||||
id=UUID(row.id),
|
||||
email=row.email,
|
||||
password_hash=row.password_hash,
|
||||
system_role=row.system_role, # type: ignore[arg-type]
|
||||
# SQLite loses tzinfo on read; reattach UTC so downstream
|
||||
# code can compare timestamps reliably.
|
||||
created_at=row.created_at if row.created_at.tzinfo else row.created_at.replace(tzinfo=UTC),
|
||||
oauth_provider=row.oauth_provider,
|
||||
oauth_id=row.oauth_id,
|
||||
needs_setup=row.needs_setup,
|
||||
token_version=row.token_version,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _user_to_row(user: User) -> UserRow:
|
||||
return UserRow(
|
||||
id=str(user.id),
|
||||
email=user.email,
|
||||
password_hash=user.password_hash,
|
||||
system_role=user.system_role,
|
||||
created_at=user.created_at,
|
||||
oauth_provider=user.oauth_provider,
|
||||
oauth_id=user.oauth_id,
|
||||
needs_setup=user.needs_setup,
|
||||
token_version=user.token_version,
|
||||
)
|
||||
|
||||
# ── CRUD ──────────────────────────────────────────────────────────
|
||||
|
||||
async def create_user(self, user: User) -> User:
|
||||
"""Insert a new user. Raises ``ValueError`` on duplicate email."""
|
||||
row = self._user_to_row(user)
|
||||
async with self._sf() as session:
|
||||
session.add(row)
|
||||
try:
|
||||
await session.commit()
|
||||
except IntegrityError as exc:
|
||||
await session.rollback()
|
||||
raise ValueError(f"Email already registered: {user.email}") from exc
|
||||
return user
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||
async with self._sf() as session:
|
||||
row = await session.get(UserRow, user_id)
|
||||
return self._row_to_user(row) if row is not None else None
|
||||
|
||||
async def get_user_by_email(self, email: str) -> User | None:
|
||||
stmt = select(UserRow).where(UserRow.email == email)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
row = result.scalar_one_or_none()
|
||||
return self._row_to_user(row) if row is not None else None
|
||||
|
||||
async def update_user(self, user: User) -> User:
|
||||
async with self._sf() as session:
|
||||
row = await session.get(UserRow, str(user.id))
|
||||
if row is None:
|
||||
# Hard fail on concurrent delete: callers (reset_admin,
|
||||
# password change handlers, _ensure_admin_user) all
|
||||
# fetched the user just before this call, so a missing
|
||||
# row here means the row vanished underneath us. Silent
|
||||
# success would let the caller log "password reset" for
|
||||
# a row that no longer exists.
|
||||
raise UserNotFoundError(f"User {user.id} no longer exists")
|
||||
row.email = user.email
|
||||
row.password_hash = user.password_hash
|
||||
row.system_role = user.system_role
|
||||
row.oauth_provider = user.oauth_provider
|
||||
row.oauth_id = user.oauth_id
|
||||
row.needs_setup = user.needs_setup
|
||||
row.token_version = user.token_version
|
||||
await session.commit()
|
||||
return user
|
||||
|
||||
async def count_users(self) -> int:
|
||||
stmt = select(func.count()).select_from(UserRow)
|
||||
async with self._sf() as session:
|
||||
return await session.scalar(stmt) or 0
|
||||
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||
stmt = select(UserRow).where(UserRow.oauth_provider == provider, UserRow.oauth_id == oauth_id)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
row = result.scalar_one_or_none()
|
||||
return self._row_to_user(row) if row is not None else None
|
||||
@@ -0,0 +1,91 @@
|
||||
"""CLI tool to reset an admin password.
|
||||
|
||||
Usage:
|
||||
python -m app.gateway.auth.reset_admin
|
||||
python -m app.gateway.auth.reset_admin --email admin@example.com
|
||||
|
||||
Writes the new password to ``.deer-flow/admin_initial_credentials.txt``
|
||||
(mode 0600) instead of printing it, so CI / log aggregators never see
|
||||
the cleartext secret.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import secrets
|
||||
import sys
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.gateway.auth.credential_file import write_initial_credentials
|
||||
from app.gateway.auth.password import hash_password
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
from deerflow.persistence.user.model import UserRow
|
||||
|
||||
|
||||
async def _run(email: str | None) -> int:
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.persistence.engine import (
|
||||
close_engine,
|
||||
get_session_factory,
|
||||
init_engine_from_config,
|
||||
)
|
||||
|
||||
config = get_app_config()
|
||||
await init_engine_from_config(config.database)
|
||||
try:
|
||||
sf = get_session_factory()
|
||||
if sf is None:
|
||||
print("Error: persistence engine not available (check config.database).", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
repo = SQLiteUserRepository(sf)
|
||||
|
||||
if email:
|
||||
user = await repo.get_user_by_email(email)
|
||||
else:
|
||||
# Find first admin via direct SELECT — repository does not
|
||||
# expose a "first admin" helper and we do not want to add
|
||||
# one just for this CLI.
|
||||
async with sf() as session:
|
||||
stmt = select(UserRow).where(UserRow.system_role == "admin").limit(1)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if row is None:
|
||||
user = None
|
||||
else:
|
||||
user = await repo.get_user_by_id(row.id)
|
||||
|
||||
if user is None:
|
||||
if email:
|
||||
print(f"Error: user '{email}' not found.", file=sys.stderr)
|
||||
else:
|
||||
print("Error: no admin user found.", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
new_password = secrets.token_urlsafe(16)
|
||||
user.password_hash = hash_password(new_password)
|
||||
user.token_version += 1
|
||||
user.needs_setup = True
|
||||
await repo.update_user(user)
|
||||
|
||||
cred_path = write_initial_credentials(user.email, new_password, label="reset")
|
||||
print(f"Password reset for: {user.email}")
|
||||
print(f"Credentials written to: {cred_path} (mode 0600)")
|
||||
print("Next login will require setup (new email + password).")
|
||||
return 0
|
||||
finally:
|
||||
await close_engine()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Reset admin password")
|
||||
parser.add_argument("--email", help="Admin email (default: first admin found)")
|
||||
args = parser.parse_args()
|
||||
|
||||
exit_code = asyncio.run(_run(args.email))
|
||||
sys.exit(exit_code)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,117 @@
|
||||
"""Global authentication middleware — fail-closed safety net.
|
||||
|
||||
Rejects unauthenticated requests to non-public paths with 401. When a
|
||||
request passes the cookie check, resolves the JWT payload to a real
|
||||
``User`` object and stamps it into both ``request.state.user`` and the
|
||||
``deerflow.runtime.user_context`` contextvar so that repository-layer
|
||||
owner filtering works automatically via the sentinel pattern.
|
||||
|
||||
Fine-grained permission checks remain in authz.py decorators.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import HTTPException, Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
|
||||
from app.gateway.authz import _ALL_PERMISSIONS, AuthContext
|
||||
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
||||
|
||||
# Paths that never require authentication.
|
||||
_PUBLIC_PATH_PREFIXES: tuple[str, ...] = (
|
||||
"/health",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
)
|
||||
|
||||
# Exact auth paths that are public (login/register/status check).
|
||||
# /api/v1/auth/me, /api/v1/auth/change-password etc. are NOT public.
|
||||
_PUBLIC_EXACT_PATHS: frozenset[str] = frozenset(
|
||||
{
|
||||
"/api/v1/auth/login/local",
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/logout",
|
||||
"/api/v1/auth/setup-status",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _is_public(path: str) -> bool:
|
||||
stripped = path.rstrip("/")
|
||||
if stripped in _PUBLIC_EXACT_PATHS:
|
||||
return True
|
||||
return any(path.startswith(prefix) for prefix in _PUBLIC_PATH_PREFIXES)
|
||||
|
||||
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
"""Strict auth gate: reject requests without a valid session.
|
||||
|
||||
Two-stage check for non-public paths:
|
||||
|
||||
1. Cookie presence — return 401 NOT_AUTHENTICATED if missing
|
||||
2. JWT validation via ``get_optional_user_from_request`` — return 401
|
||||
TOKEN_INVALID if the token is absent, malformed, expired, or the
|
||||
signed user does not exist / is stale
|
||||
|
||||
On success, stamps ``request.state.user`` and the
|
||||
``deerflow.runtime.user_context`` contextvar so that repository-layer
|
||||
owner filters work downstream without every route needing a
|
||||
``@require_auth`` decorator. Routes that need per-resource
|
||||
authorization (e.g. "user A cannot read user B's thread by guessing
|
||||
the URL") should additionally use ``@require_permission(...,
|
||||
owner_check=True)`` for explicit enforcement — but authentication
|
||||
itself is fully handled here.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
super().__init__(app)
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
if _is_public(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Non-public path: require session cookie
|
||||
if not request.cookies.get("access_token"):
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"detail": AuthErrorResponse(
|
||||
code=AuthErrorCode.NOT_AUTHENTICATED,
|
||||
message="Authentication required",
|
||||
).model_dump()
|
||||
},
|
||||
)
|
||||
|
||||
# Strict JWT validation: reject junk/expired tokens with 401
|
||||
# right here instead of silently passing through. This closes
|
||||
# the "junk cookie bypass" gap (AUTH_TEST_PLAN test 7.5.8):
|
||||
# without this, non-isolation routes like /api/models would
|
||||
# accept any cookie-shaped string as authentication.
|
||||
#
|
||||
# We call the *strict* resolver so that fine-grained error
|
||||
# codes (token_expired, token_invalid, user_not_found, …)
|
||||
# propagate from AuthErrorCode, not get flattened into one
|
||||
# generic code. BaseHTTPMiddleware doesn't let HTTPException
|
||||
# bubble up, so we catch and render it as JSONResponse here.
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
try:
|
||||
user = await get_current_user_from_request(request)
|
||||
except HTTPException as exc:
|
||||
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
||||
|
||||
# Stamp both request.state.user (for the contextvar pattern)
|
||||
# and request.state.auth (so @require_permission's "auth is
|
||||
# None" branch short-circuits instead of running the entire
|
||||
# JWT-decode + DB-lookup pipeline a second time per request).
|
||||
request.state.user = user
|
||||
request.state.auth = AuthContext(user=user, permissions=_ALL_PERMISSIONS)
|
||||
token = set_current_user(user)
|
||||
try:
|
||||
return await call_next(request)
|
||||
finally:
|
||||
reset_current_user(token)
|
||||
@@ -0,0 +1,262 @@
|
||||
"""Authorization decorators and context for DeerFlow.
|
||||
|
||||
Inspired by LangGraph Auth system: https://github.com/langchain-ai/langgraph/blob/main/libs/sdk-py/langgraph_sdk/auth/__init__.py
|
||||
|
||||
**Usage:**
|
||||
|
||||
1. Use ``@require_auth`` on routes that need authentication
|
||||
2. Use ``@require_permission("resource", "action", filter_key=...)`` for permission checks
|
||||
3. The decorator chain processes from bottom to top
|
||||
|
||||
**Example:**
|
||||
|
||||
@router.get("/{thread_id}")
|
||||
@require_auth
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def get_thread(thread_id: str, request: Request):
|
||||
# User is authenticated and has threads:read permission
|
||||
...
|
||||
|
||||
**Permission Model:**
|
||||
|
||||
- threads:read - View thread
|
||||
- threads:write - Create/update thread
|
||||
- threads:delete - Delete thread
|
||||
- runs:create - Run agent
|
||||
- runs:read - View run
|
||||
- runs:cancel - Cancel run
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.gateway.auth.models import User
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
# Permission constants
|
||||
class Permissions:
|
||||
"""Permission constants for resource:action format."""
|
||||
|
||||
# Threads
|
||||
THREADS_READ = "threads:read"
|
||||
THREADS_WRITE = "threads:write"
|
||||
THREADS_DELETE = "threads:delete"
|
||||
|
||||
# Runs
|
||||
RUNS_CREATE = "runs:create"
|
||||
RUNS_READ = "runs:read"
|
||||
RUNS_CANCEL = "runs:cancel"
|
||||
|
||||
|
||||
class AuthContext:
|
||||
"""Authentication context for the current request.
|
||||
|
||||
Stored in request.state.auth after require_auth decoration.
|
||||
|
||||
Attributes:
|
||||
user: The authenticated user, or None if anonymous
|
||||
permissions: List of permission strings (e.g., "threads:read")
|
||||
"""
|
||||
|
||||
__slots__ = ("user", "permissions")
|
||||
|
||||
def __init__(self, user: User | None = None, permissions: list[str] | None = None):
|
||||
self.user = user
|
||||
self.permissions = permissions or []
|
||||
|
||||
@property
|
||||
def is_authenticated(self) -> bool:
|
||||
"""Check if user is authenticated."""
|
||||
return self.user is not None
|
||||
|
||||
def has_permission(self, resource: str, action: str) -> bool:
|
||||
"""Check if context has permission for resource:action.
|
||||
|
||||
Args:
|
||||
resource: Resource name (e.g., "threads")
|
||||
action: Action name (e.g., "read")
|
||||
|
||||
Returns:
|
||||
True if user has permission
|
||||
"""
|
||||
permission = f"{resource}:{action}"
|
||||
return permission in self.permissions
|
||||
|
||||
def require_user(self) -> User:
|
||||
"""Get user or raise 401.
|
||||
|
||||
Raises:
|
||||
HTTPException 401 if not authenticated
|
||||
"""
|
||||
if not self.user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
return self.user
|
||||
|
||||
|
||||
def get_auth_context(request: Request) -> AuthContext | None:
|
||||
"""Get AuthContext from request state."""
|
||||
return getattr(request.state, "auth", None)
|
||||
|
||||
|
||||
_ALL_PERMISSIONS: list[str] = [
|
||||
Permissions.THREADS_READ,
|
||||
Permissions.THREADS_WRITE,
|
||||
Permissions.THREADS_DELETE,
|
||||
Permissions.RUNS_CREATE,
|
||||
Permissions.RUNS_READ,
|
||||
Permissions.RUNS_CANCEL,
|
||||
]
|
||||
|
||||
|
||||
async def _authenticate(request: Request) -> AuthContext:
|
||||
"""Authenticate request and return AuthContext.
|
||||
|
||||
Delegates to deps.get_optional_user_from_request() for the JWT→User pipeline.
|
||||
Returns AuthContext with user=None for anonymous requests.
|
||||
"""
|
||||
from app.gateway.deps import get_optional_user_from_request
|
||||
|
||||
user = await get_optional_user_from_request(request)
|
||||
if user is None:
|
||||
return AuthContext(user=None, permissions=[])
|
||||
|
||||
# In future, permissions could be stored in user record
|
||||
return AuthContext(user=user, permissions=_ALL_PERMISSIONS)
|
||||
|
||||
|
||||
def require_auth[**P, T](func: Callable[P, T]) -> Callable[P, T]:
|
||||
"""Decorator that authenticates the request and sets AuthContext.
|
||||
|
||||
Must be placed ABOVE other decorators (executes after them).
|
||||
|
||||
Usage:
|
||||
@router.get("/{thread_id}")
|
||||
@require_auth # Bottom decorator (executes first after permission check)
|
||||
@require_permission("threads", "read")
|
||||
async def get_thread(thread_id: str, request: Request):
|
||||
auth: AuthContext = request.state.auth
|
||||
...
|
||||
|
||||
Raises:
|
||||
ValueError: If 'request' parameter is missing
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
request = kwargs.get("request")
|
||||
if request is None:
|
||||
raise ValueError("require_auth decorator requires 'request' parameter")
|
||||
|
||||
# Authenticate and set context
|
||||
auth_context = await _authenticate(request)
|
||||
request.state.auth = auth_context
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_permission(
|
||||
resource: str,
|
||||
action: str,
|
||||
owner_check: bool = False,
|
||||
require_existing: bool = False,
|
||||
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
||||
"""Decorator that checks permission for resource:action.
|
||||
|
||||
Must be used AFTER @require_auth.
|
||||
|
||||
Args:
|
||||
resource: Resource name (e.g., "threads", "runs")
|
||||
action: Action name (e.g., "read", "write", "delete")
|
||||
owner_check: If True, validates that the current user owns the resource.
|
||||
Requires 'thread_id' path parameter and performs ownership check.
|
||||
require_existing: Only meaningful with ``owner_check=True``. If True, a
|
||||
missing ``threads_meta`` row counts as a denial (404)
|
||||
instead of "untracked legacy thread, allow". Use on
|
||||
**destructive / mutating** routes (DELETE, PATCH,
|
||||
state-update) so a deleted thread can't be re-targeted
|
||||
by another user via the missing-row code path.
|
||||
|
||||
Usage:
|
||||
# Read-style: legacy untracked threads are allowed
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def get_thread(thread_id: str, request: Request):
|
||||
...
|
||||
|
||||
# Destructive: thread row MUST exist and be owned by caller
|
||||
@require_permission("threads", "delete", owner_check=True, require_existing=True)
|
||||
async def delete_thread(thread_id: str, request: Request):
|
||||
...
|
||||
|
||||
Raises:
|
||||
HTTPException 401: If authentication required but user is anonymous
|
||||
HTTPException 403: If user lacks permission
|
||||
HTTPException 404: If owner_check=True but user doesn't own the thread
|
||||
ValueError: If owner_check=True but 'thread_id' parameter is missing
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[P, T]) -> Callable[P, T]:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
request = kwargs.get("request")
|
||||
if request is None:
|
||||
raise ValueError("require_permission decorator requires 'request' parameter")
|
||||
|
||||
auth: AuthContext = getattr(request.state, "auth", None)
|
||||
if auth is None:
|
||||
auth = await _authenticate(request)
|
||||
request.state.auth = auth
|
||||
|
||||
if not auth.is_authenticated:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
# Check permission
|
||||
if not auth.has_permission(resource, action):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Permission denied: {resource}:{action}",
|
||||
)
|
||||
|
||||
# Owner check for thread-specific resources.
|
||||
#
|
||||
# 2.0-rc moved thread metadata into the SQL persistence layer
|
||||
# (``threads_meta`` table). We verify ownership via
|
||||
# ``ThreadMetaStore.check_access``: it returns True for
|
||||
# missing rows (untracked legacy thread) and for rows whose
|
||||
# ``owner_id`` is NULL (shared / pre-auth data), so this is
|
||||
# strict-deny rather than strict-allow — only an *existing*
|
||||
# row with a *different* owner_id triggers 404.
|
||||
if owner_check:
|
||||
thread_id = kwargs.get("thread_id")
|
||||
if thread_id is None:
|
||||
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
||||
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
allowed = await thread_meta_repo.check_access(
|
||||
thread_id,
|
||||
str(auth.user.id),
|
||||
require_existing=require_existing,
|
||||
)
|
||||
if not allowed:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Thread {thread_id} not found",
|
||||
)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
@@ -0,0 +1,112 @@
|
||||
"""CSRF protection middleware for FastAPI.
|
||||
|
||||
Per RFC-001:
|
||||
State-changing operations require CSRF protection.
|
||||
"""
|
||||
|
||||
import secrets
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
CSRF_COOKIE_NAME = "csrf_token"
|
||||
CSRF_HEADER_NAME = "X-CSRF-Token"
|
||||
CSRF_TOKEN_LENGTH = 64 # bytes
|
||||
|
||||
|
||||
def is_secure_request(request: Request) -> bool:
|
||||
"""Detect whether the original client request was made over HTTPS."""
|
||||
return request.headers.get("x-forwarded-proto", request.url.scheme) == "https"
|
||||
|
||||
|
||||
def generate_csrf_token() -> str:
|
||||
"""Generate a secure random CSRF token."""
|
||||
return secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
|
||||
|
||||
|
||||
def should_check_csrf(request: Request) -> bool:
|
||||
"""Determine if a request needs CSRF validation.
|
||||
|
||||
CSRF is checked for state-changing methods (POST, PUT, DELETE, PATCH).
|
||||
GET, HEAD, OPTIONS, and TRACE are exempt per RFC 7231.
|
||||
"""
|
||||
if request.method not in ("POST", "PUT", "DELETE", "PATCH"):
|
||||
return False
|
||||
|
||||
path = request.url.path.rstrip("/")
|
||||
# Exempt /api/v1/auth/me endpoint
|
||||
if path == "/api/v1/auth/me":
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
_AUTH_EXEMPT_PATHS: frozenset[str] = frozenset(
|
||||
{
|
||||
"/api/v1/auth/login/local",
|
||||
"/api/v1/auth/logout",
|
||||
"/api/v1/auth/register",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def is_auth_endpoint(request: Request) -> bool:
|
||||
"""Check if the request is to an auth endpoint.
|
||||
|
||||
Auth endpoints don't need CSRF validation on first call (no token).
|
||||
"""
|
||||
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS
|
||||
|
||||
|
||||
class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware that implements CSRF protection using Double Submit Cookie pattern."""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
super().__init__(app)
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
_is_auth = is_auth_endpoint(request)
|
||||
|
||||
if should_check_csrf(request) and not _is_auth:
|
||||
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
|
||||
header_token = request.headers.get(CSRF_HEADER_NAME)
|
||||
|
||||
if not cookie_token or not header_token:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "CSRF token missing. Include X-CSRF-Token header."},
|
||||
)
|
||||
|
||||
if not secrets.compare_digest(cookie_token, header_token):
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "CSRF token mismatch."},
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
# For auth endpoints that set up session, also set CSRF cookie
|
||||
if _is_auth and request.method == "POST":
|
||||
# Generate a new CSRF token for the session
|
||||
csrf_token = generate_csrf_token()
|
||||
is_https = is_secure_request(request)
|
||||
response.set_cookie(
|
||||
key=CSRF_COOKIE_NAME,
|
||||
value=csrf_token,
|
||||
httponly=False, # Must be JS-readable for Double Submit Cookie pattern
|
||||
secure=is_https,
|
||||
samesite="strict",
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def get_csrf_token(request: Request) -> str | None:
|
||||
"""Get the CSRF token from the current request's cookies.
|
||||
|
||||
This is useful for server-side rendering where you need to embed
|
||||
token in forms or headers.
|
||||
"""
|
||||
return request.cookies.get(CSRF_COOKIE_NAME)
|
||||
+180
-25
@@ -1,7 +1,8 @@
|
||||
"""Centralized accessors for singleton objects stored on ``app.state``.
|
||||
|
||||
**Getters** (used by routers): raise 503 when a required dependency is
|
||||
missing, except ``get_store`` which returns ``None``.
|
||||
missing, except ``get_store`` and ``get_thread_meta_repo`` which return
|
||||
``None``.
|
||||
|
||||
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
|
||||
"""
|
||||
@@ -10,10 +11,15 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
|
||||
from deerflow.runtime import RunManager, StreamBridge
|
||||
from deerflow.runtime import RunContext, RunManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -26,45 +32,194 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
yield
|
||||
"""
|
||||
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config
|
||||
from deerflow.runtime import make_store, make_stream_bridge
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
async with AsyncExitStack() as stack:
|
||||
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge())
|
||||
|
||||
# Initialize persistence engine BEFORE checkpointer so that
|
||||
# auto-create-database logic runs first (postgres backend).
|
||||
config = get_app_config()
|
||||
await init_engine_from_config(config.database)
|
||||
|
||||
app.state.checkpointer = await stack.enter_async_context(make_checkpointer())
|
||||
app.state.store = await stack.enter_async_context(make_store())
|
||||
app.state.run_manager = RunManager()
|
||||
yield
|
||||
|
||||
# Initialize repositories — one get_session_factory() call for all.
|
||||
sf = get_session_factory()
|
||||
if sf is not None:
|
||||
from deerflow.persistence.feedback import FeedbackRepository
|
||||
from deerflow.persistence.run import RunRepository
|
||||
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
||||
|
||||
app.state.run_store = RunRepository(sf)
|
||||
app.state.feedback_repo = FeedbackRepository(sf)
|
||||
app.state.thread_meta_repo = ThreadMetaRepository(sf)
|
||||
else:
|
||||
from deerflow.persistence.thread_meta import MemoryThreadMetaStore
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
app.state.run_store = MemoryRunStore()
|
||||
app.state.feedback_repo = None
|
||||
app.state.thread_meta_repo = MemoryThreadMetaStore(app.state.store)
|
||||
|
||||
# Run event store (has its own factory with config-driven backend selection)
|
||||
run_events_config = getattr(config, "run_events", None)
|
||||
app.state.run_event_store = make_run_event_store(run_events_config)
|
||||
|
||||
# RunManager with store backing for persistence
|
||||
app.state.run_manager = RunManager(store=app.state.run_store)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
await close_engine()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Getters – called by routers per-request
|
||||
# Getters -- called by routers per-request
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_stream_bridge(request: Request) -> StreamBridge:
|
||||
"""Return the global :class:`StreamBridge`, or 503."""
|
||||
bridge = getattr(request.app.state, "stream_bridge", None)
|
||||
if bridge is None:
|
||||
raise HTTPException(status_code=503, detail="Stream bridge not available")
|
||||
return bridge
|
||||
def _require(attr: str, label: str):
|
||||
"""Create a FastAPI dependency that returns ``app.state.<attr>`` or 503."""
|
||||
|
||||
def dep(request: Request):
|
||||
val = getattr(request.app.state, attr, None)
|
||||
if val is None:
|
||||
raise HTTPException(status_code=503, detail=f"{label} not available")
|
||||
return val
|
||||
|
||||
dep.__name__ = dep.__qualname__ = f"get_{attr}"
|
||||
return dep
|
||||
|
||||
|
||||
def get_run_manager(request: Request) -> RunManager:
|
||||
"""Return the global :class:`RunManager`, or 503."""
|
||||
mgr = getattr(request.app.state, "run_manager", None)
|
||||
if mgr is None:
|
||||
raise HTTPException(status_code=503, detail="Run manager not available")
|
||||
return mgr
|
||||
|
||||
|
||||
def get_checkpointer(request: Request):
|
||||
"""Return the global checkpointer, or 503."""
|
||||
cp = getattr(request.app.state, "checkpointer", None)
|
||||
if cp is None:
|
||||
raise HTTPException(status_code=503, detail="Checkpointer not available")
|
||||
return cp
|
||||
get_stream_bridge = _require("stream_bridge", "Stream bridge")
|
||||
get_run_manager = _require("run_manager", "Run manager")
|
||||
get_checkpointer = _require("checkpointer", "Checkpointer")
|
||||
get_run_event_store = _require("run_event_store", "Run event store")
|
||||
get_feedback_repo = _require("feedback_repo", "Feedback")
|
||||
get_run_store = _require("run_store", "Run store")
|
||||
|
||||
|
||||
def get_store(request: Request):
|
||||
"""Return the global store (may be ``None`` if not configured)."""
|
||||
return getattr(request.app.state, "store", None)
|
||||
|
||||
|
||||
get_thread_meta_repo = _require("thread_meta_repo", "Thread metadata store")
|
||||
|
||||
|
||||
def get_run_context(request: Request) -> RunContext:
|
||||
"""Build a :class:`RunContext` from ``app.state`` singletons.
|
||||
|
||||
Returns a *base* context with infrastructure dependencies. Callers that
|
||||
need per-run fields (e.g. ``follow_up_to_run_id``) should use
|
||||
``dataclasses.replace(ctx, follow_up_to_run_id=...)`` before passing it
|
||||
to :func:`run_agent`.
|
||||
"""
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
return RunContext(
|
||||
checkpointer=get_checkpointer(request),
|
||||
store=get_store(request),
|
||||
event_store=get_run_event_store(request),
|
||||
run_events_config=getattr(get_app_config(), "run_events", None),
|
||||
thread_meta_repo=get_thread_meta_repo(request),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth helpers (used by authz.py and auth middleware)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Cached singletons to avoid repeated instantiation per request
|
||||
_cached_local_provider: LocalAuthProvider | None = None
|
||||
_cached_repo: SQLiteUserRepository | None = None
|
||||
|
||||
|
||||
def get_local_provider() -> LocalAuthProvider:
|
||||
"""Get or create the cached LocalAuthProvider singleton.
|
||||
|
||||
Must be called after ``init_engine_from_config()`` — the shared
|
||||
session factory is required to construct the user repository.
|
||||
"""
|
||||
global _cached_local_provider, _cached_repo
|
||||
if _cached_repo is None:
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
|
||||
sf = get_session_factory()
|
||||
if sf is None:
|
||||
raise RuntimeError("get_local_provider() called before init_engine_from_config(); cannot access users table")
|
||||
_cached_repo = SQLiteUserRepository(sf)
|
||||
if _cached_local_provider is None:
|
||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||
|
||||
_cached_local_provider = LocalAuthProvider(repository=_cached_repo)
|
||||
return _cached_local_provider
|
||||
|
||||
|
||||
async def get_current_user_from_request(request: Request):
|
||||
"""Get the current authenticated user from the request cookie.
|
||||
|
||||
Raises HTTPException 401 if not authenticated.
|
||||
"""
|
||||
from app.gateway.auth import decode_token
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code
|
||||
|
||||
access_token = request.cookies.get("access_token")
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.NOT_AUTHENTICATED, message="Not authenticated").model_dump(),
|
||||
)
|
||||
|
||||
payload = decode_token(access_token)
|
||||
if isinstance(payload, TokenError):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=AuthErrorResponse(code=token_error_to_code(payload), message=f"Token error: {payload.value}").model_dump(),
|
||||
)
|
||||
|
||||
provider = get_local_provider()
|
||||
user = await provider.get_user(payload.sub)
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.USER_NOT_FOUND, message="User not found").model_dump(),
|
||||
)
|
||||
|
||||
# Token version mismatch → password was changed, token is stale
|
||||
if user.token_version != payload.ver:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.TOKEN_INVALID, message="Token revoked (password changed)").model_dump(),
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_optional_user_from_request(request: Request):
|
||||
"""Get optional authenticated user from request.
|
||||
|
||||
Returns None if not authenticated.
|
||||
"""
|
||||
try:
|
||||
return await get_current_user_from_request(request)
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
|
||||
async def get_current_user(request: Request) -> str | None:
|
||||
"""Extract user_id from request cookie, or None if not authenticated.
|
||||
|
||||
Thin adapter that returns the string id for callers that only need
|
||||
identification (e.g., ``feedback.py``). Full-user callers should use
|
||||
``get_current_user_from_request`` or ``get_optional_user_from_request``.
|
||||
"""
|
||||
user = await get_optional_user_from_request(request)
|
||||
return str(user.id) if user else None
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
"""LangGraph Server auth handler — shares JWT logic with Gateway.
|
||||
|
||||
Loaded by LangGraph Server via langgraph.json ``auth.path``.
|
||||
Reuses the same ``decode_token`` / ``get_auth_config`` as Gateway,
|
||||
so both modes validate tokens with the same secret and rules.
|
||||
|
||||
Two layers:
|
||||
1. @auth.authenticate — validates JWT cookie, extracts user_id,
|
||||
and enforces CSRF on state-changing methods (POST/PUT/DELETE/PATCH)
|
||||
2. @auth.on — returns metadata filter so each user only sees own threads
|
||||
"""
|
||||
|
||||
import secrets
|
||||
|
||||
from langgraph_sdk import Auth
|
||||
|
||||
from app.gateway.auth.errors import TokenError
|
||||
from app.gateway.auth.jwt import decode_token
|
||||
from app.gateway.deps import get_local_provider
|
||||
|
||||
auth = Auth()
|
||||
|
||||
# Methods that require CSRF validation (state-changing per RFC 7231).
|
||||
_CSRF_METHODS = frozenset({"POST", "PUT", "DELETE", "PATCH"})
|
||||
|
||||
|
||||
def _check_csrf(request) -> None:
|
||||
"""Enforce Double Submit Cookie CSRF check for state-changing requests.
|
||||
|
||||
Mirrors Gateway's CSRFMiddleware logic so that LangGraph routes
|
||||
proxied directly by nginx have the same CSRF protection.
|
||||
"""
|
||||
method = getattr(request, "method", "") or ""
|
||||
if method.upper() not in _CSRF_METHODS:
|
||||
return
|
||||
|
||||
cookie_token = request.cookies.get("csrf_token")
|
||||
header_token = request.headers.get("x-csrf-token")
|
||||
|
||||
if not cookie_token or not header_token:
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=403,
|
||||
detail="CSRF token missing. Include X-CSRF-Token header.",
|
||||
)
|
||||
|
||||
if not secrets.compare_digest(cookie_token, header_token):
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=403,
|
||||
detail="CSRF token mismatch.",
|
||||
)
|
||||
|
||||
|
||||
@auth.authenticate
|
||||
async def authenticate(request):
|
||||
"""Validate the session cookie, decode JWT, and check token_version.
|
||||
|
||||
Same validation chain as Gateway's get_current_user_from_request:
|
||||
cookie → decode JWT → DB lookup → token_version match
|
||||
Also enforces CSRF on state-changing methods.
|
||||
"""
|
||||
# CSRF check before authentication so forged cross-site requests
|
||||
# are rejected early, even if the cookie carries a valid JWT.
|
||||
_check_csrf(request)
|
||||
|
||||
token = request.cookies.get("access_token")
|
||||
if not token:
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=401,
|
||||
detail="Not authenticated",
|
||||
)
|
||||
|
||||
payload = decode_token(token)
|
||||
if isinstance(payload, TokenError):
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=401,
|
||||
detail=f"Token error: {payload.value}",
|
||||
)
|
||||
|
||||
user = await get_local_provider().get_user(payload.sub)
|
||||
if user is None:
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=401,
|
||||
detail="User not found",
|
||||
)
|
||||
if user.token_version != payload.ver:
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=401,
|
||||
detail="Token revoked (password changed)",
|
||||
)
|
||||
|
||||
return payload.sub
|
||||
|
||||
|
||||
@auth.on
|
||||
async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict):
|
||||
"""Inject owner_id metadata on writes; filter by owner_id on reads.
|
||||
|
||||
Gateway stores thread ownership as ``metadata.owner_id``.
|
||||
This handler ensures LangGraph Server enforces the same isolation.
|
||||
"""
|
||||
# On create/update: stamp owner_id into metadata
|
||||
metadata = value.setdefault("metadata", {})
|
||||
metadata["owner_id"] = ctx.user.identity
|
||||
|
||||
# Return filter dict — LangGraph applies it to search/read/delete
|
||||
return {"owner_id": ctx.user.identity}
|
||||
@@ -7,6 +7,7 @@ from urllib.parse import quote
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import FileResponse, PlainTextResponse, Response
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
from app.gateway.path_utils import resolve_thread_virtual_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -81,6 +82,7 @@ def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> byte
|
||||
summary="Get Artifact File",
|
||||
description="Retrieve an artifact file generated by the AI agent. Text and binary files can be viewed inline, while active web content is always downloaded.",
|
||||
)
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def get_artifact(thread_id: str, path: str, request: Request, download: bool = False) -> Response:
|
||||
"""Get an artifact file by its path.
|
||||
|
||||
|
||||
@@ -0,0 +1,418 @@
|
||||
"""Authentication endpoints."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from ipaddress import ip_address, ip_network
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from pydantic import BaseModel, EmailStr, Field, field_validator
|
||||
|
||||
from app.gateway.auth import (
|
||||
UserResponse,
|
||||
create_access_token,
|
||||
)
|
||||
from app.gateway.auth.config import get_auth_config
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
|
||||
from app.gateway.csrf_middleware import is_secure_request
|
||||
from app.gateway.deps import get_current_user_from_request, get_local_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
||||
|
||||
|
||||
# ── Request/Response Models ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
"""Response model for login — token only lives in HttpOnly cookie."""
|
||||
|
||||
expires_in: int # seconds
|
||||
needs_setup: bool = False
|
||||
|
||||
|
||||
# Top common-password blocklist. Drawn from the public SecLists "10k worst
|
||||
# passwords" set, lowercased + length>=8 only (shorter ones already fail
|
||||
# the min_length check). Kept tight on purpose: this is the **lower bound**
|
||||
# defense, not a full HIBP / passlib check, and runs in-process per request.
|
||||
_COMMON_PASSWORDS: frozenset[str] = frozenset(
|
||||
{
|
||||
"password",
|
||||
"password1",
|
||||
"password12",
|
||||
"password123",
|
||||
"password1234",
|
||||
"12345678",
|
||||
"123456789",
|
||||
"1234567890",
|
||||
"qwerty12",
|
||||
"qwertyui",
|
||||
"qwerty123",
|
||||
"abc12345",
|
||||
"abcd1234",
|
||||
"iloveyou",
|
||||
"letmein1",
|
||||
"welcome1",
|
||||
"welcome123",
|
||||
"admin123",
|
||||
"administrator",
|
||||
"passw0rd",
|
||||
"p@ssw0rd",
|
||||
"monkey12",
|
||||
"trustno1",
|
||||
"sunshine",
|
||||
"princess",
|
||||
"football",
|
||||
"baseball",
|
||||
"superman",
|
||||
"batman123",
|
||||
"starwars",
|
||||
"dragon123",
|
||||
"master123",
|
||||
"shadow12",
|
||||
"michael1",
|
||||
"jennifer",
|
||||
"computer",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _password_is_common(password: str) -> bool:
|
||||
"""Case-insensitive blocklist check.
|
||||
|
||||
Lowercases the input so trivial mutations like ``Password`` /
|
||||
``PASSWORD`` are also rejected. Does not normalize digit substitutions
|
||||
(``p@ssw0rd`` is included as a literal entry instead) — keeping the
|
||||
rule cheap and predictable.
|
||||
"""
|
||||
return password.lower() in _COMMON_PASSWORDS
|
||||
|
||||
|
||||
def _validate_strong_password(value: str) -> str:
|
||||
"""Pydantic field-validator body shared by Register + ChangePassword.
|
||||
|
||||
Constraint = function, not type-level mixin. The two request models
|
||||
have no "is-a" relationship; they only share the password-strength
|
||||
rule. Lifting it into a free function lets each model bind it via
|
||||
``@field_validator(field_name)`` without inheritance gymnastics.
|
||||
"""
|
||||
if _password_is_common(value):
|
||||
raise ValueError("Password is too common; choose a stronger password.")
|
||||
return value
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
"""Request model for user registration."""
|
||||
|
||||
email: EmailStr
|
||||
password: str = Field(..., min_length=8)
|
||||
|
||||
_strong_password = field_validator("password")(classmethod(lambda cls, v: _validate_strong_password(v)))
|
||||
|
||||
|
||||
class ChangePasswordRequest(BaseModel):
|
||||
"""Request model for password change (also handles setup flow)."""
|
||||
|
||||
current_password: str
|
||||
new_password: str = Field(..., min_length=8)
|
||||
new_email: EmailStr | None = None
|
||||
|
||||
_strong_password = field_validator("new_password")(classmethod(lambda cls, v: _validate_strong_password(v)))
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
"""Generic message response."""
|
||||
|
||||
message: str
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _set_session_cookie(response: Response, token: str, request: Request) -> None:
|
||||
"""Set the access_token HttpOnly cookie on the response."""
|
||||
config = get_auth_config()
|
||||
is_https = is_secure_request(request)
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=is_https,
|
||||
samesite="lax",
|
||||
max_age=config.token_expiry_days * 24 * 3600 if is_https else None,
|
||||
)
|
||||
|
||||
|
||||
# ── Rate Limiting ────────────────────────────────────────────────────────
|
||||
# In-process dict — not shared across workers. Sufficient for single-worker deployments.
|
||||
|
||||
_MAX_LOGIN_ATTEMPTS = 5
|
||||
_LOCKOUT_SECONDS = 300 # 5 minutes
|
||||
|
||||
# ip → (fail_count, lock_until_timestamp)
|
||||
_login_attempts: dict[str, tuple[int, float]] = {}
|
||||
|
||||
|
||||
def _trusted_proxies() -> list:
|
||||
"""Parse ``AUTH_TRUSTED_PROXIES`` env var into a list of ip_network objects.
|
||||
|
||||
Comma-separated CIDR or single-IP entries. Empty / unset = no proxy is
|
||||
trusted (direct mode). Invalid entries are skipped with a logger warning.
|
||||
Read live so env-var overrides take effect immediately and tests can
|
||||
``monkeypatch.setenv`` without poking a module-level cache.
|
||||
"""
|
||||
raw = os.getenv("AUTH_TRUSTED_PROXIES", "").strip()
|
||||
if not raw:
|
||||
return []
|
||||
nets = []
|
||||
for entry in raw.split(","):
|
||||
entry = entry.strip()
|
||||
if not entry:
|
||||
continue
|
||||
try:
|
||||
nets.append(ip_network(entry, strict=False))
|
||||
except ValueError:
|
||||
logger.warning("AUTH_TRUSTED_PROXIES: ignoring invalid entry %r", entry)
|
||||
return nets
|
||||
|
||||
|
||||
def _get_client_ip(request: Request) -> str:
|
||||
"""Extract the real client IP for rate limiting.
|
||||
|
||||
Trust model:
|
||||
|
||||
- The TCP peer (``request.client.host``) is always the baseline. It is
|
||||
whatever the kernel reports as the connecting socket — unforgeable
|
||||
by the client itself.
|
||||
- ``X-Real-IP`` is **only** honored if the TCP peer is in the
|
||||
``AUTH_TRUSTED_PROXIES`` allowlist (set via env var, comma-separated
|
||||
CIDR or single IPs). When set, the gateway is assumed to be behind a
|
||||
reverse proxy (nginx, Cloudflare, ALB, …) that overwrites
|
||||
``X-Real-IP`` with the original client address.
|
||||
- With no ``AUTH_TRUSTED_PROXIES`` set, ``X-Real-IP`` is silently
|
||||
ignored — closing the bypass where any client could rotate the
|
||||
header to dodge per-IP rate limits in dev / direct-gateway mode.
|
||||
|
||||
``X-Forwarded-For`` is intentionally NOT used because it is naturally
|
||||
client-controlled at the *first* hop and the trust chain is harder to
|
||||
audit per-request.
|
||||
"""
|
||||
peer_host = request.client.host if request.client else None
|
||||
|
||||
trusted = _trusted_proxies()
|
||||
if trusted and peer_host:
|
||||
try:
|
||||
peer_ip = ip_address(peer_host)
|
||||
if any(peer_ip in net for net in trusted):
|
||||
real_ip = request.headers.get("x-real-ip", "").strip()
|
||||
if real_ip:
|
||||
return real_ip
|
||||
except ValueError:
|
||||
# peer_host wasn't a parseable IP (e.g. "unknown") — fall through
|
||||
pass
|
||||
|
||||
return peer_host or "unknown"
|
||||
|
||||
|
||||
def _check_rate_limit(ip: str) -> None:
|
||||
"""Raise 429 if the IP is currently locked out."""
|
||||
record = _login_attempts.get(ip)
|
||||
if record is None:
|
||||
return
|
||||
fail_count, lock_until = record
|
||||
if fail_count >= _MAX_LOGIN_ATTEMPTS:
|
||||
if time.time() < lock_until:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="Too many login attempts. Try again later.",
|
||||
)
|
||||
del _login_attempts[ip]
|
||||
|
||||
|
||||
_MAX_TRACKED_IPS = 10000
|
||||
|
||||
|
||||
def _record_login_failure(ip: str) -> None:
|
||||
"""Record a failed login attempt for the given IP."""
|
||||
# Evict expired lockouts when dict grows too large
|
||||
if len(_login_attempts) >= _MAX_TRACKED_IPS:
|
||||
now = time.time()
|
||||
expired = [k for k, (c, t) in _login_attempts.items() if c >= _MAX_LOGIN_ATTEMPTS and now >= t]
|
||||
for k in expired:
|
||||
del _login_attempts[k]
|
||||
# If still too large, evict cheapest-to-lose half: below-threshold
|
||||
# IPs (lock_until=0.0) sort first, then earliest-expiring lockouts.
|
||||
if len(_login_attempts) >= _MAX_TRACKED_IPS:
|
||||
by_time = sorted(_login_attempts.items(), key=lambda kv: kv[1][1])
|
||||
for k, _ in by_time[: len(by_time) // 2]:
|
||||
del _login_attempts[k]
|
||||
|
||||
record = _login_attempts.get(ip)
|
||||
if record is None:
|
||||
_login_attempts[ip] = (1, 0.0)
|
||||
else:
|
||||
new_count = record[0] + 1
|
||||
lock_until = time.time() + _LOCKOUT_SECONDS if new_count >= _MAX_LOGIN_ATTEMPTS else 0.0
|
||||
_login_attempts[ip] = (new_count, lock_until)
|
||||
|
||||
|
||||
def _record_login_success(ip: str) -> None:
|
||||
"""Clear failure counter for the given IP on successful login."""
|
||||
_login_attempts.pop(ip, None)
|
||||
|
||||
|
||||
# ── Endpoints ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/login/local", response_model=LoginResponse)
|
||||
async def login_local(
|
||||
request: Request,
|
||||
response: Response,
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
):
|
||||
"""Local email/password login."""
|
||||
client_ip = _get_client_ip(request)
|
||||
_check_rate_limit(client_ip)
|
||||
|
||||
user = await get_local_provider().authenticate({"email": form_data.username, "password": form_data.password})
|
||||
|
||||
if user is None:
|
||||
_record_login_failure(client_ip)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Incorrect email or password").model_dump(),
|
||||
)
|
||||
|
||||
_record_login_success(client_ip)
|
||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
||||
_set_session_cookie(response, token, request)
|
||||
|
||||
return LoginResponse(
|
||||
expires_in=get_auth_config().token_expiry_days * 24 * 3600,
|
||||
needs_setup=user.needs_setup,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def register(request: Request, response: Response, body: RegisterRequest):
|
||||
"""Register a new user account (always 'user' role).
|
||||
|
||||
Admin is auto-created on first boot. This endpoint creates regular users.
|
||||
Auto-login by setting the session cookie.
|
||||
"""
|
||||
try:
|
||||
user = await get_local_provider().create_user(email=body.email, password=body.password, system_role="user")
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already registered").model_dump(),
|
||||
)
|
||||
|
||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
||||
_set_session_cookie(response, token, request)
|
||||
|
||||
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
|
||||
|
||||
|
||||
@router.post("/logout", response_model=MessageResponse)
|
||||
async def logout(request: Request, response: Response):
|
||||
"""Logout current user by clearing the cookie."""
|
||||
response.delete_cookie(key="access_token", secure=is_secure_request(request), samesite="lax")
|
||||
return MessageResponse(message="Successfully logged out")
|
||||
|
||||
|
||||
@router.post("/change-password", response_model=MessageResponse)
|
||||
async def change_password(request: Request, response: Response, body: ChangePasswordRequest):
|
||||
"""Change password for the currently authenticated user.
|
||||
|
||||
Also handles the first-boot setup flow:
|
||||
- If new_email is provided, updates email (checks uniqueness)
|
||||
- If user.needs_setup is True and new_email is given, clears needs_setup
|
||||
- Always increments token_version to invalidate old sessions
|
||||
- Re-issues session cookie with new token_version
|
||||
"""
|
||||
from app.gateway.auth.password import hash_password_async, verify_password_async
|
||||
|
||||
user = await get_current_user_from_request(request)
|
||||
|
||||
if user.password_hash is None:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump())
|
||||
|
||||
if not await verify_password_async(body.current_password, user.password_hash):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Current password is incorrect").model_dump())
|
||||
|
||||
provider = get_local_provider()
|
||||
|
||||
# Update email if provided
|
||||
if body.new_email is not None:
|
||||
existing = await provider.get_user_by_email(body.new_email)
|
||||
if existing and str(existing.id) != str(user.id):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already in use").model_dump())
|
||||
user.email = body.new_email
|
||||
|
||||
# Update password + bump version
|
||||
user.password_hash = await hash_password_async(body.new_password)
|
||||
user.token_version += 1
|
||||
|
||||
# Clear setup flag if this is the setup flow
|
||||
if user.needs_setup and body.new_email is not None:
|
||||
user.needs_setup = False
|
||||
|
||||
await provider.update_user(user)
|
||||
|
||||
# Re-issue cookie with new token_version
|
||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
||||
_set_session_cookie(response, token, request)
|
||||
|
||||
return MessageResponse(message="Password changed successfully")
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_me(request: Request):
|
||||
"""Get current authenticated user info."""
|
||||
user = await get_current_user_from_request(request)
|
||||
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role, needs_setup=user.needs_setup)
|
||||
|
||||
|
||||
@router.get("/setup-status")
|
||||
async def setup_status():
|
||||
"""Check if admin account exists. Always False after first boot."""
|
||||
user_count = await get_local_provider().count_users()
|
||||
return {"needs_setup": user_count == 0}
|
||||
|
||||
|
||||
# ── OAuth Endpoints (Future/Placeholder) ─────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/oauth/{provider}")
|
||||
async def oauth_login(provider: str):
|
||||
"""Initiate OAuth login flow.
|
||||
|
||||
Redirects to the OAuth provider's authorization URL.
|
||||
Currently a placeholder - requires OAuth provider implementation.
|
||||
"""
|
||||
if provider not in ["github", "google"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unsupported OAuth provider: {provider}",
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="OAuth login not yet implemented",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/callback/{provider}")
|
||||
async def oauth_callback(provider: str, code: str, state: str):
|
||||
"""OAuth callback endpoint.
|
||||
|
||||
Handles the OAuth provider's callback after user authorization.
|
||||
Currently a placeholder.
|
||||
"""
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="OAuth callback not yet implemented",
|
||||
)
|
||||
@@ -0,0 +1,132 @@
|
||||
"""Feedback endpoints — create, list, stats, delete.
|
||||
|
||||
Allows users to submit thumbs-up/down feedback on runs,
|
||||
optionally scoped to a specific message.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
from app.gateway.deps import get_current_user, get_feedback_repo, get_run_store
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/threads", tags=["feedback"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request / response models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class FeedbackCreateRequest(BaseModel):
|
||||
rating: int = Field(..., description="Feedback rating: +1 (positive) or -1 (negative)")
|
||||
comment: str | None = Field(default=None, description="Optional text feedback")
|
||||
message_id: str | None = Field(default=None, description="Optional: scope feedback to a specific message")
|
||||
|
||||
|
||||
class FeedbackResponse(BaseModel):
|
||||
feedback_id: str
|
||||
run_id: str
|
||||
thread_id: str
|
||||
owner_id: str | None = None
|
||||
message_id: str | None = None
|
||||
rating: int
|
||||
comment: str | None = None
|
||||
created_at: str = ""
|
||||
|
||||
|
||||
class FeedbackStatsResponse(BaseModel):
|
||||
run_id: str
|
||||
total: int = 0
|
||||
positive: int = 0
|
||||
negative: int = 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
|
||||
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
||||
async def create_feedback(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
body: FeedbackCreateRequest,
|
||||
request: Request,
|
||||
) -> dict[str, Any]:
|
||||
"""Submit feedback (thumbs-up/down) for a run."""
|
||||
if body.rating not in (1, -1):
|
||||
raise HTTPException(status_code=400, detail="rating must be +1 or -1")
|
||||
|
||||
user_id = await get_current_user(request)
|
||||
|
||||
# Validate run exists and belongs to thread
|
||||
run_store = get_run_store(request)
|
||||
run = await run_store.get(run_id)
|
||||
if run is None:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
if run.get("thread_id") != thread_id:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found in thread {thread_id}")
|
||||
|
||||
feedback_repo = get_feedback_repo(request)
|
||||
return await feedback_repo.create(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
rating=body.rating,
|
||||
owner_id=user_id,
|
||||
message_id=body.message_id,
|
||||
comment=body.comment,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/feedback", response_model=list[FeedbackResponse])
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def list_feedback(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List all feedback for a run."""
|
||||
feedback_repo = get_feedback_repo(request)
|
||||
return await feedback_repo.list_by_run(thread_id, run_id)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/feedback/stats", response_model=FeedbackStatsResponse)
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def feedback_stats(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
) -> dict[str, Any]:
|
||||
"""Get aggregated feedback stats (positive/negative counts) for a run."""
|
||||
feedback_repo = get_feedback_repo(request)
|
||||
return await feedback_repo.aggregate_by_run(thread_id, run_id)
|
||||
|
||||
|
||||
@router.delete("/{thread_id}/runs/{run_id}/feedback/{feedback_id}")
|
||||
@require_permission("threads", "delete", owner_check=True, require_existing=True)
|
||||
async def delete_feedback(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
feedback_id: str,
|
||||
request: Request,
|
||||
) -> dict[str, bool]:
|
||||
"""Delete a feedback record."""
|
||||
feedback_repo = get_feedback_repo(request)
|
||||
# Verify feedback belongs to the specified thread/run before deleting
|
||||
existing = await feedback_repo.get(feedback_id)
|
||||
if existing is None:
|
||||
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found")
|
||||
if existing.get("thread_id") != thread_id or existing.get("run_id") != run_id:
|
||||
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found in run {run_id}")
|
||||
deleted = await feedback_repo.delete(feedback_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found")
|
||||
return {"success": True}
|
||||
@@ -1,10 +1,11 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, Request
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -98,12 +99,13 @@ def _format_conversation(messages: list[SuggestionMessage]) -> str:
|
||||
summary="Generate Follow-up Questions",
|
||||
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
|
||||
)
|
||||
async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> SuggestionsResponse:
|
||||
if not request.messages:
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request: Request) -> SuggestionsResponse:
|
||||
if not body.messages:
|
||||
return SuggestionsResponse(suggestions=[])
|
||||
|
||||
n = request.n
|
||||
conversation = _format_conversation(request.messages)
|
||||
n = body.n
|
||||
conversation = _format_conversation(body.messages)
|
||||
if not conversation:
|
||||
return SuggestionsResponse(suggestions=[])
|
||||
|
||||
@@ -120,7 +122,7 @@ async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> S
|
||||
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
|
||||
|
||||
try:
|
||||
model = create_chat_model(name=request.model_name, thinking_enabled=False)
|
||||
model = create_chat_model(name=body.model_name, thinking_enabled=False)
|
||||
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)])
|
||||
raw = _extract_response_text(response.content)
|
||||
suggestions = _parse_json_string_list(raw) or []
|
||||
|
||||
@@ -19,7 +19,8 @@ from fastapi import APIRouter, HTTPException, Query, Request
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge
|
||||
from app.gateway.authz import require_permission
|
||||
from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
||||
from app.gateway.services import sse_consumer, start_run
|
||||
from deerflow.runtime import RunRecord, serialize_channel_values
|
||||
|
||||
@@ -53,6 +54,7 @@ class RunCreateRequest(BaseModel):
|
||||
after_seconds: float | None = Field(default=None, description="Delayed execution")
|
||||
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
|
||||
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
|
||||
follow_up_to_run_id: str | None = Field(default=None, description="Run ID this message follows up on. Auto-detected from latest successful run if not provided.")
|
||||
|
||||
|
||||
class RunResponse(BaseModel):
|
||||
@@ -92,6 +94,7 @@ def _record_to_response(record: RunRecord) -> RunResponse:
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs", response_model=RunResponse)
|
||||
@require_permission("runs", "create", owner_check=True, require_existing=True)
|
||||
async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -> RunResponse:
|
||||
"""Create a background run (returns immediately)."""
|
||||
record = await start_run(body, thread_id, request)
|
||||
@@ -99,6 +102,7 @@ async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/stream")
|
||||
@require_permission("runs", "create", owner_check=True, require_existing=True)
|
||||
async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -> StreamingResponse:
|
||||
"""Create a run and stream events via SSE.
|
||||
|
||||
@@ -126,6 +130,7 @@ async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/wait", response_model=dict)
|
||||
@require_permission("runs", "create", owner_check=True, require_existing=True)
|
||||
async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict:
|
||||
"""Create a run and block until it completes, returning the final state."""
|
||||
record = await start_run(body, thread_id, request)
|
||||
@@ -151,6 +156,7 @@ async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) ->
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs", response_model=list[RunResponse])
|
||||
@require_permission("runs", "read", owner_check=True)
|
||||
async def list_runs(thread_id: str, request: Request) -> list[RunResponse]:
|
||||
"""List all runs for a thread."""
|
||||
run_mgr = get_run_manager(request)
|
||||
@@ -159,6 +165,7 @@ async def list_runs(thread_id: str, request: Request) -> list[RunResponse]:
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse)
|
||||
@require_permission("runs", "read", owner_check=True)
|
||||
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
|
||||
"""Get details of a specific run."""
|
||||
run_mgr = get_run_manager(request)
|
||||
@@ -169,6 +176,7 @@ async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/{run_id}/cancel")
|
||||
@require_permission("runs", "cancel", owner_check=True, require_existing=True)
|
||||
async def cancel_run(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
@@ -206,6 +214,7 @@ async def cancel_run(
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/join")
|
||||
@require_permission("runs", "read", owner_check=True)
|
||||
async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse:
|
||||
"""Join an existing run's SSE stream."""
|
||||
bridge = get_stream_bridge(request)
|
||||
@@ -226,6 +235,7 @@ async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingRe
|
||||
|
||||
|
||||
@router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None)
|
||||
@require_permission("runs", "read", owner_check=True)
|
||||
async def stream_existing_run(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
@@ -265,3 +275,54 @@ async def stream_existing_run(
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Messages / Events / Token usage endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/{thread_id}/messages")
|
||||
@require_permission("runs", "read", owner_check=True)
|
||||
async def list_thread_messages(
|
||||
thread_id: str,
|
||||
request: Request,
|
||||
limit: int = Query(default=50, le=200),
|
||||
before_seq: int | None = Query(default=None),
|
||||
after_seq: int | None = Query(default=None),
|
||||
) -> list[dict]:
|
||||
"""Return displayable messages for a thread (across all runs)."""
|
||||
event_store = get_run_event_store(request)
|
||||
return await event_store.list_messages(thread_id, limit=limit, before_seq=before_seq, after_seq=after_seq)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/messages")
|
||||
@require_permission("runs", "read", owner_check=True)
|
||||
async def list_run_messages(thread_id: str, run_id: str, request: Request) -> list[dict]:
|
||||
"""Return displayable messages for a specific run."""
|
||||
event_store = get_run_event_store(request)
|
||||
return await event_store.list_messages_by_run(thread_id, run_id)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/events")
|
||||
@require_permission("runs", "read", owner_check=True)
|
||||
async def list_run_events(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
event_types: str | None = Query(default=None),
|
||||
limit: int = Query(default=500, le=2000),
|
||||
) -> list[dict]:
|
||||
"""Return the full event stream for a run (debug/audit)."""
|
||||
event_store = get_run_event_store(request)
|
||||
types = event_types.split(",") if event_types else None
|
||||
return await event_store.list_events(thread_id, run_id, event_types=types, limit=limit)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/token-usage")
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def thread_token_usage(thread_id: str, request: Request) -> dict:
|
||||
"""Thread-level token usage aggregation."""
|
||||
run_store = get_run_store(request)
|
||||
agg = await run_store.aggregate_tokens_by_thread(thread_id)
|
||||
return {"thread_id": thread_id, **agg}
|
||||
|
||||
@@ -18,23 +18,34 @@ import uuid
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.gateway.deps import get_checkpointer, get_store
|
||||
from app.gateway.authz import require_permission
|
||||
from app.gateway.deps import get_checkpointer
|
||||
from app.gateway.utils import sanitize_log_param
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.runtime import serialize_channel_values
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Store namespace
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
THREADS_NS: tuple[str, ...] = ("threads",)
|
||||
"""Namespace used by the Store for thread metadata records."""
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/threads", tags=["threads"])
|
||||
|
||||
|
||||
# Metadata keys that the server controls; clients are not allowed to set
|
||||
# them. Pydantic ``@field_validator("metadata")`` strips them on every
|
||||
# inbound model below so a malicious client cannot reflect a forged
|
||||
# owner identity through the API surface. Defense-in-depth — the
|
||||
# row-level invariant is still ``threads_meta.owner_id`` populated from
|
||||
# the auth contextvar; this list closes the metadata-blob echo gap.
|
||||
_SERVER_RESERVED_METADATA_KEYS: frozenset[str] = frozenset({"owner_id", "user_id"})
|
||||
|
||||
|
||||
def _strip_reserved_metadata(metadata: dict[str, Any] | None) -> dict[str, Any]:
|
||||
"""Return ``metadata`` with server-controlled keys removed."""
|
||||
if not metadata:
|
||||
return metadata or {}
|
||||
return {k: v for k, v in metadata.items() if k not in _SERVER_RESERVED_METADATA_KEYS}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response / request models
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -63,8 +74,11 @@ class ThreadCreateRequest(BaseModel):
|
||||
"""Request body for creating a thread."""
|
||||
|
||||
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
|
||||
assistant_id: str | None = Field(default=None, description="Associate thread with an assistant")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")
|
||||
|
||||
_strip_reserved = field_validator("metadata")(classmethod(lambda cls, v: _strip_reserved_metadata(v)))
|
||||
|
||||
|
||||
class ThreadSearchRequest(BaseModel):
|
||||
"""Request body for searching threads."""
|
||||
@@ -93,6 +107,8 @@ class ThreadPatchRequest(BaseModel):
|
||||
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata to merge")
|
||||
|
||||
_strip_reserved = field_validator("metadata")(classmethod(lambda cls, v: _strip_reserved_metadata(v)))
|
||||
|
||||
|
||||
class ThreadStateUpdateRequest(BaseModel):
|
||||
"""Request body for updating thread state (human-in-the-loop resume)."""
|
||||
@@ -135,61 +151,16 @@ def _delete_thread_data(thread_id: str, paths: Paths | None = None) -> ThreadDel
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
except FileNotFoundError:
|
||||
# Not critical — thread data may not exist on disk
|
||||
logger.debug("No local thread data to delete for %s", thread_id)
|
||||
logger.debug("No local thread data to delete for %s", sanitize_log_param(thread_id))
|
||||
return ThreadDeleteResponse(success=True, message=f"No local data for {thread_id}")
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to delete thread data for %s", thread_id)
|
||||
logger.exception("Failed to delete thread data for %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to delete local thread data.") from exc
|
||||
|
||||
logger.info("Deleted local thread data for %s", thread_id)
|
||||
logger.info("Deleted local thread data for %s", sanitize_log_param(thread_id))
|
||||
return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}")
|
||||
|
||||
|
||||
async def _store_get(store, thread_id: str) -> dict | None:
|
||||
"""Fetch a thread record from the Store; returns ``None`` if absent."""
|
||||
item = await store.aget(THREADS_NS, thread_id)
|
||||
return item.value if item is not None else None
|
||||
|
||||
|
||||
async def _store_put(store, record: dict) -> None:
|
||||
"""Write a thread record to the Store."""
|
||||
await store.aput(THREADS_NS, record["thread_id"], record)
|
||||
|
||||
|
||||
async def _store_upsert(store, thread_id: str, *, metadata: dict | None = None, values: dict | None = None) -> None:
|
||||
"""Create or refresh a thread record in the Store.
|
||||
|
||||
On creation the record is written with ``status="idle"``. On update only
|
||||
``updated_at`` (and optionally ``metadata`` / ``values``) are changed so
|
||||
that existing fields are preserved.
|
||||
|
||||
``values`` carries the agent-state snapshot exposed to the frontend
|
||||
(currently just ``{"title": "..."}``).
|
||||
"""
|
||||
now = time.time()
|
||||
existing = await _store_get(store, thread_id)
|
||||
if existing is None:
|
||||
await _store_put(
|
||||
store,
|
||||
{
|
||||
"thread_id": thread_id,
|
||||
"status": "idle",
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"metadata": metadata or {},
|
||||
"values": values or {},
|
||||
},
|
||||
)
|
||||
else:
|
||||
val = dict(existing)
|
||||
val["updated_at"] = now
|
||||
if metadata:
|
||||
val.setdefault("metadata", {}).update(metadata)
|
||||
if values:
|
||||
val.setdefault("values", {}).update(values)
|
||||
await _store_put(store, val)
|
||||
|
||||
|
||||
def _derive_thread_status(checkpoint_tuple) -> str:
|
||||
"""Derive thread status from checkpoint metadata."""
|
||||
if checkpoint_tuple is None:
|
||||
@@ -215,23 +186,19 @@ def _derive_thread_status(checkpoint_tuple) -> str:
|
||||
|
||||
|
||||
@router.delete("/{thread_id}", response_model=ThreadDeleteResponse)
|
||||
@require_permission("threads", "delete", owner_check=True, require_existing=True)
|
||||
async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteResponse:
|
||||
"""Delete local persisted filesystem data for a thread.
|
||||
|
||||
Cleans DeerFlow-managed thread directories, removes checkpoint data,
|
||||
and removes the thread record from the Store.
|
||||
and removes the thread_meta row from the configured ThreadMetaStore
|
||||
(sqlite or memory).
|
||||
"""
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
|
||||
# Clean local filesystem
|
||||
response = _delete_thread_data(thread_id)
|
||||
|
||||
# Remove from Store (best-effort)
|
||||
store = get_store(request)
|
||||
if store is not None:
|
||||
try:
|
||||
await store.adelete(THREADS_NS, thread_id)
|
||||
except Exception:
|
||||
logger.debug("Could not delete store record for thread %s (not critical)", thread_id)
|
||||
|
||||
# Remove checkpoints (best-effort)
|
||||
checkpointer = getattr(request.app.state, "checkpointer", None)
|
||||
if checkpointer is not None:
|
||||
@@ -239,7 +206,15 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
|
||||
if hasattr(checkpointer, "adelete_thread"):
|
||||
await checkpointer.adelete_thread(thread_id)
|
||||
except Exception:
|
||||
logger.debug("Could not delete checkpoints for thread %s (not critical)", thread_id)
|
||||
logger.debug("Could not delete checkpoints for thread %s (not critical)", sanitize_log_param(thread_id))
|
||||
|
||||
# Remove thread_meta row (best-effort) — required for sqlite backend
|
||||
# so the deleted thread no longer appears in /threads/search.
|
||||
try:
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
await thread_meta_repo.delete(thread_id)
|
||||
except Exception:
|
||||
logger.debug("Could not delete thread_meta for %s (not critical)", sanitize_log_param(thread_id))
|
||||
|
||||
return response
|
||||
|
||||
@@ -248,43 +223,40 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
|
||||
async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadResponse:
|
||||
"""Create a new thread.
|
||||
|
||||
The thread record is written to the Store (for fast listing) and an
|
||||
empty checkpoint is written to the checkpointer (for state reads).
|
||||
Writes a thread_meta record (so the thread appears in /threads/search)
|
||||
and an empty checkpoint (so state endpoints work immediately).
|
||||
Idempotent: returns the existing record when ``thread_id`` already exists.
|
||||
"""
|
||||
store = get_store(request)
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
|
||||
checkpointer = get_checkpointer(request)
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
thread_id = body.thread_id or str(uuid.uuid4())
|
||||
now = time.time()
|
||||
# ``body.metadata`` is already stripped of server-reserved keys by
|
||||
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
|
||||
|
||||
# Idempotency: return existing record from Store when already present
|
||||
if store is not None:
|
||||
existing_record = await _store_get(store, thread_id)
|
||||
if existing_record is not None:
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=existing_record.get("status", "idle"),
|
||||
created_at=str(existing_record.get("created_at", "")),
|
||||
updated_at=str(existing_record.get("updated_at", "")),
|
||||
metadata=existing_record.get("metadata", {}),
|
||||
)
|
||||
# Idempotency: return existing record when already present
|
||||
existing_record = await thread_meta_repo.get(thread_id)
|
||||
if existing_record is not None:
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=existing_record.get("status", "idle"),
|
||||
created_at=str(existing_record.get("created_at", "")),
|
||||
updated_at=str(existing_record.get("updated_at", "")),
|
||||
metadata=existing_record.get("metadata", {}),
|
||||
)
|
||||
|
||||
# Write thread record to Store
|
||||
if store is not None:
|
||||
try:
|
||||
await _store_put(
|
||||
store,
|
||||
{
|
||||
"thread_id": thread_id,
|
||||
"status": "idle",
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"metadata": body.metadata,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to write thread %s to store", thread_id)
|
||||
raise HTTPException(status_code=500, detail="Failed to create thread")
|
||||
# Write thread_meta so the thread appears in /threads/search immediately
|
||||
try:
|
||||
await thread_meta_repo.create(
|
||||
thread_id,
|
||||
assistant_id=getattr(body, "assistant_id", None),
|
||||
metadata=body.metadata,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to write thread_meta for %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to create thread")
|
||||
|
||||
# Write an empty checkpoint so state endpoints work immediately
|
||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
@@ -301,10 +273,10 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
}
|
||||
await checkpointer.aput(config, empty_checkpoint(), ckpt_metadata, {})
|
||||
except Exception:
|
||||
logger.exception("Failed to create checkpoint for thread %s", thread_id)
|
||||
logger.exception("Failed to create checkpoint for thread %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to create thread")
|
||||
|
||||
logger.info("Thread created: %s", thread_id)
|
||||
logger.info("Thread created: %s", sanitize_log_param(thread_id))
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status="idle",
|
||||
@@ -318,166 +290,91 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
async def search_threads(body: ThreadSearchRequest, request: Request) -> list[ThreadResponse]:
|
||||
"""Search and list threads.
|
||||
|
||||
Two-phase approach:
|
||||
|
||||
**Phase 1 — Store (fast path, O(threads))**: returns threads that were
|
||||
created or run through this Gateway. Store records are tiny metadata
|
||||
dicts so fetching all of them at once is cheap.
|
||||
|
||||
**Phase 2 — Checkpointer supplement (lazy migration)**: threads that
|
||||
were created directly by LangGraph Server (and therefore absent from the
|
||||
Store) are discovered here by iterating the shared checkpointer. Any
|
||||
newly found thread is immediately written to the Store so that the next
|
||||
search skips Phase 2 for that thread — the Store converges to a full
|
||||
index over time without a one-shot migration job.
|
||||
Delegates to the configured ThreadMetaStore implementation
|
||||
(SQL-backed for sqlite/postgres, Store-backed for memory mode).
|
||||
"""
|
||||
store = get_store(request)
|
||||
checkpointer = get_checkpointer(request)
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Phase 1: Store
|
||||
# -----------------------------------------------------------------------
|
||||
merged: dict[str, ThreadResponse] = {}
|
||||
|
||||
if store is not None:
|
||||
try:
|
||||
items = await store.asearch(THREADS_NS, limit=10_000)
|
||||
except Exception:
|
||||
logger.warning("Store search failed — falling back to checkpointer only", exc_info=True)
|
||||
items = []
|
||||
|
||||
for item in items:
|
||||
val = item.value
|
||||
merged[val["thread_id"]] = ThreadResponse(
|
||||
thread_id=val["thread_id"],
|
||||
status=val.get("status", "idle"),
|
||||
created_at=str(val.get("created_at", "")),
|
||||
updated_at=str(val.get("updated_at", "")),
|
||||
metadata=val.get("metadata", {}),
|
||||
values=val.get("values", {}),
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Phase 2: Checkpointer supplement
|
||||
# Discovers threads not yet in the Store (e.g. created by LangGraph
|
||||
# Server) and lazily migrates them so future searches skip this phase.
|
||||
# -----------------------------------------------------------------------
|
||||
try:
|
||||
async for checkpoint_tuple in checkpointer.alist(None):
|
||||
cfg = getattr(checkpoint_tuple, "config", {})
|
||||
thread_id = cfg.get("configurable", {}).get("thread_id")
|
||||
if not thread_id or thread_id in merged:
|
||||
continue
|
||||
|
||||
# Skip sub-graph checkpoints (checkpoint_ns is non-empty for those)
|
||||
if cfg.get("configurable", {}).get("checkpoint_ns", ""):
|
||||
continue
|
||||
|
||||
ckpt_meta = getattr(checkpoint_tuple, "metadata", {}) or {}
|
||||
# Strip LangGraph internal keys from the user-visible metadata dict
|
||||
user_meta = {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")}
|
||||
|
||||
# Extract state values (title) from the checkpoint's channel_values
|
||||
checkpoint_data = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
channel_values = checkpoint_data.get("channel_values", {})
|
||||
ckpt_values = {}
|
||||
if title := channel_values.get("title"):
|
||||
ckpt_values["title"] = title
|
||||
|
||||
thread_resp = ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=_derive_thread_status(checkpoint_tuple),
|
||||
created_at=str(ckpt_meta.get("created_at", "")),
|
||||
updated_at=str(ckpt_meta.get("updated_at", ckpt_meta.get("created_at", ""))),
|
||||
metadata=user_meta,
|
||||
values=ckpt_values,
|
||||
)
|
||||
merged[thread_id] = thread_resp
|
||||
|
||||
# Lazy migration — write to Store so the next search finds it there
|
||||
if store is not None:
|
||||
try:
|
||||
await _store_upsert(store, thread_id, metadata=user_meta, values=ckpt_values or None)
|
||||
except Exception:
|
||||
logger.debug("Failed to migrate thread %s to store (non-fatal)", thread_id)
|
||||
except Exception:
|
||||
logger.exception("Checkpointer scan failed during thread search")
|
||||
# Don't raise — return whatever was collected from Store + partial scan
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Phase 3: Filter → sort → paginate
|
||||
# -----------------------------------------------------------------------
|
||||
results = list(merged.values())
|
||||
|
||||
if body.metadata:
|
||||
results = [r for r in results if all(r.metadata.get(k) == v for k, v in body.metadata.items())]
|
||||
|
||||
if body.status:
|
||||
results = [r for r in results if r.status == body.status]
|
||||
|
||||
results.sort(key=lambda r: r.updated_at, reverse=True)
|
||||
return results[body.offset : body.offset + body.limit]
|
||||
repo = get_thread_meta_repo(request)
|
||||
rows = await repo.search(
|
||||
metadata=body.metadata or None,
|
||||
status=body.status,
|
||||
limit=body.limit,
|
||||
offset=body.offset,
|
||||
)
|
||||
return [
|
||||
ThreadResponse(
|
||||
thread_id=r["thread_id"],
|
||||
status=r.get("status", "idle"),
|
||||
created_at=r.get("created_at", ""),
|
||||
updated_at=r.get("updated_at", ""),
|
||||
metadata=r.get("metadata", {}),
|
||||
values={"title": r["display_name"]} if r.get("display_name") else {},
|
||||
interrupts={},
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
@router.patch("/{thread_id}", response_model=ThreadResponse)
|
||||
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
||||
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
|
||||
"""Merge metadata into a thread record."""
|
||||
store = get_store(request)
|
||||
if store is None:
|
||||
raise HTTPException(status_code=503, detail="Store not available")
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
|
||||
record = await _store_get(store, thread_id)
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
record = await thread_meta_repo.get(thread_id)
|
||||
if record is None:
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
now = time.time()
|
||||
updated = dict(record)
|
||||
updated.setdefault("metadata", {}).update(body.metadata)
|
||||
updated["updated_at"] = now
|
||||
|
||||
# ``body.metadata`` already stripped by ``ThreadPatchRequest._strip_reserved``.
|
||||
try:
|
||||
await _store_put(store, updated)
|
||||
await thread_meta_repo.update_metadata(thread_id, body.metadata)
|
||||
except Exception:
|
||||
logger.exception("Failed to patch thread %s", thread_id)
|
||||
logger.exception("Failed to patch thread %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to update thread")
|
||||
|
||||
# Re-read to get the merged metadata + refreshed updated_at
|
||||
record = await thread_meta_repo.get(thread_id) or record
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=updated.get("status", "idle"),
|
||||
created_at=str(updated.get("created_at", "")),
|
||||
updated_at=str(now),
|
||||
metadata=updated.get("metadata", {}),
|
||||
status=record.get("status", "idle"),
|
||||
created_at=str(record.get("created_at", "")),
|
||||
updated_at=str(record.get("updated_at", "")),
|
||||
metadata=record.get("metadata", {}),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{thread_id}", response_model=ThreadResponse)
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
||||
"""Get thread info.
|
||||
|
||||
Reads metadata from the Store and derives the accurate execution
|
||||
status from the checkpointer. Falls back to the checkpointer alone
|
||||
for threads that pre-date Store adoption (backward compat).
|
||||
Reads metadata from the ThreadMetaStore and derives the accurate
|
||||
execution status from the checkpointer. Falls back to the checkpointer
|
||||
alone for threads that pre-date ThreadMetaStore adoption (backward compat).
|
||||
"""
|
||||
store = get_store(request)
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
checkpointer = get_checkpointer(request)
|
||||
|
||||
record: dict | None = None
|
||||
if store is not None:
|
||||
record = await _store_get(store, thread_id)
|
||||
record: dict | None = await thread_meta_repo.get(thread_id)
|
||||
|
||||
# Derive accurate status from the checkpointer
|
||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||
except Exception:
|
||||
logger.exception("Failed to get checkpoint for thread %s", thread_id)
|
||||
logger.exception("Failed to get checkpoint for thread %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to get thread")
|
||||
|
||||
if record is None and checkpoint_tuple is None:
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
# If the thread exists in the checkpointer but not the store (e.g. legacy
|
||||
# data), synthesize a minimal store record from the checkpoint metadata.
|
||||
# If the thread exists in the checkpointer but not in thread_meta (e.g.
|
||||
# legacy data created before thread_meta adoption), synthesize a minimal
|
||||
# record from the checkpoint metadata.
|
||||
if record is None and checkpoint_tuple is not None:
|
||||
ckpt_meta = getattr(checkpoint_tuple, "metadata", {}) or {}
|
||||
record = {
|
||||
@@ -506,6 +403,7 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
||||
|
||||
|
||||
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse:
|
||||
"""Get the latest state snapshot for a thread.
|
||||
|
||||
@@ -518,7 +416,7 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||
except Exception:
|
||||
logger.exception("Failed to get state for thread %s", thread_id)
|
||||
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to get thread state")
|
||||
|
||||
if checkpoint_tuple is None:
|
||||
@@ -555,15 +453,19 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
|
||||
|
||||
|
||||
@router.post("/{thread_id}/state", response_model=ThreadStateResponse)
|
||||
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
||||
async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse:
|
||||
"""Update thread state (e.g. for human-in-the-loop resume or title rename).
|
||||
|
||||
Writes a new checkpoint that merges *body.values* into the latest
|
||||
channel values, then syncs any updated ``title`` field back to the Store
|
||||
so that ``/threads/search`` reflects the change immediately.
|
||||
channel values, then syncs any updated ``title`` field through the
|
||||
ThreadMetaStore abstraction so that ``/threads/search`` reflects the
|
||||
change immediately in both sqlite and memory backends.
|
||||
"""
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
|
||||
checkpointer = get_checkpointer(request)
|
||||
store = get_store(request)
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
|
||||
# checkpoint_ns must be present in the config for aput — default to ""
|
||||
# (the root graph namespace). checkpoint_id is optional; omitting it
|
||||
@@ -580,7 +482,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(read_config)
|
||||
except Exception:
|
||||
logger.exception("Failed to get state for thread %s", thread_id)
|
||||
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to get thread state")
|
||||
|
||||
if checkpoint_tuple is None:
|
||||
@@ -614,19 +516,22 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
||||
try:
|
||||
new_config = await checkpointer.aput(write_config, checkpoint, metadata, {})
|
||||
except Exception:
|
||||
logger.exception("Failed to update state for thread %s", thread_id)
|
||||
logger.exception("Failed to update state for thread %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to update thread state")
|
||||
|
||||
new_checkpoint_id: str | None = None
|
||||
if isinstance(new_config, dict):
|
||||
new_checkpoint_id = new_config.get("configurable", {}).get("checkpoint_id")
|
||||
|
||||
# Sync title changes to the Store so /threads/search reflects them immediately.
|
||||
if store is not None and body.values and "title" in body.values:
|
||||
try:
|
||||
await _store_upsert(store, thread_id, values={"title": body.values["title"]})
|
||||
except Exception:
|
||||
logger.debug("Failed to sync title to store for thread %s (non-fatal)", thread_id)
|
||||
# Sync title changes through the ThreadMetaStore abstraction so /threads/search
|
||||
# reflects them immediately in both sqlite and memory backends.
|
||||
if body.values and "title" in body.values:
|
||||
new_title = body.values["title"]
|
||||
if new_title: # Skip empty strings and None
|
||||
try:
|
||||
await thread_meta_repo.update_display_name(thread_id, new_title)
|
||||
except Exception:
|
||||
logger.debug("Failed to sync title to thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
||||
|
||||
return ThreadStateResponse(
|
||||
values=serialize_channel_values(channel_values),
|
||||
@@ -638,8 +543,16 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
||||
|
||||
|
||||
@router.post("/{thread_id}/history", response_model=list[HistoryEntry])
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]:
|
||||
"""Get checkpoint history for a thread."""
|
||||
"""Get checkpoint history for a thread.
|
||||
|
||||
Messages are read from the checkpointer's channel values (the
|
||||
authoritative source) and serialized via
|
||||
:func:`~deerflow.runtime.serialization.serialize_channel_values`.
|
||||
Only the latest (first) checkpoint carries the ``messages`` key to
|
||||
avoid duplicating them across every entry.
|
||||
"""
|
||||
checkpointer = get_checkpointer(request)
|
||||
|
||||
config: dict[str, Any] = {"configurable": {"thread_id": thread_id}}
|
||||
@@ -647,6 +560,7 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
|
||||
config["configurable"]["checkpoint_id"] = body.before
|
||||
|
||||
entries: list[HistoryEntry] = []
|
||||
is_latest_checkpoint = True
|
||||
try:
|
||||
async for checkpoint_tuple in checkpointer.alist(config, limit=body.limit):
|
||||
ckpt_config = getattr(checkpoint_tuple, "config", {})
|
||||
@@ -661,22 +575,42 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
|
||||
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
|
||||
# Build values from checkpoint channel_values
|
||||
values: dict[str, Any] = {}
|
||||
if title := channel_values.get("title"):
|
||||
values["title"] = title
|
||||
if thread_data := channel_values.get("thread_data"):
|
||||
values["thread_data"] = thread_data
|
||||
|
||||
# Attach messages from checkpointer only for the latest checkpoint
|
||||
if is_latest_checkpoint:
|
||||
messages = channel_values.get("messages")
|
||||
if messages:
|
||||
values["messages"] = serialize_channel_values({"messages": messages}).get("messages", [])
|
||||
is_latest_checkpoint = False
|
||||
|
||||
# Derive next tasks
|
||||
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
|
||||
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
|
||||
|
||||
# Strip LangGraph internal keys from metadata
|
||||
user_meta = {k: v for k, v in metadata.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")}
|
||||
# Keep step for ordering context
|
||||
if "step" in metadata:
|
||||
user_meta["step"] = metadata["step"]
|
||||
|
||||
entries.append(
|
||||
HistoryEntry(
|
||||
checkpoint_id=checkpoint_id,
|
||||
parent_checkpoint_id=parent_id,
|
||||
metadata=metadata,
|
||||
values=serialize_channel_values(channel_values),
|
||||
metadata=user_meta,
|
||||
values=values,
|
||||
created_at=str(metadata.get("created_at", "")),
|
||||
next=next_tasks,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to get history for thread %s", thread_id)
|
||||
logger.exception("Failed to get history for thread %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to get thread history")
|
||||
|
||||
return entries
|
||||
|
||||
@@ -4,9 +4,10 @@ import logging
|
||||
import os
|
||||
import stat
|
||||
|
||||
from fastapi import APIRouter, File, HTTPException, UploadFile
|
||||
from fastapi import APIRouter, File, HTTPException, Request, UploadFile
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||
from deerflow.uploads.manager import (
|
||||
@@ -54,8 +55,10 @@ def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
|
||||
|
||||
|
||||
@router.post("", response_model=UploadResponse)
|
||||
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
||||
async def upload_files(
|
||||
thread_id: str,
|
||||
request: Request,
|
||||
files: list[UploadFile] = File(...),
|
||||
) -> UploadResponse:
|
||||
"""Upload multiple files to a thread's uploads directory."""
|
||||
@@ -133,7 +136,8 @@ async def upload_files(
|
||||
|
||||
|
||||
@router.get("/list", response_model=dict)
|
||||
async def list_uploaded_files(thread_id: str) -> dict:
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def list_uploaded_files(thread_id: str, request: Request) -> dict:
|
||||
"""List all files in a thread's uploads directory."""
|
||||
try:
|
||||
uploads_dir = get_uploads_dir(thread_id)
|
||||
@@ -151,7 +155,8 @@ async def list_uploaded_files(thread_id: str) -> dict:
|
||||
|
||||
|
||||
@router.delete("/{filename}")
|
||||
async def delete_uploaded_file(thread_id: str, filename: str) -> dict:
|
||||
@require_permission("threads", "delete", owner_check=True, require_existing=True)
|
||||
async def delete_uploaded_file(thread_id: str, filename: str, request: Request) -> dict:
|
||||
"""Delete a file from a thread's uploads directory."""
|
||||
try:
|
||||
uploads_dir = get_uploads_dir(thread_id)
|
||||
|
||||
@@ -8,16 +8,17 @@ frames, and consuming stream bridge events. Router modules
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from app.gateway.deps import get_checkpointer, get_run_manager, get_store, get_stream_bridge
|
||||
from app.gateway.deps import get_run_context, get_run_manager, get_run_store, get_stream_bridge
|
||||
from app.gateway.utils import sanitize_log_param
|
||||
from deerflow.runtime import (
|
||||
END_SENTINEL,
|
||||
HEARTBEAT_SENTINEL,
|
||||
@@ -171,71 +172,6 @@ def build_run_config(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _upsert_thread_in_store(store, thread_id: str, metadata: dict | None) -> None:
|
||||
"""Create or refresh the thread record in the Store.
|
||||
|
||||
Called from :func:`start_run` so that threads created via the stateless
|
||||
``/runs/stream`` endpoint (which never calls ``POST /threads``) still
|
||||
appear in ``/threads/search`` results.
|
||||
"""
|
||||
# Deferred import to avoid circular import with the threads router module.
|
||||
from app.gateway.routers.threads import _store_upsert
|
||||
|
||||
try:
|
||||
await _store_upsert(store, thread_id, metadata=metadata)
|
||||
except Exception:
|
||||
logger.warning("Failed to upsert thread %s in store (non-fatal)", thread_id)
|
||||
|
||||
|
||||
async def _sync_thread_title_after_run(
|
||||
run_task: asyncio.Task,
|
||||
thread_id: str,
|
||||
checkpointer: Any,
|
||||
store: Any,
|
||||
) -> None:
|
||||
"""Wait for *run_task* to finish, then persist the generated title to the Store.
|
||||
|
||||
TitleMiddleware writes the generated title to the LangGraph agent state
|
||||
(checkpointer) but the Gateway's Store record is not updated automatically.
|
||||
This coroutine closes that gap by reading the final checkpoint after the
|
||||
run completes and syncing ``values.title`` into the Store record so that
|
||||
subsequent ``/threads/search`` responses include the correct title.
|
||||
|
||||
Runs as a fire-and-forget :func:`asyncio.create_task`; failures are
|
||||
logged at DEBUG level and never propagate.
|
||||
"""
|
||||
# Wait for the background run task to complete (any outcome).
|
||||
# asyncio.wait does not propagate task exceptions — it just returns
|
||||
# when the task is done, cancelled, or failed.
|
||||
await asyncio.wait({run_task})
|
||||
|
||||
# Deferred import to avoid circular import with the threads router module.
|
||||
from app.gateway.routers.threads import _store_get, _store_put
|
||||
|
||||
try:
|
||||
ckpt_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
ckpt_tuple = await checkpointer.aget_tuple(ckpt_config)
|
||||
if ckpt_tuple is None:
|
||||
return
|
||||
|
||||
channel_values = ckpt_tuple.checkpoint.get("channel_values", {})
|
||||
title = channel_values.get("title")
|
||||
if not title:
|
||||
return
|
||||
|
||||
existing = await _store_get(store, thread_id)
|
||||
if existing is None:
|
||||
return
|
||||
|
||||
updated = dict(existing)
|
||||
updated.setdefault("values", {})["title"] = title
|
||||
updated["updated_at"] = time.time()
|
||||
await _store_put(store, updated)
|
||||
logger.debug("Synced title %r for thread %s", title, thread_id)
|
||||
except Exception:
|
||||
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id, exc_info=True)
|
||||
|
||||
|
||||
async def start_run(
|
||||
body: Any,
|
||||
thread_id: str,
|
||||
@@ -255,11 +191,25 @@ async def start_run(
|
||||
"""
|
||||
bridge = get_stream_bridge(request)
|
||||
run_mgr = get_run_manager(request)
|
||||
checkpointer = get_checkpointer(request)
|
||||
store = get_store(request)
|
||||
run_ctx = get_run_context(request)
|
||||
|
||||
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
|
||||
|
||||
# Resolve follow_up_to_run_id: explicit from request, or auto-detect from latest successful run
|
||||
follow_up_to_run_id = getattr(body, "follow_up_to_run_id", None)
|
||||
if follow_up_to_run_id is None:
|
||||
run_store = get_run_store(request)
|
||||
try:
|
||||
recent_runs = await run_store.list_by_thread(thread_id, limit=1)
|
||||
if recent_runs and recent_runs[0].get("status") == "success":
|
||||
follow_up_to_run_id = recent_runs[0]["run_id"]
|
||||
except Exception:
|
||||
pass # Don't block run creation
|
||||
|
||||
# Enrich base context with per-run field
|
||||
if follow_up_to_run_id:
|
||||
run_ctx = dataclasses.replace(run_ctx, follow_up_to_run_id=follow_up_to_run_id)
|
||||
|
||||
try:
|
||||
record = await run_mgr.create_or_reject(
|
||||
thread_id,
|
||||
@@ -268,17 +218,28 @@ async def start_run(
|
||||
metadata=body.metadata or {},
|
||||
kwargs={"input": body.input, "config": body.config},
|
||||
multitask_strategy=body.multitask_strategy,
|
||||
follow_up_to_run_id=follow_up_to_run_id,
|
||||
)
|
||||
except ConflictError as exc:
|
||||
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
||||
except UnsupportedStrategyError as exc:
|
||||
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
||||
|
||||
# Ensure the thread is visible in /threads/search, even for threads that
|
||||
# were never explicitly created via POST /threads (e.g. stateless runs).
|
||||
store = get_store(request)
|
||||
if store is not None:
|
||||
await _upsert_thread_in_store(store, thread_id, body.metadata)
|
||||
# Upsert thread metadata so the thread appears in /threads/search,
|
||||
# even for threads that were never explicitly created via POST /threads
|
||||
# (e.g. stateless runs).
|
||||
try:
|
||||
existing = await run_ctx.thread_meta_repo.get(thread_id)
|
||||
if existing is None:
|
||||
await run_ctx.thread_meta_repo.create(
|
||||
thread_id,
|
||||
assistant_id=body.assistant_id,
|
||||
metadata=body.metadata,
|
||||
)
|
||||
else:
|
||||
await run_ctx.thread_meta_repo.update_status(thread_id, "running")
|
||||
except Exception:
|
||||
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
||||
|
||||
agent_factory = resolve_agent_factory(body.assistant_id)
|
||||
graph_input = normalize_input(body.input)
|
||||
@@ -311,8 +272,7 @@ async def start_run(
|
||||
bridge,
|
||||
run_mgr,
|
||||
record,
|
||||
checkpointer=checkpointer,
|
||||
store=store,
|
||||
ctx=run_ctx,
|
||||
agent_factory=agent_factory,
|
||||
graph_input=graph_input,
|
||||
config=config,
|
||||
@@ -324,11 +284,9 @@ async def start_run(
|
||||
)
|
||||
record.task = task
|
||||
|
||||
# After the run completes, sync the title generated by TitleMiddleware from
|
||||
# the checkpointer into the Store record so that /threads/search returns the
|
||||
# correct title instead of an empty values dict.
|
||||
if store is not None:
|
||||
asyncio.create_task(_sync_thread_title_after_run(task, thread_id, checkpointer, store))
|
||||
# Title sync is handled by worker.py's finally block which reads the
|
||||
# title from the checkpoint and calls thread_meta_repo.update_display_name
|
||||
# after the run completes.
|
||||
|
||||
return record
|
||||
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
"""Shared utility helpers for the Gateway layer."""
|
||||
|
||||
|
||||
def sanitize_log_param(value: str) -> str:
|
||||
"""Strip control characters to prevent log injection."""
|
||||
return value.replace("\n", "").replace("\r", "").replace("\x00", "")
|
||||
@@ -0,0 +1,77 @@
|
||||
# Docker Test Gap (Section 七 7.4)
|
||||
|
||||
This file documents the only **un-executed** test cases from
|
||||
`backend/docs/AUTH_TEST_PLAN.md` after the full release validation pass.
|
||||
|
||||
## Why this gap exists
|
||||
|
||||
The release validation environment (sg_dev: `10.251.229.92`) **does not have
|
||||
a Docker daemon installed**. The TC-DOCKER cases are container-runtime
|
||||
behavior tests that need an actual Docker engine to spin up
|
||||
`docker/docker-compose.yaml` services.
|
||||
|
||||
```bash
|
||||
$ ssh sg_dev "which docker; docker --version"
|
||||
# (empty)
|
||||
# bash: docker: command not found
|
||||
```
|
||||
|
||||
All other test plan sections were executed against either:
|
||||
- The local dev box (Mac, all services running locally), or
|
||||
- The deployed sg_dev instance (gateway + frontend + nginx via SSH tunnel)
|
||||
|
||||
## Cases not executed
|
||||
|
||||
| Case | Title | What it covers | Why not run |
|
||||
|---|---|---|---|
|
||||
| TC-DOCKER-01 | `users.db` volume persistence | Verify the `DEER_FLOW_HOME` bind mount survives container restart | needs `docker compose up` |
|
||||
| TC-DOCKER-02 | Session persistence across container restart | `AUTH_JWT_SECRET` env var keeps cookies valid after `docker compose down && up` | needs `docker compose down/up` |
|
||||
| TC-DOCKER-03 | Per-worker rate limiter divergence | Confirms in-process `_login_attempts` dict doesn't share state across `gunicorn` workers (4 by default in the compose file); known limitation, documented | needs multi-worker container |
|
||||
| TC-DOCKER-04 | IM channels skip AuthMiddleware | Verify Feishu/Slack/Telegram dispatchers run in-container against `http://langgraph:2024` without going through nginx | needs `docker logs` |
|
||||
| TC-DOCKER-05 | Admin credentials surfacing | **Updated post-simplify** — was "log scrape", now "0600 credential file in `DEER_FLOW_HOME`". The file-based behavior is already validated by TC-1.1 + TC-UPG-13 on sg_dev (non-Docker), so the only Docker-specific gap is verifying the volume mount carries the file out to the host | needs container + host volume |
|
||||
| TC-DOCKER-06 | Gateway-mode Docker deploy | `./scripts/deploy.sh --gateway` produces a 3-container topology (no `langgraph` container); same auth flow as standard mode | needs `docker compose --profile gateway` |
|
||||
|
||||
## Coverage already provided by non-Docker tests
|
||||
|
||||
The **auth-relevant** behavior in each Docker case is already exercised by
|
||||
the test cases that ran on sg_dev or local:
|
||||
|
||||
| Docker case | Auth behavior covered by |
|
||||
|---|---|
|
||||
| TC-DOCKER-01 (volume persistence) | TC-REENT-01 on sg_dev (admin row survives gateway restart) — same SQLite file, just no container layer between |
|
||||
| TC-DOCKER-02 (session persistence) | TC-API-02/03/06 (cookie roundtrip), plus TC-REENT-04 (multi-cookie) — JWT verification is process-state-free, container restart is equivalent to `pkill uvicorn && uv run uvicorn` |
|
||||
| TC-DOCKER-03 (per-worker rate limit) | TC-GW-04 + TC-REENT-09 (single-worker rate limit + 5min expiry). The cross-worker divergence is an architectural property of the in-memory dict; no auth code path differs |
|
||||
| TC-DOCKER-04 (IM channels skip auth) | Code-level only: `app/channels/manager.py` uses `langgraph_sdk` directly with no cookie handling. The langgraph_auth handler is bypassed by going through SDK, not HTTP |
|
||||
| TC-DOCKER-05 (credential surfacing) | TC-1.1 on sg_dev (file at `~/deer-flow/backend/.deer-flow/admin_initial_credentials.txt`, mode 0600, password 22 chars) — the only Docker-unique step is whether the bind mount projects this path onto the host, which is a `docker compose` config check, not a runtime behavior change |
|
||||
| TC-DOCKER-06 (gateway-mode container) | Section 七 7.2 covered by TC-GW-01..05 + Section 二 (gateway-mode auth flow on sg_dev) — same Gateway code, container is just a packaging change |
|
||||
|
||||
## Reproduction steps when Docker becomes available
|
||||
|
||||
Anyone with `docker` + `docker compose` installed can reproduce the gap by
|
||||
running the test plan section verbatim. Pre-flight:
|
||||
|
||||
```bash
|
||||
# Required on the host
|
||||
docker --version # >=24.x
|
||||
docker compose version # plugin >=2.x
|
||||
|
||||
# Required env var (otherwise sessions reset on every container restart)
|
||||
echo "AUTH_JWT_SECRET=$(python3 -c 'import secrets; print(secrets.token_urlsafe(32))')" \
|
||||
>> .env
|
||||
|
||||
# Optional: pin DEER_FLOW_HOME to a stable host path
|
||||
echo "DEER_FLOW_HOME=$HOME/deer-flow-data" >> .env
|
||||
```
|
||||
|
||||
Then run TC-DOCKER-01..06 from the test plan as written.
|
||||
|
||||
## Decision log
|
||||
|
||||
- **Not blocking the release.** The auth-relevant behavior in every Docker
|
||||
case has an already-validated equivalent on bare metal. The gap is purely
|
||||
about *container packaging* details (bind mounts, multi-worker, log
|
||||
collection), not about whether the auth code paths work.
|
||||
- **TC-DOCKER-05 was updated in place** in `AUTH_TEST_PLAN.md` to reflect
|
||||
the post-simplify reality (credentials file → 0600 file, no log leak).
|
||||
The old "grep 'Password:' in docker logs" expectation would have failed
|
||||
silently and given a false sense of coverage.
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,129 @@
|
||||
# Authentication Upgrade Guide
|
||||
|
||||
DeerFlow 内置了认证模块。本文档面向从无认证版本升级的用户。
|
||||
|
||||
## 核心概念
|
||||
|
||||
认证模块采用**始终强制**策略:
|
||||
|
||||
- 首次启动时自动创建 admin 账号,随机密码打印到控制台日志
|
||||
- 认证从一开始就是强制的,无竞争窗口
|
||||
- 历史对话(升级前创建的 thread)自动迁移到 admin 名下
|
||||
|
||||
## 升级步骤
|
||||
|
||||
### 1. 更新代码
|
||||
|
||||
```bash
|
||||
git pull origin main
|
||||
cd backend && make install
|
||||
```
|
||||
|
||||
### 2. 首次启动
|
||||
|
||||
```bash
|
||||
make dev
|
||||
```
|
||||
|
||||
控制台会输出:
|
||||
|
||||
```
|
||||
============================================================
|
||||
Admin account created on first boot
|
||||
Email: admin@deerflow.dev
|
||||
Password: aB3xK9mN_pQ7rT2w
|
||||
Change it after login: Settings → Account
|
||||
============================================================
|
||||
```
|
||||
|
||||
如果未登录就重启了服务,不用担心——只要 setup 未完成,每次启动都会重置密码并重新打印到控制台。
|
||||
|
||||
### 3. 登录
|
||||
|
||||
访问 `http://localhost:2026/login`,使用控制台输出的邮箱和密码登录。
|
||||
|
||||
### 4. 修改密码
|
||||
|
||||
登录后进入 Settings → Account → Change Password。
|
||||
|
||||
### 5. 添加用户(可选)
|
||||
|
||||
其他用户通过 `/login` 页面注册,自动获得 **user** 角色。每个用户只能看到自己的对话。
|
||||
|
||||
## 安全机制
|
||||
|
||||
| 机制 | 说明 |
|
||||
|------|------|
|
||||
| JWT HttpOnly Cookie | Token 不暴露给 JavaScript,防止 XSS 窃取 |
|
||||
| CSRF Double Submit Cookie | 所有 POST/PUT/DELETE 请求需携带 `X-CSRF-Token` |
|
||||
| bcrypt 密码哈希 | 密码不以明文存储 |
|
||||
| 多租户隔离 | 用户只能访问自己的 thread |
|
||||
| HTTPS 自适应 | 检测 `x-forwarded-proto`,自动设置 `Secure` cookie 标志 |
|
||||
|
||||
## 常见操作
|
||||
|
||||
### 忘记密码
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
|
||||
# 重置 admin 密码
|
||||
python -m app.gateway.auth.reset_admin
|
||||
|
||||
# 重置指定用户密码
|
||||
python -m app.gateway.auth.reset_admin --email user@example.com
|
||||
```
|
||||
|
||||
会输出新的随机密码。
|
||||
|
||||
### 完全重置
|
||||
|
||||
删除用户数据库,重启后自动创建新 admin:
|
||||
|
||||
```bash
|
||||
rm -f backend/.deer-flow/users.db
|
||||
# 重启服务,控制台输出新密码
|
||||
```
|
||||
|
||||
## 数据存储
|
||||
|
||||
| 文件 | 内容 |
|
||||
|------|------|
|
||||
| `.deer-flow/users.db` | SQLite 用户数据库(密码哈希、角色) |
|
||||
| `.env` 中的 `AUTH_JWT_SECRET` | JWT 签名密钥(未设置时自动生成临时密钥,重启后 session 失效) |
|
||||
|
||||
### 生产环境建议
|
||||
|
||||
```bash
|
||||
# 生成持久化 JWT 密钥,避免重启后所有用户需重新登录
|
||||
python -c "import secrets; print(secrets.token_urlsafe(32))"
|
||||
# 将输出添加到 .env:
|
||||
# AUTH_JWT_SECRET=<生成的密钥>
|
||||
```
|
||||
|
||||
## API 端点
|
||||
|
||||
| 端点 | 方法 | 说明 |
|
||||
|------|------|------|
|
||||
| `/api/v1/auth/login/local` | POST | 邮箱密码登录(OAuth2 form) |
|
||||
| `/api/v1/auth/register` | POST | 注册新用户(user 角色) |
|
||||
| `/api/v1/auth/logout` | POST | 登出(清除 cookie) |
|
||||
| `/api/v1/auth/me` | GET | 获取当前用户信息 |
|
||||
| `/api/v1/auth/change-password` | POST | 修改密码 |
|
||||
| `/api/v1/auth/setup-status` | GET | 检查 admin 是否存在 |
|
||||
|
||||
## 兼容性
|
||||
|
||||
- **标准模式**(`make dev`):完全兼容,admin 自动创建
|
||||
- **Gateway 模式**(`make dev-pro`):完全兼容
|
||||
- **Docker 部署**:完全兼容,`.deer-flow/users.db` 需持久化卷挂载
|
||||
- **IM 渠道**(Feishu/Slack/Telegram):通过 LangGraph SDK 通信,不经过认证层
|
||||
- **DeerFlowClient**(嵌入式):不经过 HTTP,不受认证影响
|
||||
|
||||
## 故障排查
|
||||
|
||||
| 症状 | 原因 | 解决 |
|
||||
|------|------|------|
|
||||
| 启动后没看到密码 | admin 已存在(非首次启动) | 用 `reset_admin` 重置,或删 `users.db` |
|
||||
| 登录后 POST 返回 403 | CSRF token 缺失 | 确认前端已更新 |
|
||||
| 重启后需要重新登录 | `AUTH_JWT_SECRET` 未持久化 | 在 `.env` 中设置固定密钥 |
|
||||
@@ -8,6 +8,9 @@
|
||||
"graphs": {
|
||||
"lead_agent": "deerflow.agents:make_lead_agent"
|
||||
},
|
||||
"auth": {
|
||||
"path": "./app/gateway/langgraph_auth.py:auth"
|
||||
},
|
||||
"checkpointer": {
|
||||
"path": "./packages/harness/deerflow/agents/checkpointer/async_provider.py:make_checkpointer"
|
||||
}
|
||||
|
||||
@@ -84,23 +84,76 @@ async def _async_checkpointer(config) -> AsyncIterator[Checkpointer]:
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def make_checkpointer() -> AsyncIterator[Checkpointer]:
|
||||
"""Async context manager that yields a checkpointer for the caller's lifetime.
|
||||
Resources are opened on enter and closed on exit — no global state::
|
||||
|
||||
async with make_checkpointer() as checkpointer:
|
||||
app.state.checkpointer = checkpointer
|
||||
|
||||
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
|
||||
"""
|
||||
|
||||
config = get_app_config()
|
||||
|
||||
if config.checkpointer is None:
|
||||
async def _async_checkpointer_from_database(db_config) -> AsyncIterator[Checkpointer]:
|
||||
"""Async context manager that constructs a checkpointer from unified DatabaseConfig."""
|
||||
if db_config.backend == "memory":
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
yield InMemorySaver()
|
||||
return
|
||||
|
||||
async with _async_checkpointer(config.checkpointer) as saver:
|
||||
yield saver
|
||||
if db_config.backend == "sqlite":
|
||||
try:
|
||||
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
||||
except ImportError as exc:
|
||||
raise ImportError(SQLITE_INSTALL) from exc
|
||||
|
||||
conn_str = db_config.checkpointer_sqlite_path
|
||||
ensure_sqlite_parent_dir(conn_str)
|
||||
async with AsyncSqliteSaver.from_conn_string(conn_str) as saver:
|
||||
await saver.setup()
|
||||
yield saver
|
||||
return
|
||||
|
||||
if db_config.backend == "postgres":
|
||||
try:
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
except ImportError as exc:
|
||||
raise ImportError(POSTGRES_INSTALL) from exc
|
||||
|
||||
if not db_config.postgres_url:
|
||||
raise ValueError("database.postgres_url is required for the postgres backend")
|
||||
|
||||
async with AsyncPostgresSaver.from_conn_string(db_config.postgres_url) as saver:
|
||||
await saver.setup()
|
||||
yield saver
|
||||
return
|
||||
|
||||
raise ValueError(f"Unknown database backend: {db_config.backend!r}")
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def make_checkpointer() -> AsyncIterator[Checkpointer]:
|
||||
"""Async context manager that yields a checkpointer for the caller's lifetime.
|
||||
Resources are opened on enter and closed on exit -- no global state::
|
||||
|
||||
async with make_checkpointer() as checkpointer:
|
||||
app.state.checkpointer = checkpointer
|
||||
|
||||
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
|
||||
|
||||
Priority:
|
||||
1. Legacy ``checkpointer:`` config section (backward compatible)
|
||||
2. Unified ``database:`` config section
|
||||
3. Default InMemorySaver
|
||||
"""
|
||||
|
||||
config = get_app_config()
|
||||
|
||||
# Legacy: standalone checkpointer config takes precedence
|
||||
if config.checkpointer is not None:
|
||||
async with _async_checkpointer(config.checkpointer) as saver:
|
||||
yield saver
|
||||
return
|
||||
|
||||
# Unified database config
|
||||
db_config = getattr(config, "database", None)
|
||||
if db_config is not None and db_config.backend != "memory":
|
||||
async with _async_checkpointer_from_database(db_config) as saver:
|
||||
yield saver
|
||||
return
|
||||
|
||||
# Default: in-memory
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
yield InMemorySaver()
|
||||
|
||||
@@ -56,13 +56,15 @@ def _create_summarization_middleware() -> SummarizationMiddleware | None:
|
||||
# Prepare keep parameter
|
||||
keep = config.keep.to_tuple()
|
||||
|
||||
# Prepare model parameter
|
||||
# Prepare model parameter.
|
||||
# Bind "middleware:summarize" tag so RunJournal identifies these LLM calls
|
||||
# as middleware rather than lead_agent (SummarizationMiddleware is a
|
||||
# LangChain built-in, so we tag the model at creation time).
|
||||
if config.model_name:
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
else:
|
||||
# Use a lightweight model for summarization to save costs
|
||||
# Falls back to default model if not explicitly specified
|
||||
model = create_chat_model(thinking_enabled=False)
|
||||
model = model.with_config(tags=["middleware:summarize"])
|
||||
|
||||
# Prepare kwargs
|
||||
kwargs = {
|
||||
|
||||
@@ -164,6 +164,30 @@ Skip simple one-off tasks.
|
||||
"""
|
||||
|
||||
|
||||
def _skill_mutability_label(category: str) -> str:
|
||||
return "[custom, editable]" if category == "custom" else "[built-in]"
|
||||
|
||||
|
||||
def clear_skills_system_prompt_cache() -> None:
|
||||
_get_cached_skills_prompt_section.cache_clear()
|
||||
|
||||
|
||||
def _build_skill_evolution_section(skill_evolution_enabled: bool) -> str:
|
||||
if not skill_evolution_enabled:
|
||||
return ""
|
||||
return """
|
||||
## Skill Self-Evolution
|
||||
After completing a task, consider creating or updating a skill when:
|
||||
- The task required 5+ tool calls to resolve
|
||||
- You overcame non-obvious errors or pitfalls
|
||||
- The user corrected your approach and the corrected version worked
|
||||
- You discovered a non-trivial, recurring workflow
|
||||
If you used a skill and encountered issues not covered by it, patch it immediately.
|
||||
Prefer patch over edit. Before creating a new skill, confirm with the user first.
|
||||
Skip simple one-off tasks.
|
||||
"""
|
||||
|
||||
|
||||
def _build_subagent_section(max_concurrent: int) -> str:
|
||||
"""Build the subagent system prompt section with dynamic concurrency limit.
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""Middleware for automatic thread title generation."""
|
||||
|
||||
import logging
|
||||
from typing import NotRequired, override
|
||||
from typing import Any, NotRequired, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.config.title_config import get_title_config
|
||||
@@ -100,6 +101,20 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
return user_msg[:fallback_chars].rstrip() + "..."
|
||||
return user_msg if user_msg else "New Conversation"
|
||||
|
||||
def _get_runnable_config(self) -> dict[str, Any]:
|
||||
"""Inherit the parent RunnableConfig and add middleware tag.
|
||||
|
||||
This ensures RunJournal identifies LLM calls from this middleware
|
||||
as ``middleware:title`` instead of ``lead_agent``.
|
||||
"""
|
||||
try:
|
||||
parent = get_config()
|
||||
except Exception:
|
||||
parent = {}
|
||||
config = {**parent}
|
||||
config["tags"] = [*(config.get("tags") or []), "middleware:title"]
|
||||
return config
|
||||
|
||||
def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None:
|
||||
"""Generate a local fallback title without blocking on an LLM call."""
|
||||
if not self._should_generate_title(state):
|
||||
@@ -121,7 +136,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
else:
|
||||
model = create_chat_model(thinking_enabled=False)
|
||||
response = await model.ainvoke(prompt)
|
||||
response = await model.ainvoke(prompt, config=self._get_runnable_config())
|
||||
title = self._parse_title(response.content)
|
||||
if title:
|
||||
return {"title": title}
|
||||
|
||||
@@ -10,10 +10,12 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from deerflow.config.acp_config import load_acp_config_from_dict
|
||||
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
|
||||
from deerflow.config.database_config import DatabaseConfig
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict
|
||||
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
|
||||
from deerflow.config.model_config import ModelConfig
|
||||
from deerflow.config.run_events_config import RunEventsConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.config.skill_evolution_config import SkillEvolutionConfig
|
||||
from deerflow.config.skills_config import SkillsConfig
|
||||
@@ -56,6 +58,8 @@ class AppConfig(BaseModel):
|
||||
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
|
||||
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
|
||||
model_config = ConfigDict(extra="allow", frozen=False)
|
||||
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
|
||||
run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")
|
||||
checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration")
|
||||
stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration")
|
||||
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
"""Unified database backend configuration.
|
||||
|
||||
Controls BOTH the LangGraph checkpointer and the DeerFlow application
|
||||
persistence layer (runs, threads metadata, users, etc.). The user
|
||||
configures one backend; the system handles physical separation details.
|
||||
|
||||
SQLite mode: checkpointer and app use different .db files in the same
|
||||
directory to avoid write-lock contention. This is automatic.
|
||||
|
||||
Postgres mode: both use the same database URL but maintain independent
|
||||
connection pools with different lifecycles.
|
||||
|
||||
Memory mode: checkpointer uses MemorySaver, app uses in-memory stores.
|
||||
No database is initialized.
|
||||
|
||||
Sensitive values (postgres_url) should use $VAR syntax in config.yaml
|
||||
to reference environment variables from .env:
|
||||
|
||||
database:
|
||||
backend: postgres
|
||||
postgres_url: $DATABASE_URL
|
||||
|
||||
The $VAR resolution is handled by AppConfig.resolve_env_variables()
|
||||
before this config is instantiated -- DatabaseConfig itself does not
|
||||
need to do any environment variable processing.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DatabaseConfig(BaseModel):
|
||||
backend: Literal["memory", "sqlite", "postgres"] = Field(
|
||||
default="memory",
|
||||
description=("Storage backend for both checkpointer and application data. 'memory' for development (no persistence across restarts), 'sqlite' for single-node deployment, 'postgres' for production multi-node deployment."),
|
||||
)
|
||||
sqlite_dir: str = Field(
|
||||
default=".deer-flow/data",
|
||||
description=("Directory for SQLite database files. Checkpointer uses {sqlite_dir}/checkpoints.db, application data uses {sqlite_dir}/app.db."),
|
||||
)
|
||||
postgres_url: str = Field(
|
||||
default="",
|
||||
description=(
|
||||
"PostgreSQL connection URL, shared by checkpointer and app. "
|
||||
"Use $DATABASE_URL in config.yaml to reference .env. "
|
||||
"Example: postgresql://user:pass@host:5432/deerflow "
|
||||
"(the +asyncpg driver suffix is added automatically where needed)."
|
||||
),
|
||||
)
|
||||
echo_sql: bool = Field(
|
||||
default=False,
|
||||
description="Echo all SQL statements to log (debug only).",
|
||||
)
|
||||
pool_size: int = Field(
|
||||
default=5,
|
||||
description="Connection pool size for the app ORM engine (postgres only).",
|
||||
)
|
||||
|
||||
# -- Derived helpers (not user-configured) --
|
||||
|
||||
@property
|
||||
def _resolved_sqlite_dir(self) -> str:
|
||||
"""Resolve sqlite_dir to an absolute path (relative to CWD)."""
|
||||
from pathlib import Path
|
||||
|
||||
return str(Path(self.sqlite_dir).resolve())
|
||||
|
||||
@property
|
||||
def checkpointer_sqlite_path(self) -> str:
|
||||
"""SQLite file path for the LangGraph checkpointer."""
|
||||
return os.path.join(self._resolved_sqlite_dir, "checkpoints.db")
|
||||
|
||||
@property
|
||||
def app_sqlite_path(self) -> str:
|
||||
"""SQLite file path for application ORM data."""
|
||||
return os.path.join(self._resolved_sqlite_dir, "app.db")
|
||||
|
||||
@property
|
||||
def app_sqlalchemy_url(self) -> str:
|
||||
"""SQLAlchemy async URL for the application ORM engine."""
|
||||
if self.backend == "sqlite":
|
||||
return f"sqlite+aiosqlite:///{self.app_sqlite_path}"
|
||||
if self.backend == "postgres":
|
||||
url = self.postgres_url
|
||||
if url.startswith("postgresql://"):
|
||||
url = url.replace("postgresql://", "postgresql+asyncpg://", 1)
|
||||
return url
|
||||
raise ValueError(f"No SQLAlchemy URL for backend={self.backend!r}")
|
||||
@@ -0,0 +1,33 @@
|
||||
"""Run event storage configuration.
|
||||
|
||||
Controls where run events (messages + execution traces) are persisted.
|
||||
|
||||
Backends:
|
||||
- memory: In-memory storage, data lost on restart. Suitable for
|
||||
development and testing.
|
||||
- db: SQL database via SQLAlchemy ORM. Provides full query capability.
|
||||
Suitable for production deployments.
|
||||
- jsonl: Append-only JSONL files. Lightweight alternative for
|
||||
single-node deployments that need persistence without a database.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RunEventsConfig(BaseModel):
|
||||
backend: Literal["memory", "db", "jsonl"] = Field(
|
||||
default="memory",
|
||||
description="Storage backend for run events. 'memory' for development (no persistence), 'db' for production (SQL queries), 'jsonl' for lightweight single-node persistence.",
|
||||
)
|
||||
max_trace_content: int = Field(
|
||||
default=10240,
|
||||
description="Maximum trace content size in bytes before truncation (db backend only).",
|
||||
)
|
||||
track_token_usage: bool = Field(
|
||||
default=True,
|
||||
description="Whether RunJournal should accumulate token counts to RunRow.",
|
||||
)
|
||||
@@ -113,7 +113,16 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
|
||||
elif "reasoning_effort" not in model_settings_from_config:
|
||||
model_settings_from_config["reasoning_effort"] = "medium"
|
||||
|
||||
model_instance = model_class(**{**model_settings_from_config, **kwargs})
|
||||
# Ensure stream_usage is enabled so that token usage metadata is available
|
||||
# in streaming responses. LangChain's BaseChatOpenAI only defaults
|
||||
# stream_usage=True when no custom base_url/api_base is set, so models
|
||||
# hitting third-party endpoints (e.g. doubao, deepseek) silently lose
|
||||
# usage data. We default it to True unless explicitly configured.
|
||||
if "stream_usage" not in model_settings_from_config and "stream_usage" not in kwargs:
|
||||
if "stream_usage" in getattr(model_class, "model_fields", {}):
|
||||
model_settings_from_config["stream_usage"] = True
|
||||
|
||||
model_instance = model_class(**kwargs, **model_settings_from_config)
|
||||
|
||||
callbacks = build_tracing_callbacks()
|
||||
if callbacks:
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
"""DeerFlow application persistence layer (SQLAlchemy 2.0 async ORM).
|
||||
|
||||
This module manages DeerFlow's own application data -- runs metadata,
|
||||
thread ownership, cron jobs, users. It is completely separate from
|
||||
LangGraph's checkpointer, which manages graph execution state.
|
||||
|
||||
Usage:
|
||||
from deerflow.persistence import init_engine, close_engine, get_session_factory
|
||||
"""
|
||||
|
||||
from deerflow.persistence.engine import close_engine, get_engine, get_session_factory, init_engine
|
||||
|
||||
__all__ = ["close_engine", "get_engine", "get_session_factory", "init_engine"]
|
||||
@@ -0,0 +1,40 @@
|
||||
"""SQLAlchemy declarative base with automatic to_dict support.
|
||||
|
||||
All DeerFlow ORM models inherit from this Base. It provides a generic
|
||||
to_dict() method via SQLAlchemy's inspect() so individual models don't
|
||||
need to write their own serialization logic.
|
||||
|
||||
LangGraph's checkpointer tables are NOT managed by this Base.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all DeerFlow ORM models.
|
||||
|
||||
Provides:
|
||||
- Automatic to_dict() via SQLAlchemy column inspection.
|
||||
- Standard __repr__() showing all column values.
|
||||
"""
|
||||
|
||||
def to_dict(self, *, exclude: set[str] | None = None) -> dict:
|
||||
"""Convert ORM instance to plain dict.
|
||||
|
||||
Uses SQLAlchemy's inspect() to iterate mapped column attributes.
|
||||
|
||||
Args:
|
||||
exclude: Optional set of column keys to omit.
|
||||
|
||||
Returns:
|
||||
Dict of {column_key: value} for all mapped columns.
|
||||
"""
|
||||
exclude = exclude or set()
|
||||
return {c.key: getattr(self, c.key) for c in sa_inspect(type(self)).mapper.column_attrs if c.key not in exclude}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
cols = ", ".join(f"{c.key}={getattr(self, c.key)!r}" for c in sa_inspect(type(self)).mapper.column_attrs)
|
||||
return f"{type(self).__name__}({cols})"
|
||||
@@ -0,0 +1,185 @@
|
||||
"""Async SQLAlchemy engine lifecycle management.
|
||||
|
||||
Initializes at Gateway startup, provides session factory for
|
||||
repositories, disposes at shutdown.
|
||||
|
||||
When database.backend="memory", init_engine is a no-op and
|
||||
get_session_factory() returns None. Repositories must check for
|
||||
None and fall back to in-memory implementations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
|
||||
def _json_serializer(obj: object) -> str:
|
||||
"""JSON serializer with ensure_ascii=False for Chinese character support."""
|
||||
return json.dumps(obj, ensure_ascii=False)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_engine: AsyncEngine | None = None
|
||||
_session_factory: async_sessionmaker[AsyncSession] | None = None
|
||||
|
||||
|
||||
async def _auto_create_postgres_db(url: str) -> None:
|
||||
"""Connect to the ``postgres`` maintenance DB and CREATE DATABASE.
|
||||
|
||||
The target database name is extracted from *url*. The connection is
|
||||
made to the default ``postgres`` database on the same server using
|
||||
``AUTOCOMMIT`` isolation (CREATE DATABASE cannot run inside a
|
||||
transaction).
|
||||
"""
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
parsed = make_url(url)
|
||||
db_name = parsed.database
|
||||
if not db_name:
|
||||
raise ValueError("Cannot auto-create database: no database name in URL")
|
||||
|
||||
# Connect to the default 'postgres' database to issue CREATE DATABASE
|
||||
maint_url = parsed.set(database="postgres")
|
||||
maint_engine = create_async_engine(maint_url, isolation_level="AUTOCOMMIT")
|
||||
try:
|
||||
async with maint_engine.connect() as conn:
|
||||
await conn.execute(text(f'CREATE DATABASE "{db_name}"'))
|
||||
logger.info("Auto-created PostgreSQL database: %s", db_name)
|
||||
finally:
|
||||
await maint_engine.dispose()
|
||||
|
||||
|
||||
async def init_engine(
|
||||
backend: str,
|
||||
*,
|
||||
url: str = "",
|
||||
echo: bool = False,
|
||||
pool_size: int = 5,
|
||||
sqlite_dir: str = "",
|
||||
) -> None:
|
||||
"""Create the async engine and session factory, then auto-create tables.
|
||||
|
||||
Args:
|
||||
backend: "memory", "sqlite", or "postgres".
|
||||
url: SQLAlchemy async URL (for sqlite/postgres).
|
||||
echo: Echo SQL to log.
|
||||
pool_size: Postgres connection pool size.
|
||||
sqlite_dir: Directory to create for SQLite (ensured to exist).
|
||||
"""
|
||||
global _engine, _session_factory
|
||||
|
||||
if backend == "memory":
|
||||
logger.info("Persistence backend=memory -- ORM engine not initialized")
|
||||
return
|
||||
|
||||
if backend == "postgres":
|
||||
try:
|
||||
import asyncpg # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError("database.backend is set to 'postgres' but asyncpg is not installed.\nInstall it with:\n uv sync --extra postgres\nOr switch to backend: sqlite in config.yaml for single-node deployment.") from None
|
||||
|
||||
if backend == "sqlite":
|
||||
import os
|
||||
|
||||
from sqlalchemy import event
|
||||
|
||||
os.makedirs(sqlite_dir or ".", exist_ok=True)
|
||||
_engine = create_async_engine(url, echo=echo, json_serializer=_json_serializer)
|
||||
|
||||
# Enable WAL on every new connection. SQLite PRAGMA settings are
|
||||
# per-connection, so we wire the listener instead of running PRAGMA
|
||||
# once at startup. WAL gives concurrent reads + writers without
|
||||
# blocking and is the standard recommendation for any production
|
||||
# SQLite deployment (TC-UPG-06 in AUTH_TEST_PLAN.md). The companion
|
||||
# ``synchronous=NORMAL`` is the safe-and-fast pairing — fsync only
|
||||
# at WAL checkpoint boundaries instead of every commit.
|
||||
@event.listens_for(_engine.sync_engine, "connect")
|
||||
def _enable_sqlite_wal(dbapi_conn, _record): # noqa: ARG001 — SQLAlchemy contract
|
||||
cursor = dbapi_conn.cursor()
|
||||
try:
|
||||
cursor.execute("PRAGMA journal_mode=WAL;")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL;")
|
||||
cursor.execute("PRAGMA foreign_keys=ON;")
|
||||
finally:
|
||||
cursor.close()
|
||||
elif backend == "postgres":
|
||||
_engine = create_async_engine(
|
||||
url,
|
||||
echo=echo,
|
||||
pool_size=pool_size,
|
||||
pool_pre_ping=True,
|
||||
json_serializer=_json_serializer,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown persistence backend: {backend!r}")
|
||||
|
||||
_session_factory = async_sessionmaker(_engine, expire_on_commit=False)
|
||||
|
||||
# Auto-create tables (dev convenience). Production should use Alembic.
|
||||
from deerflow.persistence.base import Base
|
||||
|
||||
# Import all models so Base.metadata discovers them.
|
||||
# When no models exist yet (scaffolding phase), this is a no-op.
|
||||
try:
|
||||
import deerflow.persistence.models # noqa: F401
|
||||
except ImportError:
|
||||
# Models package not yet available — tables won't be auto-created.
|
||||
# This is expected during initial scaffolding or minimal installs.
|
||||
logger.debug("deerflow.persistence.models not found; skipping auto-create tables")
|
||||
|
||||
try:
|
||||
async with _engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
except Exception as exc:
|
||||
if backend == "postgres" and "does not exist" in str(exc):
|
||||
# Database not yet created — attempt to auto-create it, then retry.
|
||||
await _auto_create_postgres_db(url)
|
||||
# Rebuild engine against the now-existing database
|
||||
await _engine.dispose()
|
||||
_engine = create_async_engine(url, echo=echo, pool_size=pool_size, pool_pre_ping=True, json_serializer=_json_serializer)
|
||||
_session_factory = async_sessionmaker(_engine, expire_on_commit=False)
|
||||
async with _engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
else:
|
||||
raise
|
||||
|
||||
logger.info("Persistence engine initialized: backend=%s", backend)
|
||||
|
||||
|
||||
async def init_engine_from_config(config) -> None:
|
||||
"""Convenience: init engine from a DatabaseConfig object."""
|
||||
if config.backend == "memory":
|
||||
await init_engine("memory")
|
||||
return
|
||||
await init_engine(
|
||||
backend=config.backend,
|
||||
url=config.app_sqlalchemy_url,
|
||||
echo=config.echo_sql,
|
||||
pool_size=config.pool_size,
|
||||
sqlite_dir=config.sqlite_dir if config.backend == "sqlite" else "",
|
||||
)
|
||||
|
||||
|
||||
def get_session_factory() -> async_sessionmaker[AsyncSession] | None:
|
||||
"""Return the async session factory, or None if backend=memory."""
|
||||
return _session_factory
|
||||
|
||||
|
||||
def get_engine() -> AsyncEngine | None:
|
||||
"""Return the async engine, or None if not initialized."""
|
||||
return _engine
|
||||
|
||||
|
||||
async def close_engine() -> None:
|
||||
"""Dispose the engine, release all connections."""
|
||||
global _engine, _session_factory
|
||||
if _engine is not None:
|
||||
await _engine.dispose()
|
||||
logger.info("Persistence engine closed")
|
||||
_engine = None
|
||||
_session_factory = None
|
||||
@@ -0,0 +1,6 @@
|
||||
"""Feedback persistence — ORM and SQL repository."""
|
||||
|
||||
from deerflow.persistence.feedback.model import FeedbackRow
|
||||
from deerflow.persistence.feedback.sql import FeedbackRepository
|
||||
|
||||
__all__ = ["FeedbackRepository", "FeedbackRow"]
|
||||
@@ -0,0 +1,30 @@
|
||||
"""ORM model for user feedback on runs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import DateTime, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
|
||||
|
||||
class FeedbackRow(Base):
|
||||
__tablename__ = "feedback"
|
||||
|
||||
feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
run_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
message_id: Mapped[str | None] = mapped_column(String(64))
|
||||
# message_id is an optional RunEventStore event identifier —
|
||||
# allows feedback to target a specific message or the entire run
|
||||
|
||||
rating: Mapped[int] = mapped_column(nullable=False)
|
||||
# +1 (thumbs-up) or -1 (thumbs-down)
|
||||
|
||||
comment: Mapped[str | None] = mapped_column(Text)
|
||||
# Optional text feedback from the user
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
@@ -0,0 +1,139 @@
|
||||
"""SQLAlchemy-backed feedback storage.
|
||||
|
||||
Each method acquires its own short-lived session.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import case, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.feedback.model import FeedbackRow
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
||||
|
||||
|
||||
class FeedbackRepository:
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
self._sf = session_factory
|
||||
|
||||
@staticmethod
|
||||
def _row_to_dict(row: FeedbackRow) -> dict:
|
||||
d = row.to_dict()
|
||||
val = d.get("created_at")
|
||||
if isinstance(val, datetime):
|
||||
d["created_at"] = val.isoformat()
|
||||
return d
|
||||
|
||||
async def create(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
thread_id: str,
|
||||
rating: int,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
message_id: str | None = None,
|
||||
comment: str | None = None,
|
||||
) -> dict:
|
||||
"""Create a feedback record. rating must be +1 or -1."""
|
||||
if rating not in (1, -1):
|
||||
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.create")
|
||||
row = FeedbackRow(
|
||||
feedback_id=str(uuid.uuid4()),
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
owner_id=resolved_owner_id,
|
||||
message_id=message_id,
|
||||
rating=rating,
|
||||
comment=comment,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
async with self._sf() as session:
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def get(
|
||||
self,
|
||||
feedback_id: str,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> dict | None:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.get")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(FeedbackRow, feedback_id)
|
||||
if row is None:
|
||||
return None
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
return None
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def list_by_run(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
limit: int = 100,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> list[dict]:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_run")
|
||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id)
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id)
|
||||
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def list_by_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 100,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> list[dict]:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_thread")
|
||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id)
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id)
|
||||
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
feedback_id: str,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> bool:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.delete")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(FeedbackRow, feedback_id)
|
||||
if row is None:
|
||||
return False
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
return False
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
async def aggregate_by_run(self, thread_id: str, run_id: str) -> dict:
|
||||
"""Aggregate feedback stats for a run using database-side counting."""
|
||||
stmt = select(
|
||||
func.count().label("total"),
|
||||
func.coalesce(func.sum(case((FeedbackRow.rating == 1, 1), else_=0)), 0).label("positive"),
|
||||
func.coalesce(func.sum(case((FeedbackRow.rating == -1, 1), else_=0)), 0).label("negative"),
|
||||
).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id)
|
||||
async with self._sf() as session:
|
||||
row = (await session.execute(stmt)).one()
|
||||
return {
|
||||
"run_id": run_id,
|
||||
"total": row.total,
|
||||
"positive": row.positive,
|
||||
"negative": row.negative,
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
[alembic]
|
||||
script_location = %(here)s
|
||||
# Default URL for offline mode / autogenerate.
|
||||
# Runtime uses engine from DeerFlow config.
|
||||
sqlalchemy.url = sqlite+aiosqlite:///./data/app.db
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
@@ -0,0 +1,65 @@
|
||||
"""Alembic environment for DeerFlow application tables.
|
||||
|
||||
ONLY manages DeerFlow's tables (runs, threads_meta, cron_jobs, users).
|
||||
LangGraph's checkpointer tables are managed by LangGraph itself -- they
|
||||
have their own schema lifecycle and must not be touched by Alembic.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
|
||||
# Import all models so metadata is populated.
|
||||
try:
|
||||
import deerflow.persistence.models # noqa: F401 — register ORM models with Base.metadata
|
||||
except ImportError:
|
||||
# Models not available — migration will work with existing metadata only.
|
||||
logging.getLogger(__name__).warning("Could not import deerflow.persistence.models; Alembic may not detect all tables")
|
||||
|
||||
config = context.config
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
render_as_batch=True,
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection):
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
render_as_batch=True, # Required for SQLite ALTER TABLE support
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_migrations_online() -> None:
|
||||
connectable = create_async_engine(config.get_main_option("sqlalchemy.url"))
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
asyncio.run(run_migrations_online())
|
||||
@@ -0,0 +1,23 @@
|
||||
"""ORM model registration entry point.
|
||||
|
||||
Importing this module ensures all ORM models are registered with
|
||||
``Base.metadata`` so Alembic autogenerate detects every table.
|
||||
|
||||
The actual ORM classes have moved to entity-specific subpackages:
|
||||
- ``deerflow.persistence.thread_meta``
|
||||
- ``deerflow.persistence.run``
|
||||
- ``deerflow.persistence.feedback``
|
||||
- ``deerflow.persistence.user``
|
||||
|
||||
``RunEventRow`` remains in ``deerflow.persistence.models.run_event`` because
|
||||
its storage implementation lives in ``deerflow.runtime.events.store.db`` and
|
||||
there is no matching entity directory.
|
||||
"""
|
||||
|
||||
from deerflow.persistence.feedback.model import FeedbackRow
|
||||
from deerflow.persistence.models.run_event import RunEventRow
|
||||
from deerflow.persistence.run.model import RunRow
|
||||
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
||||
from deerflow.persistence.user.model import UserRow
|
||||
|
||||
__all__ = ["FeedbackRow", "RunEventRow", "RunRow", "ThreadMetaRow", "UserRow"]
|
||||
@@ -0,0 +1,35 @@
|
||||
"""ORM model for run events."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import JSON, DateTime, Index, String, Text, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
|
||||
|
||||
class RunEventRow(Base):
|
||||
__tablename__ = "run_events"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
run_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
# Owner of the conversation this event belongs to. Nullable for data
|
||||
# created before auth was introduced; populated by auth middleware on
|
||||
# new writes and by the boot-time orphan migration on existing rows.
|
||||
owner_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
|
||||
event_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
category: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
# "message" | "trace" | "lifecycle"
|
||||
content: Mapped[str] = mapped_column(Text, default="")
|
||||
event_metadata: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
seq: Mapped[int] = mapped_column(nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("thread_id", "seq", name="uq_events_thread_seq"),
|
||||
Index("ix_events_thread_cat_seq", "thread_id", "category", "seq"),
|
||||
Index("ix_events_run", "thread_id", "run_id", "seq"),
|
||||
)
|
||||
@@ -0,0 +1,6 @@
|
||||
"""Run metadata persistence — ORM and SQL repository."""
|
||||
|
||||
from deerflow.persistence.run.model import RunRow
|
||||
from deerflow.persistence.run.sql import RunRepository
|
||||
|
||||
__all__ = ["RunRepository", "RunRow"]
|
||||
@@ -0,0 +1,49 @@
|
||||
"""ORM model for run metadata."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import JSON, DateTime, Index, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
|
||||
|
||||
class RunRow(Base):
|
||||
__tablename__ = "runs"
|
||||
|
||||
run_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
assistant_id: Mapped[str | None] = mapped_column(String(128))
|
||||
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
status: Mapped[str] = mapped_column(String(20), default="pending")
|
||||
# "pending" | "running" | "success" | "error" | "timeout" | "interrupted"
|
||||
|
||||
model_name: Mapped[str | None] = mapped_column(String(128))
|
||||
multitask_strategy: Mapped[str] = mapped_column(String(20), default="reject")
|
||||
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
kwargs_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
error: Mapped[str | None] = mapped_column(Text)
|
||||
|
||||
# Convenience fields (for listing pages without querying RunEventStore)
|
||||
message_count: Mapped[int] = mapped_column(default=0)
|
||||
first_human_message: Mapped[str | None] = mapped_column(Text)
|
||||
last_ai_message: Mapped[str | None] = mapped_column(Text)
|
||||
|
||||
# Token usage (accumulated in-memory by RunJournal, written on run completion)
|
||||
total_input_tokens: Mapped[int] = mapped_column(default=0)
|
||||
total_output_tokens: Mapped[int] = mapped_column(default=0)
|
||||
total_tokens: Mapped[int] = mapped_column(default=0)
|
||||
llm_call_count: Mapped[int] = mapped_column(default=0)
|
||||
lead_agent_tokens: Mapped[int] = mapped_column(default=0)
|
||||
subagent_tokens: Mapped[int] = mapped_column(default=0)
|
||||
middleware_tokens: Mapped[int] = mapped_column(default=0)
|
||||
|
||||
# Follow-up association
|
||||
follow_up_to_run_id: Mapped[str | None] = mapped_column(String(64))
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC))
|
||||
|
||||
__table_args__ = (Index("ix_runs_thread_status", "thread_id", "status"),)
|
||||
@@ -0,0 +1,255 @@
|
||||
"""SQLAlchemy-backed RunStore implementation.
|
||||
|
||||
Each method acquires and releases its own short-lived session.
|
||||
Run status updates happen from background workers that may live
|
||||
minutes -- we don't hold connections across long execution.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.run.model import RunRow
|
||||
from deerflow.runtime.runs.store.base import RunStore
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
||||
|
||||
|
||||
class RunRepository(RunStore):
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
self._sf = session_factory
|
||||
|
||||
@staticmethod
|
||||
def _safe_json(obj: Any) -> Any:
|
||||
"""Ensure obj is JSON-serializable. Falls back to model_dump() or str()."""
|
||||
if obj is None:
|
||||
return None
|
||||
if isinstance(obj, (str, int, float, bool)):
|
||||
return obj
|
||||
if isinstance(obj, dict):
|
||||
return {k: RunRepository._safe_json(v) for k, v in obj.items()}
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return [RunRepository._safe_json(v) for v in obj]
|
||||
if hasattr(obj, "model_dump"):
|
||||
try:
|
||||
return obj.model_dump()
|
||||
except Exception:
|
||||
pass
|
||||
if hasattr(obj, "dict"):
|
||||
try:
|
||||
return obj.dict()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
json.dumps(obj)
|
||||
return obj
|
||||
except (TypeError, ValueError):
|
||||
return str(obj)
|
||||
|
||||
@staticmethod
|
||||
def _row_to_dict(row: RunRow) -> dict[str, Any]:
|
||||
d = row.to_dict()
|
||||
# Remap JSON columns to match RunStore interface
|
||||
d["metadata"] = d.pop("metadata_json", {})
|
||||
d["kwargs"] = d.pop("kwargs_json", {})
|
||||
# Convert datetime to ISO string for consistency with MemoryRunStore
|
||||
for key in ("created_at", "updated_at"):
|
||||
val = d.get(key)
|
||||
if isinstance(val, datetime):
|
||||
d[key] = val.isoformat()
|
||||
return d
|
||||
|
||||
async def put(
|
||||
self,
|
||||
run_id,
|
||||
*,
|
||||
thread_id,
|
||||
assistant_id=None,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
status="pending",
|
||||
multitask_strategy="reject",
|
||||
metadata=None,
|
||||
kwargs=None,
|
||||
error=None,
|
||||
created_at=None,
|
||||
follow_up_to_run_id=None,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.put")
|
||||
now = datetime.now(UTC)
|
||||
row = RunRow(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
owner_id=resolved_owner_id,
|
||||
status=status,
|
||||
multitask_strategy=multitask_strategy,
|
||||
metadata_json=self._safe_json(metadata) or {},
|
||||
kwargs_json=self._safe_json(kwargs) or {},
|
||||
error=error,
|
||||
follow_up_to_run_id=follow_up_to_run_id,
|
||||
created_at=datetime.fromisoformat(created_at) if created_at else now,
|
||||
updated_at=now,
|
||||
)
|
||||
async with self._sf() as session:
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
|
||||
async def get(
|
||||
self,
|
||||
run_id,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.get")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(RunRow, run_id)
|
||||
if row is None:
|
||||
return None
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
return None
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def list_by_thread(
|
||||
self,
|
||||
thread_id,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
limit=100,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.list_by_thread")
|
||||
stmt = select(RunRow).where(RunRow.thread_id == thread_id)
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunRow.owner_id == resolved_owner_id)
|
||||
stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def update_status(self, run_id, status, *, error=None):
|
||||
values: dict[str, Any] = {"status": status, "updated_at": datetime.now(UTC)}
|
||||
if error is not None:
|
||||
values["error"] = error
|
||||
async with self._sf() as session:
|
||||
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
|
||||
await session.commit()
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
run_id,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.delete")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(RunRow, run_id)
|
||||
if row is None:
|
||||
return
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
return
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
|
||||
async def list_pending(self, *, before=None):
|
||||
if before is None:
|
||||
before_dt = datetime.now(UTC)
|
||||
elif isinstance(before, datetime):
|
||||
before_dt = before
|
||||
else:
|
||||
before_dt = datetime.fromisoformat(before)
|
||||
stmt = select(RunRow).where(RunRow.status == "pending", RunRow.created_at <= before_dt).order_by(RunRow.created_at.asc())
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def update_run_completion(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
status: str,
|
||||
total_input_tokens: int = 0,
|
||||
total_output_tokens: int = 0,
|
||||
total_tokens: int = 0,
|
||||
llm_call_count: int = 0,
|
||||
lead_agent_tokens: int = 0,
|
||||
subagent_tokens: int = 0,
|
||||
middleware_tokens: int = 0,
|
||||
message_count: int = 0,
|
||||
last_ai_message: str | None = None,
|
||||
first_human_message: str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""Update status + token usage + convenience fields on run completion."""
|
||||
values: dict[str, Any] = {
|
||||
"status": status,
|
||||
"total_input_tokens": total_input_tokens,
|
||||
"total_output_tokens": total_output_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"llm_call_count": llm_call_count,
|
||||
"lead_agent_tokens": lead_agent_tokens,
|
||||
"subagent_tokens": subagent_tokens,
|
||||
"middleware_tokens": middleware_tokens,
|
||||
"message_count": message_count,
|
||||
"updated_at": datetime.now(UTC),
|
||||
}
|
||||
if last_ai_message is not None:
|
||||
values["last_ai_message"] = last_ai_message[:2000]
|
||||
if first_human_message is not None:
|
||||
values["first_human_message"] = first_human_message[:2000]
|
||||
if error is not None:
|
||||
values["error"] = error
|
||||
async with self._sf() as session:
|
||||
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
|
||||
await session.commit()
|
||||
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||
"""Aggregate token usage via a single SQL GROUP BY query."""
|
||||
_completed = RunRow.status.in_(("success", "error"))
|
||||
_thread = RunRow.thread_id == thread_id
|
||||
|
||||
stmt = (
|
||||
select(
|
||||
func.coalesce(RunRow.model_name, "unknown").label("model"),
|
||||
func.count().label("runs"),
|
||||
func.coalesce(func.sum(RunRow.total_tokens), 0).label("total_tokens"),
|
||||
func.coalesce(func.sum(RunRow.total_input_tokens), 0).label("total_input_tokens"),
|
||||
func.coalesce(func.sum(RunRow.total_output_tokens), 0).label("total_output_tokens"),
|
||||
func.coalesce(func.sum(RunRow.lead_agent_tokens), 0).label("lead_agent"),
|
||||
func.coalesce(func.sum(RunRow.subagent_tokens), 0).label("subagent"),
|
||||
func.coalesce(func.sum(RunRow.middleware_tokens), 0).label("middleware"),
|
||||
)
|
||||
.where(_thread, _completed)
|
||||
.group_by(func.coalesce(RunRow.model_name, "unknown"))
|
||||
)
|
||||
|
||||
async with self._sf() as session:
|
||||
rows = (await session.execute(stmt)).all()
|
||||
|
||||
total_tokens = total_input = total_output = total_runs = 0
|
||||
lead_agent = subagent = middleware = 0
|
||||
by_model: dict[str, dict] = {}
|
||||
for r in rows:
|
||||
by_model[r.model] = {"tokens": r.total_tokens, "runs": r.runs}
|
||||
total_tokens += r.total_tokens
|
||||
total_input += r.total_input_tokens
|
||||
total_output += r.total_output_tokens
|
||||
total_runs += r.runs
|
||||
lead_agent += r.lead_agent
|
||||
subagent += r.subagent
|
||||
middleware += r.middleware
|
||||
|
||||
return {
|
||||
"total_tokens": total_tokens,
|
||||
"total_input_tokens": total_input,
|
||||
"total_output_tokens": total_output,
|
||||
"total_runs": total_runs,
|
||||
"by_model": by_model,
|
||||
"by_caller": {
|
||||
"lead_agent": lead_agent,
|
||||
"subagent": subagent,
|
||||
"middleware": middleware,
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
"""Thread metadata persistence — ORM, abstract store, and concrete implementations."""
|
||||
|
||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
|
||||
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
||||
from deerflow.persistence.thread_meta.sql import ThreadMetaRepository
|
||||
|
||||
__all__ = [
|
||||
"MemoryThreadMetaStore",
|
||||
"ThreadMetaRepository",
|
||||
"ThreadMetaRow",
|
||||
"ThreadMetaStore",
|
||||
]
|
||||
@@ -0,0 +1,60 @@
|
||||
"""Abstract interface for thread metadata storage.
|
||||
|
||||
Implementations:
|
||||
- ThreadMetaRepository: SQL-backed (sqlite / postgres via SQLAlchemy)
|
||||
- MemoryThreadMetaStore: wraps LangGraph BaseStore (memory mode)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
|
||||
|
||||
class ThreadMetaStore(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
async def create(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
assistant_id: str | None = None,
|
||||
owner_id: str | None = None,
|
||||
display_name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> dict:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get(self, thread_id: str) -> dict | None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def search(
|
||||
self,
|
||||
*,
|
||||
metadata: dict | None = None,
|
||||
status: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[dict]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def update_display_name(self, thread_id: str, display_name: str) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def update_status(self, thread_id: str, status: str) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def update_metadata(self, thread_id: str, metadata: dict) -> None:
|
||||
"""Merge ``metadata`` into the thread's metadata field.
|
||||
|
||||
Existing keys are overwritten by the new values; keys absent from
|
||||
``metadata`` are preserved. No-op if the thread does not exist.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete(self, thread_id: str) -> None:
|
||||
pass
|
||||
@@ -0,0 +1,120 @@
|
||||
"""In-memory ThreadMetaStore backed by LangGraph BaseStore.
|
||||
|
||||
Used when database.backend=memory. Delegates to the LangGraph Store's
|
||||
``("threads",)`` namespace — the same namespace used by the Gateway
|
||||
router for thread records.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from langgraph.store.base import BaseStore
|
||||
|
||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||
|
||||
THREADS_NS: tuple[str, ...] = ("threads",)
|
||||
|
||||
|
||||
class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
def __init__(self, store: BaseStore) -> None:
|
||||
self._store = store
|
||||
|
||||
async def create(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
assistant_id: str | None = None,
|
||||
owner_id: str | None = None,
|
||||
display_name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> dict:
|
||||
now = time.time()
|
||||
record: dict[str, Any] = {
|
||||
"thread_id": thread_id,
|
||||
"assistant_id": assistant_id,
|
||||
"owner_id": owner_id,
|
||||
"display_name": display_name,
|
||||
"status": "idle",
|
||||
"metadata": metadata or {},
|
||||
"values": {},
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
return record
|
||||
|
||||
async def get(self, thread_id: str) -> dict | None:
|
||||
item = await self._store.aget(THREADS_NS, thread_id)
|
||||
return item.value if item is not None else None
|
||||
|
||||
async def search(
|
||||
self,
|
||||
*,
|
||||
metadata: dict | None = None,
|
||||
status: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[dict]:
|
||||
filter_dict: dict[str, Any] = {}
|
||||
if metadata:
|
||||
filter_dict.update(metadata)
|
||||
if status:
|
||||
filter_dict["status"] = status
|
||||
|
||||
items = await self._store.asearch(
|
||||
THREADS_NS,
|
||||
filter=filter_dict or None,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
return [self._item_to_dict(item) for item in items]
|
||||
|
||||
async def update_display_name(self, thread_id: str, display_name: str) -> None:
|
||||
item = await self._store.aget(THREADS_NS, thread_id)
|
||||
if item is None:
|
||||
return
|
||||
record = dict(item.value)
|
||||
record["display_name"] = display_name
|
||||
record["updated_at"] = time.time()
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
|
||||
async def update_status(self, thread_id: str, status: str) -> None:
|
||||
item = await self._store.aget(THREADS_NS, thread_id)
|
||||
if item is None:
|
||||
return
|
||||
record = dict(item.value)
|
||||
record["status"] = status
|
||||
record["updated_at"] = time.time()
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
|
||||
async def update_metadata(self, thread_id: str, metadata: dict) -> None:
|
||||
"""Merge ``metadata`` into the in-memory record. No-op if absent."""
|
||||
item = await self._store.aget(THREADS_NS, thread_id)
|
||||
if item is None:
|
||||
return
|
||||
record = dict(item.value)
|
||||
merged = dict(record.get("metadata") or {})
|
||||
merged.update(metadata)
|
||||
record["metadata"] = merged
|
||||
record["updated_at"] = time.time()
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
|
||||
async def delete(self, thread_id: str) -> None:
|
||||
await self._store.adelete(THREADS_NS, thread_id)
|
||||
|
||||
@staticmethod
|
||||
def _item_to_dict(item) -> dict[str, Any]:
|
||||
"""Convert a Store SearchItem to the dict format expected by callers."""
|
||||
val = item.value
|
||||
return {
|
||||
"thread_id": item.key,
|
||||
"assistant_id": val.get("assistant_id"),
|
||||
"owner_id": val.get("owner_id"),
|
||||
"display_name": val.get("display_name"),
|
||||
"status": val.get("status", "idle"),
|
||||
"metadata": val.get("metadata", {}),
|
||||
"created_at": str(val.get("created_at", "")),
|
||||
"updated_at": str(val.get("updated_at", "")),
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
"""ORM model for thread metadata."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import JSON, DateTime, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
|
||||
|
||||
class ThreadMetaRow(Base):
|
||||
__tablename__ = "threads_meta"
|
||||
|
||||
thread_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
assistant_id: Mapped[str | None] = mapped_column(String(128), index=True)
|
||||
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
display_name: Mapped[str | None] = mapped_column(String(256))
|
||||
status: Mapped[str] = mapped_column(String(20), default="idle")
|
||||
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC))
|
||||
@@ -0,0 +1,223 @@
|
||||
"""SQLAlchemy-backed thread metadata repository."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
||||
|
||||
|
||||
class ThreadMetaRepository(ThreadMetaStore):
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
self._sf = session_factory
|
||||
|
||||
@staticmethod
|
||||
def _row_to_dict(row: ThreadMetaRow) -> dict[str, Any]:
|
||||
d = row.to_dict()
|
||||
d["metadata"] = d.pop("metadata_json", {})
|
||||
for key in ("created_at", "updated_at"):
|
||||
val = d.get(key)
|
||||
if isinstance(val, datetime):
|
||||
d[key] = val.isoformat()
|
||||
return d
|
||||
|
||||
async def create(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
assistant_id: str | None = None,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
display_name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> dict:
|
||||
# Auto-resolve owner_id from contextvar when AUTO; explicit None
|
||||
# creates an orphan row (used by migration scripts).
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.create")
|
||||
now = datetime.now(UTC)
|
||||
row = ThreadMetaRow(
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
owner_id=resolved_owner_id,
|
||||
display_name=display_name,
|
||||
metadata_json=metadata or {},
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
async with self._sf() as session:
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def get(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> dict | None:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.get")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is None:
|
||||
return None
|
||||
# Enforce owner filter unless explicitly bypassed (owner_id=None).
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
return None
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def list_by_owner(self, owner_id: str, *, limit: int = 100, offset: int = 0) -> list[dict]:
|
||||
stmt = select(ThreadMetaRow).where(ThreadMetaRow.owner_id == owner_id).order_by(ThreadMetaRow.updated_at.desc()).limit(limit).offset(offset)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def check_access(self, thread_id: str, owner_id: str, *, require_existing: bool = False) -> bool:
|
||||
"""Check if ``owner_id`` has access to ``thread_id``.
|
||||
|
||||
Two modes — one row, two distinct semantics depending on what
|
||||
the caller is about to do:
|
||||
|
||||
- ``require_existing=False`` (default, permissive):
|
||||
Returns True for: row missing (untracked legacy thread),
|
||||
``row.owner_id`` is None (shared / pre-auth data),
|
||||
or ``row.owner_id == owner_id``. Use for **read-style**
|
||||
decorators where treating an untracked thread as accessible
|
||||
preserves backward-compat.
|
||||
|
||||
- ``require_existing=True`` (strict):
|
||||
Returns True **only** when the row exists AND
|
||||
(``row.owner_id == owner_id`` OR ``row.owner_id is None``).
|
||||
Use for **destructive / mutating** decorators (DELETE, PATCH,
|
||||
state-update) so a thread that has *already been deleted*
|
||||
cannot be re-targeted by any caller — closing the
|
||||
delete-idempotence cross-user gap where the row vanishing
|
||||
made every other user appear to "own" it.
|
||||
"""
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is None:
|
||||
return not require_existing
|
||||
if row.owner_id is None:
|
||||
return True
|
||||
return row.owner_id == owner_id
|
||||
|
||||
async def search(
|
||||
self,
|
||||
*,
|
||||
metadata: dict | None = None,
|
||||
status: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> list[dict]:
|
||||
"""Search threads with optional metadata and status filters.
|
||||
|
||||
Owner filter is enforced by default: caller must be in a user
|
||||
context. Pass ``owner_id=None`` to bypass (migration/CLI).
|
||||
"""
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.search")
|
||||
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc())
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(ThreadMetaRow.owner_id == resolved_owner_id)
|
||||
if status:
|
||||
stmt = stmt.where(ThreadMetaRow.status == status)
|
||||
|
||||
if metadata:
|
||||
# When metadata filter is active, fetch a larger window and filter
|
||||
# in Python. TODO(Phase 2): use JSON DB operators (Postgres @>,
|
||||
# SQLite json_extract) for server-side filtering.
|
||||
stmt = stmt.limit(limit * 5 + offset)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
rows = [self._row_to_dict(r) for r in result.scalars()]
|
||||
rows = [r for r in rows if all(r.get("metadata", {}).get(k) == v for k, v in metadata.items())]
|
||||
return rows[offset : offset + limit]
|
||||
else:
|
||||
stmt = stmt.limit(limit).offset(offset)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_owner_id: str | None) -> bool:
|
||||
"""Return True if the row exists and is owned (or filter bypassed)."""
|
||||
if resolved_owner_id is None:
|
||||
return True # explicit bypass
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
return row is not None and row.owner_id == resolved_owner_id
|
||||
|
||||
async def update_display_name(
|
||||
self,
|
||||
thread_id: str,
|
||||
display_name: str,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> None:
|
||||
"""Update the display_name (title) for a thread."""
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_display_name")
|
||||
async with self._sf() as session:
|
||||
if not await self._check_ownership(session, thread_id, resolved_owner_id):
|
||||
return
|
||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(display_name=display_name, updated_at=datetime.now(UTC)))
|
||||
await session.commit()
|
||||
|
||||
async def update_status(
|
||||
self,
|
||||
thread_id: str,
|
||||
status: str,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> None:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_status")
|
||||
async with self._sf() as session:
|
||||
if not await self._check_ownership(session, thread_id, resolved_owner_id):
|
||||
return
|
||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC)))
|
||||
await session.commit()
|
||||
|
||||
async def update_metadata(
|
||||
self,
|
||||
thread_id: str,
|
||||
metadata: dict,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> None:
|
||||
"""Merge ``metadata`` into ``metadata_json``.
|
||||
|
||||
Read-modify-write inside a single session/transaction so concurrent
|
||||
callers see consistent state. No-op if the row does not exist or
|
||||
the owner_id check fails.
|
||||
"""
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_metadata")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is None:
|
||||
return
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
return
|
||||
merged = dict(row.metadata_json or {})
|
||||
merged.update(metadata)
|
||||
row.metadata_json = merged
|
||||
row.updated_at = datetime.now(UTC)
|
||||
await session.commit()
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> None:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.delete")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is None:
|
||||
return
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
return
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
@@ -0,0 +1,12 @@
|
||||
"""User storage subpackage.
|
||||
|
||||
Holds the ORM model for the ``users`` table. The concrete repository
|
||||
implementation (``SQLiteUserRepository``) lives in the app layer
|
||||
(``app.gateway.auth.repositories.sqlite``) because it converts
|
||||
between the ORM row and the auth module's pydantic ``User`` class.
|
||||
This keeps the harness package free of any dependency on app code.
|
||||
"""
|
||||
|
||||
from deerflow.persistence.user.model import UserRow
|
||||
|
||||
__all__ = ["UserRow"]
|
||||
@@ -0,0 +1,59 @@
|
||||
"""ORM model for the users table.
|
||||
|
||||
Lives in the harness persistence package so it is picked up by
|
||||
``Base.metadata.create_all()`` alongside ``threads_meta``, ``runs``,
|
||||
``run_events``, and ``feedback``. Using the shared engine means:
|
||||
|
||||
- One SQLite/Postgres database, one connection pool
|
||||
- One schema initialisation codepath
|
||||
- Consistent async sessions across auth and persistence reads
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Index, String, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
|
||||
|
||||
class UserRow(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
# UUIDs are stored as 36-char strings for cross-backend portability.
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True)
|
||||
|
||||
email: Mapped[str] = mapped_column(String(320), unique=True, nullable=False, index=True)
|
||||
password_hash: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
|
||||
# "admin" | "user" — kept as plain string to avoid ALTER TABLE pain
|
||||
# when new roles are introduced.
|
||||
system_role: Mapped[str] = mapped_column(String(16), nullable=False, default="user")
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(UTC),
|
||||
)
|
||||
|
||||
# OAuth linkage (optional). A partial unique index enforces one
|
||||
# account per (provider, oauth_id) pair, leaving NULL/NULL rows
|
||||
# unconstrained so plain password accounts can coexist.
|
||||
oauth_provider: Mapped[str | None] = mapped_column(String(32), nullable=True)
|
||||
oauth_id: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
|
||||
# Auth lifecycle flags
|
||||
needs_setup: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
token_version: Mapped[int] = mapped_column(nullable=False, default=0)
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"idx_users_oauth_identity",
|
||||
"oauth_provider",
|
||||
"oauth_id",
|
||||
unique=True,
|
||||
sqlite_where=text("oauth_provider IS NOT NULL AND oauth_id IS NOT NULL"),
|
||||
),
|
||||
)
|
||||
@@ -5,7 +5,7 @@ Re-exports the public API of :mod:`~deerflow.runtime.runs` and
|
||||
directly from ``deerflow.runtime``.
|
||||
"""
|
||||
|
||||
from .runs import ConflictError, DisconnectMode, RunManager, RunRecord, RunStatus, UnsupportedStrategyError, run_agent
|
||||
from .runs import ConflictError, DisconnectMode, RunContext, RunManager, RunRecord, RunStatus, UnsupportedStrategyError, run_agent
|
||||
from .serialization import serialize, serialize_channel_values, serialize_lc_object, serialize_messages_tuple
|
||||
from .store import get_store, make_store, reset_store, store_context
|
||||
from .stream_bridge import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, StreamBridge, StreamEvent, make_stream_bridge
|
||||
@@ -14,6 +14,7 @@ __all__ = [
|
||||
# runs
|
||||
"ConflictError",
|
||||
"DisconnectMode",
|
||||
"RunContext",
|
||||
"RunManager",
|
||||
"RunRecord",
|
||||
"RunStatus",
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
"""Pure functions to convert LangChain message objects to OpenAI Chat Completions format.
|
||||
|
||||
Used by RunJournal to build content dicts for event storage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
_ROLE_MAP = {
|
||||
"human": "user",
|
||||
"ai": "assistant",
|
||||
"system": "system",
|
||||
"tool": "tool",
|
||||
}
|
||||
|
||||
|
||||
def langchain_to_openai_message(message: Any) -> dict:
|
||||
"""Convert a single LangChain BaseMessage to an OpenAI message dict.
|
||||
|
||||
Handles:
|
||||
- HumanMessage → {"role": "user", "content": "..."}
|
||||
- AIMessage (text only) → {"role": "assistant", "content": "..."}
|
||||
- AIMessage (with tool_calls) → {"role": "assistant", "content": null, "tool_calls": [...]}
|
||||
- AIMessage (text + tool_calls) → both content and tool_calls present
|
||||
- AIMessage (list content / multimodal) → content preserved as list
|
||||
- SystemMessage → {"role": "system", "content": "..."}
|
||||
- ToolMessage → {"role": "tool", "tool_call_id": "...", "content": "..."}
|
||||
"""
|
||||
msg_type = getattr(message, "type", "")
|
||||
role = _ROLE_MAP.get(msg_type, msg_type)
|
||||
content = getattr(message, "content", "")
|
||||
|
||||
if role == "tool":
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": getattr(message, "tool_call_id", ""),
|
||||
"content": content,
|
||||
}
|
||||
|
||||
if role == "assistant":
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
result: dict = {"role": "assistant"}
|
||||
|
||||
if tool_calls:
|
||||
openai_tool_calls = []
|
||||
for tc in tool_calls:
|
||||
args = tc.get("args", {})
|
||||
openai_tool_calls.append(
|
||||
{
|
||||
"id": tc.get("id", ""),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.get("name", ""),
|
||||
"arguments": json.dumps(args) if not isinstance(args, str) else args,
|
||||
},
|
||||
}
|
||||
)
|
||||
# If no text content, set content to null per OpenAI spec
|
||||
result["content"] = content if (isinstance(content, list) and content) or (isinstance(content, str) and content) else None
|
||||
result["tool_calls"] = openai_tool_calls
|
||||
else:
|
||||
result["content"] = content
|
||||
|
||||
return result
|
||||
|
||||
# user / system / unknown
|
||||
return {"role": role, "content": content}
|
||||
|
||||
|
||||
def _infer_finish_reason(message: Any) -> str:
|
||||
"""Infer OpenAI finish_reason from an AIMessage.
|
||||
|
||||
Returns "tool_calls" if tool_calls present, else looks in
|
||||
response_metadata.finish_reason, else returns "stop".
|
||||
"""
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
if tool_calls:
|
||||
return "tool_calls"
|
||||
resp_meta = getattr(message, "response_metadata", None) or {}
|
||||
if isinstance(resp_meta, dict):
|
||||
finish = resp_meta.get("finish_reason")
|
||||
if finish:
|
||||
return finish
|
||||
return "stop"
|
||||
|
||||
|
||||
def langchain_to_openai_completion(message: Any) -> dict:
|
||||
"""Convert an AIMessage and its metadata to an OpenAI completion response dict.
|
||||
|
||||
Returns:
|
||||
{
|
||||
"id": message.id,
|
||||
"model": message.response_metadata.get("model_name"),
|
||||
"choices": [{"index": 0, "message": <openai_message>, "finish_reason": <inferred>}],
|
||||
"usage": {"prompt_tokens": ..., "completion_tokens": ..., "total_tokens": ...} or None,
|
||||
}
|
||||
"""
|
||||
resp_meta = getattr(message, "response_metadata", None) or {}
|
||||
model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None
|
||||
|
||||
openai_msg = langchain_to_openai_message(message)
|
||||
finish_reason = _infer_finish_reason(message)
|
||||
|
||||
usage_metadata = getattr(message, "usage_metadata", None)
|
||||
if usage_metadata is not None:
|
||||
input_tokens = usage_metadata.get("input_tokens", 0) or 0
|
||||
output_tokens = usage_metadata.get("output_tokens", 0) or 0
|
||||
usage: dict | None = {
|
||||
"prompt_tokens": input_tokens,
|
||||
"completion_tokens": output_tokens,
|
||||
"total_tokens": input_tokens + output_tokens,
|
||||
}
|
||||
else:
|
||||
usage = None
|
||||
|
||||
return {
|
||||
"id": getattr(message, "id", None),
|
||||
"model": model_name,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": openai_msg,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
"usage": usage,
|
||||
}
|
||||
|
||||
|
||||
def langchain_messages_to_openai(messages: list) -> list[dict]:
|
||||
"""Convert a list of LangChain BaseMessages to OpenAI message dicts."""
|
||||
return [langchain_to_openai_message(m) for m in messages]
|
||||
@@ -0,0 +1,4 @@
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
|
||||
__all__ = ["MemoryRunEventStore", "RunEventStore"]
|
||||
@@ -0,0 +1,26 @@
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
|
||||
|
||||
def make_run_event_store(config=None) -> RunEventStore:
|
||||
"""Create a RunEventStore based on run_events.backend configuration."""
|
||||
if config is None or config.backend == "memory":
|
||||
return MemoryRunEventStore()
|
||||
if config.backend == "db":
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
|
||||
sf = get_session_factory()
|
||||
if sf is None:
|
||||
# database.backend=memory but run_events.backend=db -> fallback
|
||||
return MemoryRunEventStore()
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
return DbRunEventStore(sf, max_trace_content=config.max_trace_content)
|
||||
if config.backend == "jsonl":
|
||||
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
|
||||
|
||||
return JsonlRunEventStore()
|
||||
raise ValueError(f"Unknown run_events backend: {config.backend!r}")
|
||||
|
||||
|
||||
__all__ = ["MemoryRunEventStore", "RunEventStore", "make_run_event_store"]
|
||||
@@ -0,0 +1,99 @@
|
||||
"""Abstract interface for run event storage.
|
||||
|
||||
RunEventStore is the unified storage interface for run event streams.
|
||||
Messages (frontend display) and execution traces (debugging/audit) go
|
||||
through the same interface, distinguished by the ``category`` field.
|
||||
|
||||
Implementations:
|
||||
- MemoryRunEventStore: in-memory dict (development, tests)
|
||||
- Future: DB-backed store (SQLAlchemy ORM), JSONL file store
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
|
||||
|
||||
class RunEventStore(abc.ABC):
|
||||
"""Run event stream storage interface.
|
||||
|
||||
All implementations must guarantee:
|
||||
1. put() events are retrievable in subsequent queries
|
||||
2. seq is strictly increasing within the same thread
|
||||
3. list_messages() only returns category="message" events
|
||||
4. list_events() returns all events for the specified run
|
||||
5. Returned dicts match the RunEvent field structure
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def put(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
event_type: str,
|
||||
category: str,
|
||||
content: str | dict = "",
|
||||
metadata: dict | None = None,
|
||||
created_at: str | None = None,
|
||||
) -> dict:
|
||||
"""Write an event, auto-assign seq, return the complete record."""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def put_batch(self, events: list[dict]) -> list[dict]:
|
||||
"""Batch-write events. Used by RunJournal flush buffer.
|
||||
|
||||
Each dict's keys match put()'s keyword arguments.
|
||||
Returns complete records with seq assigned.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_messages(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
) -> list[dict]:
|
||||
"""Return displayable messages (category=message) for a thread, ordered by seq ascending.
|
||||
|
||||
Supports bidirectional cursor pagination:
|
||||
- before_seq: return the last ``limit`` records with seq < before_seq (ascending)
|
||||
- after_seq: return the first ``limit`` records with seq > after_seq (ascending)
|
||||
- neither: return the latest ``limit`` records (ascending)
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_events(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
event_types: list[str] | None = None,
|
||||
limit: int = 500,
|
||||
) -> list[dict]:
|
||||
"""Return the full event stream for a run, ordered by seq ascending.
|
||||
|
||||
Optionally filter by event_types.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_messages_by_run(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
) -> list[dict]:
|
||||
"""Return displayable messages (category=message) for a specific run, ordered by seq ascending."""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def count_messages(self, thread_id: str) -> int:
|
||||
"""Count displayable messages (category=message) in a thread."""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_by_thread(self, thread_id: str) -> int:
|
||||
"""Delete all events for a thread. Return the number of deleted events."""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_by_run(self, thread_id: str, run_id: str) -> int:
|
||||
"""Delete all events for a specific run. Return the number of deleted events."""
|
||||
@@ -0,0 +1,261 @@
|
||||
"""SQLAlchemy-backed RunEventStore implementation.
|
||||
|
||||
Persists events to the ``run_events`` table. Trace content is truncated
|
||||
at ``max_trace_content`` bytes to avoid bloating the database.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.models.run_event import RunEventRow
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_owner_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DbRunEventStore(RunEventStore):
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession], *, max_trace_content: int = 10240):
|
||||
self._sf = session_factory
|
||||
self._max_trace_content = max_trace_content
|
||||
|
||||
@staticmethod
|
||||
def _row_to_dict(row: RunEventRow) -> dict:
|
||||
d = row.to_dict()
|
||||
d["metadata"] = d.pop("event_metadata", {})
|
||||
val = d.get("created_at")
|
||||
if isinstance(val, datetime):
|
||||
d["created_at"] = val.isoformat()
|
||||
d.pop("id", None)
|
||||
# Restore dict content that was JSON-serialized on write
|
||||
raw = d.get("content", "")
|
||||
if isinstance(raw, str) and d.get("metadata", {}).get("content_is_dict"):
|
||||
try:
|
||||
d["content"] = json.loads(raw)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
# Content looked like JSON (content_is_dict flag) but failed to parse;
|
||||
# keep the raw string as-is.
|
||||
logger.debug("Failed to deserialize content as JSON for event seq=%s", d.get("seq"))
|
||||
return d
|
||||
|
||||
def _truncate_trace(self, category: str, content: str | dict, metadata: dict | None) -> tuple[str | dict, dict]:
|
||||
if category == "trace":
|
||||
text = json.dumps(content, default=str, ensure_ascii=False) if isinstance(content, dict) else content
|
||||
encoded = text.encode("utf-8")
|
||||
if len(encoded) > self._max_trace_content:
|
||||
# Truncate by bytes, then decode back (may cut a multi-byte char, so use errors="ignore")
|
||||
content = encoded[: self._max_trace_content].decode("utf-8", errors="ignore")
|
||||
metadata = {**(metadata or {}), "content_truncated": True, "original_byte_length": len(encoded)}
|
||||
return content, metadata or {}
|
||||
|
||||
@staticmethod
|
||||
def _owner_from_context() -> str | None:
|
||||
"""Soft read of owner_id from contextvar for write paths.
|
||||
|
||||
Returns ``None`` (no filter / no stamp) if contextvar is unset,
|
||||
which is the expected case for background worker writes. HTTP
|
||||
request writes will have the contextvar set by auth middleware
|
||||
and get their user_id stamped automatically.
|
||||
"""
|
||||
user = get_current_user()
|
||||
return user.id if user is not None else None
|
||||
|
||||
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401
|
||||
"""Write a single event — low-frequency path only.
|
||||
|
||||
This opens a dedicated transaction with a FOR UPDATE lock to
|
||||
assign a monotonic *seq*. For high-throughput writes use
|
||||
:meth:`put_batch`, which acquires the lock once for the whole
|
||||
batch. Currently the only caller is ``worker.run_agent`` for
|
||||
the initial ``human_message`` event (once per run).
|
||||
"""
|
||||
content, metadata = self._truncate_trace(category, content, metadata)
|
||||
if isinstance(content, dict):
|
||||
db_content = json.dumps(content, default=str, ensure_ascii=False)
|
||||
metadata = {**(metadata or {}), "content_is_dict": True}
|
||||
else:
|
||||
db_content = content
|
||||
owner_id = self._owner_from_context()
|
||||
async with self._sf() as session:
|
||||
async with session.begin():
|
||||
# Use FOR UPDATE to serialize seq assignment within a thread.
|
||||
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
|
||||
# the UNIQUE(thread_id, seq) constraint catches races there.
|
||||
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
|
||||
seq = (max_seq or 0) + 1
|
||||
row = RunEventRow(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
owner_id=owner_id,
|
||||
event_type=event_type,
|
||||
category=category,
|
||||
content=db_content,
|
||||
event_metadata=metadata,
|
||||
seq=seq,
|
||||
created_at=datetime.fromisoformat(created_at) if created_at else datetime.now(UTC),
|
||||
)
|
||||
session.add(row)
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def put_batch(self, events):
|
||||
if not events:
|
||||
return []
|
||||
owner_id = self._owner_from_context()
|
||||
async with self._sf() as session:
|
||||
async with session.begin():
|
||||
# Get max seq for the thread (assume all events in batch belong to same thread).
|
||||
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
|
||||
# the UNIQUE(thread_id, seq) constraint catches races there.
|
||||
thread_id = events[0]["thread_id"]
|
||||
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
|
||||
seq = max_seq or 0
|
||||
rows = []
|
||||
for e in events:
|
||||
seq += 1
|
||||
content = e.get("content", "")
|
||||
category = e.get("category", "trace")
|
||||
metadata = e.get("metadata")
|
||||
content, metadata = self._truncate_trace(category, content, metadata)
|
||||
if isinstance(content, dict):
|
||||
db_content = json.dumps(content, default=str, ensure_ascii=False)
|
||||
metadata = {**(metadata or {}), "content_is_dict": True}
|
||||
else:
|
||||
db_content = content
|
||||
row = RunEventRow(
|
||||
thread_id=e["thread_id"],
|
||||
run_id=e["run_id"],
|
||||
owner_id=e.get("owner_id", owner_id),
|
||||
event_type=e["event_type"],
|
||||
category=category,
|
||||
content=db_content,
|
||||
event_metadata=metadata,
|
||||
seq=seq,
|
||||
created_at=datetime.fromisoformat(e["created_at"]) if e.get("created_at") else datetime.now(UTC),
|
||||
)
|
||||
session.add(row)
|
||||
rows.append(row)
|
||||
return [self._row_to_dict(r) for r in rows]
|
||||
|
||||
async def list_messages(
|
||||
self,
|
||||
thread_id,
|
||||
*,
|
||||
limit=50,
|
||||
before_seq=None,
|
||||
after_seq=None,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages")
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||
if before_seq is not None:
|
||||
stmt = stmt.where(RunEventRow.seq < before_seq)
|
||||
if after_seq is not None:
|
||||
stmt = stmt.where(RunEventRow.seq > after_seq)
|
||||
|
||||
if after_seq is not None:
|
||||
# Forward pagination: first `limit` records after cursor
|
||||
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
else:
|
||||
# before_seq or default (latest): take last `limit` records, return ascending
|
||||
stmt = stmt.order_by(RunEventRow.seq.desc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
rows = list(result.scalars())
|
||||
return [self._row_to_dict(r) for r in reversed(rows)]
|
||||
|
||||
async def list_events(
|
||||
self,
|
||||
thread_id,
|
||||
run_id,
|
||||
*,
|
||||
event_types=None,
|
||||
limit=500,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_events")
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||
if event_types:
|
||||
stmt = stmt.where(RunEventRow.event_type.in_(event_types))
|
||||
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def list_messages_by_run(
|
||||
self,
|
||||
thread_id,
|
||||
run_id,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages_by_run")
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message")
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||
stmt = stmt.order_by(RunEventRow.seq.asc())
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def count_messages(
|
||||
self,
|
||||
thread_id,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.count_messages")
|
||||
stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||
async with self._sf() as session:
|
||||
return await session.scalar(stmt) or 0
|
||||
|
||||
async def delete_by_thread(
|
||||
self,
|
||||
thread_id,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_thread")
|
||||
async with self._sf() as session:
|
||||
count_conditions = [RunEventRow.thread_id == thread_id]
|
||||
if resolved_owner_id is not None:
|
||||
count_conditions.append(RunEventRow.owner_id == resolved_owner_id)
|
||||
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
||||
count = await session.scalar(count_stmt) or 0
|
||||
if count > 0:
|
||||
await session.execute(delete(RunEventRow).where(*count_conditions))
|
||||
await session.commit()
|
||||
return count
|
||||
|
||||
async def delete_by_run(
|
||||
self,
|
||||
thread_id,
|
||||
run_id,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_run")
|
||||
async with self._sf() as session:
|
||||
count_conditions = [RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id]
|
||||
if resolved_owner_id is not None:
|
||||
count_conditions.append(RunEventRow.owner_id == resolved_owner_id)
|
||||
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
||||
count = await session.scalar(count_stmt) or 0
|
||||
if count > 0:
|
||||
await session.execute(delete(RunEventRow).where(*count_conditions))
|
||||
await session.commit()
|
||||
return count
|
||||
@@ -0,0 +1,179 @@
|
||||
"""JSONL file-backed RunEventStore implementation.
|
||||
|
||||
Each run's events are stored in a single file:
|
||||
``.deer-flow/threads/{thread_id}/runs/{run_id}.jsonl``
|
||||
|
||||
All categories (message, trace, lifecycle) are in the same file.
|
||||
This backend is suitable for lightweight single-node deployments.
|
||||
|
||||
Known trade-off: ``list_messages()`` must scan all run files for a
|
||||
thread since messages from multiple runs need unified seq ordering.
|
||||
``list_events()`` reads only one file -- the fast path.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SAFE_ID_PATTERN = re.compile(r"^[A-Za-z0-9_\-]+$")
|
||||
|
||||
|
||||
class JsonlRunEventStore(RunEventStore):
|
||||
def __init__(self, base_dir: str | Path | None = None):
|
||||
self._base_dir = Path(base_dir) if base_dir else Path(".deer-flow")
|
||||
self._seq_counters: dict[str, int] = {} # thread_id -> current max seq
|
||||
|
||||
@staticmethod
|
||||
def _validate_id(value: str, label: str) -> str:
|
||||
"""Validate that an ID is safe for use in filesystem paths."""
|
||||
if not value or not _SAFE_ID_PATTERN.match(value):
|
||||
raise ValueError(f"Invalid {label}: must be alphanumeric/dash/underscore, got {value!r}")
|
||||
return value
|
||||
|
||||
def _thread_dir(self, thread_id: str) -> Path:
|
||||
self._validate_id(thread_id, "thread_id")
|
||||
return self._base_dir / "threads" / thread_id / "runs"
|
||||
|
||||
def _run_file(self, thread_id: str, run_id: str) -> Path:
|
||||
self._validate_id(run_id, "run_id")
|
||||
return self._thread_dir(thread_id) / f"{run_id}.jsonl"
|
||||
|
||||
def _next_seq(self, thread_id: str) -> int:
|
||||
self._seq_counters[thread_id] = self._seq_counters.get(thread_id, 0) + 1
|
||||
return self._seq_counters[thread_id]
|
||||
|
||||
def _ensure_seq_loaded(self, thread_id: str) -> None:
|
||||
"""Load max seq from existing files if not yet cached."""
|
||||
if thread_id in self._seq_counters:
|
||||
return
|
||||
max_seq = 0
|
||||
thread_dir = self._thread_dir(thread_id)
|
||||
if thread_dir.exists():
|
||||
for f in thread_dir.glob("*.jsonl"):
|
||||
for line in f.read_text(encoding="utf-8").strip().splitlines():
|
||||
try:
|
||||
record = json.loads(line)
|
||||
max_seq = max(max_seq, record.get("seq", 0))
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Skipping malformed JSONL line in %s", f)
|
||||
continue
|
||||
self._seq_counters[thread_id] = max_seq
|
||||
|
||||
def _write_record(self, record: dict) -> None:
|
||||
path = self._run_file(record["thread_id"], record["run_id"])
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(record, default=str, ensure_ascii=False) + "\n")
|
||||
|
||||
def _read_thread_events(self, thread_id: str) -> list[dict]:
|
||||
"""Read all events for a thread, sorted by seq."""
|
||||
events = []
|
||||
thread_dir = self._thread_dir(thread_id)
|
||||
if not thread_dir.exists():
|
||||
return events
|
||||
for f in sorted(thread_dir.glob("*.jsonl")):
|
||||
for line in f.read_text(encoding="utf-8").strip().splitlines():
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
events.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Skipping malformed JSONL line in %s", f)
|
||||
continue
|
||||
events.sort(key=lambda e: e.get("seq", 0))
|
||||
return events
|
||||
|
||||
def _read_run_events(self, thread_id: str, run_id: str) -> list[dict]:
|
||||
"""Read events for a specific run file."""
|
||||
path = self._run_file(thread_id, run_id)
|
||||
if not path.exists():
|
||||
return []
|
||||
events = []
|
||||
for line in path.read_text(encoding="utf-8").strip().splitlines():
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
events.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Skipping malformed JSONL line in %s", path)
|
||||
continue
|
||||
events.sort(key=lambda e: e.get("seq", 0))
|
||||
return events
|
||||
|
||||
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None):
|
||||
self._ensure_seq_loaded(thread_id)
|
||||
seq = self._next_seq(thread_id)
|
||||
record = {
|
||||
"thread_id": thread_id,
|
||||
"run_id": run_id,
|
||||
"event_type": event_type,
|
||||
"category": category,
|
||||
"content": content,
|
||||
"metadata": metadata or {},
|
||||
"seq": seq,
|
||||
"created_at": created_at or datetime.now(UTC).isoformat(),
|
||||
}
|
||||
self._write_record(record)
|
||||
return record
|
||||
|
||||
async def put_batch(self, events):
|
||||
if not events:
|
||||
return []
|
||||
results = []
|
||||
for ev in events:
|
||||
record = await self.put(**ev)
|
||||
results.append(record)
|
||||
return results
|
||||
|
||||
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
|
||||
all_events = self._read_thread_events(thread_id)
|
||||
messages = [e for e in all_events if e.get("category") == "message"]
|
||||
|
||||
if before_seq is not None:
|
||||
messages = [e for e in messages if e["seq"] < before_seq]
|
||||
return messages[-limit:]
|
||||
elif after_seq is not None:
|
||||
messages = [e for e in messages if e["seq"] > after_seq]
|
||||
return messages[:limit]
|
||||
else:
|
||||
return messages[-limit:]
|
||||
|
||||
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
|
||||
events = self._read_run_events(thread_id, run_id)
|
||||
if event_types is not None:
|
||||
events = [e for e in events if e.get("event_type") in event_types]
|
||||
return events[:limit]
|
||||
|
||||
async def list_messages_by_run(self, thread_id, run_id):
|
||||
events = self._read_run_events(thread_id, run_id)
|
||||
return [e for e in events if e.get("category") == "message"]
|
||||
|
||||
async def count_messages(self, thread_id):
|
||||
all_events = self._read_thread_events(thread_id)
|
||||
return sum(1 for e in all_events if e.get("category") == "message")
|
||||
|
||||
async def delete_by_thread(self, thread_id):
|
||||
all_events = self._read_thread_events(thread_id)
|
||||
count = len(all_events)
|
||||
thread_dir = self._thread_dir(thread_id)
|
||||
if thread_dir.exists():
|
||||
for f in thread_dir.glob("*.jsonl"):
|
||||
f.unlink()
|
||||
self._seq_counters.pop(thread_id, None)
|
||||
return count
|
||||
|
||||
async def delete_by_run(self, thread_id, run_id):
|
||||
events = self._read_run_events(thread_id, run_id)
|
||||
count = len(events)
|
||||
path = self._run_file(thread_id, run_id)
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
return count
|
||||
@@ -0,0 +1,120 @@
|
||||
"""In-memory RunEventStore. Used when run_events.backend=memory (default) and in tests.
|
||||
|
||||
Thread-safe for single-process async usage (no threading locks needed
|
||||
since all mutations happen within the same event loop).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
|
||||
|
||||
class MemoryRunEventStore(RunEventStore):
|
||||
def __init__(self) -> None:
|
||||
self._events: dict[str, list[dict]] = {} # thread_id -> sorted event list
|
||||
self._seq_counters: dict[str, int] = {} # thread_id -> last assigned seq
|
||||
|
||||
def _next_seq(self, thread_id: str) -> int:
|
||||
current = self._seq_counters.get(thread_id, 0)
|
||||
next_val = current + 1
|
||||
self._seq_counters[thread_id] = next_val
|
||||
return next_val
|
||||
|
||||
def _put_one(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
event_type: str,
|
||||
category: str,
|
||||
content: str | dict = "",
|
||||
metadata: dict | None = None,
|
||||
created_at: str | None = None,
|
||||
) -> dict:
|
||||
seq = self._next_seq(thread_id)
|
||||
record = {
|
||||
"thread_id": thread_id,
|
||||
"run_id": run_id,
|
||||
"event_type": event_type,
|
||||
"category": category,
|
||||
"content": content,
|
||||
"metadata": metadata or {},
|
||||
"seq": seq,
|
||||
"created_at": created_at or datetime.now(UTC).isoformat(),
|
||||
}
|
||||
self._events.setdefault(thread_id, []).append(record)
|
||||
return record
|
||||
|
||||
async def put(
|
||||
self,
|
||||
*,
|
||||
thread_id,
|
||||
run_id,
|
||||
event_type,
|
||||
category,
|
||||
content="",
|
||||
metadata=None,
|
||||
created_at=None,
|
||||
):
|
||||
return self._put_one(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
event_type=event_type,
|
||||
category=category,
|
||||
content=content,
|
||||
metadata=metadata,
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
async def put_batch(self, events):
|
||||
results = []
|
||||
for ev in events:
|
||||
record = self._put_one(**ev)
|
||||
results.append(record)
|
||||
return results
|
||||
|
||||
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
|
||||
all_events = self._events.get(thread_id, [])
|
||||
messages = [e for e in all_events if e["category"] == "message"]
|
||||
|
||||
if before_seq is not None:
|
||||
messages = [e for e in messages if e["seq"] < before_seq]
|
||||
# Take the last `limit` records
|
||||
return messages[-limit:]
|
||||
elif after_seq is not None:
|
||||
messages = [e for e in messages if e["seq"] > after_seq]
|
||||
return messages[:limit]
|
||||
else:
|
||||
# Return the latest `limit` records, ascending
|
||||
return messages[-limit:]
|
||||
|
||||
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
|
||||
all_events = self._events.get(thread_id, [])
|
||||
filtered = [e for e in all_events if e["run_id"] == run_id]
|
||||
if event_types is not None:
|
||||
filtered = [e for e in filtered if e["event_type"] in event_types]
|
||||
return filtered[:limit]
|
||||
|
||||
async def list_messages_by_run(self, thread_id, run_id):
|
||||
all_events = self._events.get(thread_id, [])
|
||||
return [e for e in all_events if e["run_id"] == run_id and e["category"] == "message"]
|
||||
|
||||
async def count_messages(self, thread_id):
|
||||
all_events = self._events.get(thread_id, [])
|
||||
return sum(1 for e in all_events if e["category"] == "message")
|
||||
|
||||
async def delete_by_thread(self, thread_id):
|
||||
events = self._events.pop(thread_id, [])
|
||||
self._seq_counters.pop(thread_id, None)
|
||||
return len(events)
|
||||
|
||||
async def delete_by_run(self, thread_id, run_id):
|
||||
all_events = self._events.get(thread_id, [])
|
||||
if not all_events:
|
||||
return 0
|
||||
remaining = [e for e in all_events if e["run_id"] != run_id]
|
||||
removed = len(all_events) - len(remaining)
|
||||
self._events[thread_id] = remaining
|
||||
return removed
|
||||
@@ -0,0 +1,471 @@
|
||||
"""Run event capture via LangChain callbacks.
|
||||
|
||||
RunJournal sits between LangChain's callback mechanism and the pluggable
|
||||
RunEventStore. It standardizes callback data into RunEvent records and
|
||||
handles token usage accumulation.
|
||||
|
||||
Key design decisions:
|
||||
- on_llm_new_token is NOT implemented -- only complete messages via on_llm_end
|
||||
- on_chat_model_start captures structured prompts as llm_request (OpenAI format)
|
||||
- on_llm_end emits llm_response in OpenAI Chat Completions format
|
||||
- Token usage accumulated in memory, written to RunRow on run completion
|
||||
- Caller identification via tags injection (lead_agent / subagent:{name} / middleware:{name})
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RunJournal(BaseCallbackHandler):
|
||||
"""LangChain callback handler that captures events to RunEventStore."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
run_id: str,
|
||||
thread_id: str,
|
||||
event_store: RunEventStore,
|
||||
*,
|
||||
track_token_usage: bool = True,
|
||||
flush_threshold: int = 20,
|
||||
):
|
||||
super().__init__()
|
||||
self.run_id = run_id
|
||||
self.thread_id = thread_id
|
||||
self._store = event_store
|
||||
self._track_tokens = track_token_usage
|
||||
self._flush_threshold = flush_threshold
|
||||
|
||||
# Write buffer
|
||||
self._buffer: list[dict] = []
|
||||
|
||||
# Token accumulators
|
||||
self._total_input_tokens = 0
|
||||
self._total_output_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._llm_call_count = 0
|
||||
self._lead_agent_tokens = 0
|
||||
self._subagent_tokens = 0
|
||||
self._middleware_tokens = 0
|
||||
|
||||
# Convenience fields
|
||||
self._last_ai_msg: str | None = None
|
||||
self._first_human_msg: str | None = None
|
||||
self._msg_count = 0
|
||||
|
||||
# Latency tracking
|
||||
self._llm_start_times: dict[str, float] = {} # langchain run_id -> start time
|
||||
|
||||
# LLM request/response tracking
|
||||
self._llm_call_index = 0
|
||||
self._cached_prompts: dict[str, list[dict]] = {} # langchain run_id -> OpenAI messages
|
||||
self._cached_models: dict[str, str] = {} # langchain run_id -> model name
|
||||
|
||||
# Tool call ID cache
|
||||
self._tool_call_ids: dict[str, str] = {} # langchain run_id -> tool_call_id
|
||||
|
||||
# -- Lifecycle callbacks --
|
||||
|
||||
def on_chain_start(self, serialized: dict, inputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
if kwargs.get("parent_run_id") is not None:
|
||||
return
|
||||
self._put(
|
||||
event_type="run_start",
|
||||
category="lifecycle",
|
||||
metadata={"input_preview": str(inputs)[:500]},
|
||||
)
|
||||
|
||||
def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
if kwargs.get("parent_run_id") is not None:
|
||||
return
|
||||
self._put(event_type="run_end", category="lifecycle", metadata={"status": "success"})
|
||||
self._flush_sync()
|
||||
|
||||
def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
if kwargs.get("parent_run_id") is not None:
|
||||
return
|
||||
self._put(
|
||||
event_type="run_error",
|
||||
category="lifecycle",
|
||||
content=str(error),
|
||||
metadata={"error_type": type(error).__name__},
|
||||
)
|
||||
self._flush_sync()
|
||||
|
||||
# -- LLM callbacks --
|
||||
|
||||
def on_chat_model_start(self, serialized: dict, messages: list[list], *, run_id: UUID, **kwargs: Any) -> None:
|
||||
"""Capture structured prompt messages for llm_request event."""
|
||||
from deerflow.runtime.converters import langchain_messages_to_openai
|
||||
|
||||
rid = str(run_id)
|
||||
self._llm_start_times[rid] = time.monotonic()
|
||||
self._llm_call_index += 1
|
||||
|
||||
model_name = serialized.get("name", "")
|
||||
self._cached_models[rid] = model_name
|
||||
|
||||
# Convert the first message list (LangChain passes list-of-lists)
|
||||
prompt_msgs = messages[0] if messages else []
|
||||
openai_msgs = langchain_messages_to_openai(prompt_msgs)
|
||||
self._cached_prompts[rid] = openai_msgs
|
||||
|
||||
caller = self._identify_caller(kwargs)
|
||||
self._put(
|
||||
event_type="llm_request",
|
||||
category="trace",
|
||||
content={"model": model_name, "messages": openai_msgs},
|
||||
metadata={"caller": caller, "llm_call_index": self._llm_call_index},
|
||||
)
|
||||
|
||||
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, **kwargs: Any) -> None:
|
||||
# Fallback: on_chat_model_start is preferred. This just tracks latency.
|
||||
self._llm_start_times[str(run_id)] = time.monotonic()
|
||||
|
||||
def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from deerflow.runtime.converters import langchain_to_openai_completion
|
||||
|
||||
try:
|
||||
message = response.generations[0][0].message
|
||||
except (IndexError, AttributeError):
|
||||
logger.debug("on_llm_end: could not extract message from response")
|
||||
return
|
||||
|
||||
caller = self._identify_caller(kwargs)
|
||||
|
||||
# Latency
|
||||
rid = str(run_id)
|
||||
start = self._llm_start_times.pop(rid, None)
|
||||
latency_ms = int((time.monotonic() - start) * 1000) if start else None
|
||||
|
||||
# Token usage from message
|
||||
usage = getattr(message, "usage_metadata", None)
|
||||
usage_dict = dict(usage) if usage else {}
|
||||
|
||||
# Resolve call index
|
||||
call_index = self._llm_call_index
|
||||
if rid not in self._cached_prompts:
|
||||
# Fallback: on_chat_model_start was not called
|
||||
self._llm_call_index += 1
|
||||
call_index = self._llm_call_index
|
||||
|
||||
# Clean up caches
|
||||
self._cached_prompts.pop(rid, None)
|
||||
self._cached_models.pop(rid, None)
|
||||
|
||||
# Trace event: llm_response (OpenAI completion format)
|
||||
content = getattr(message, "content", "")
|
||||
self._put(
|
||||
event_type="llm_response",
|
||||
category="trace",
|
||||
content=langchain_to_openai_completion(message),
|
||||
metadata={
|
||||
"caller": caller,
|
||||
"usage": usage_dict,
|
||||
"latency_ms": latency_ms,
|
||||
"llm_call_index": call_index,
|
||||
},
|
||||
)
|
||||
|
||||
# Message events: only lead_agent gets message-category events.
|
||||
# Content uses message.model_dump() to align with checkpoint format.
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
if caller == "lead_agent":
|
||||
resp_meta = getattr(message, "response_metadata", None) or {}
|
||||
model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None
|
||||
if tool_calls:
|
||||
# ai_tool_call: agent decided to use tools
|
||||
self._put(
|
||||
event_type="ai_tool_call",
|
||||
category="message",
|
||||
content=message.model_dump(),
|
||||
metadata={"model_name": model_name, "finish_reason": "tool_calls"},
|
||||
)
|
||||
elif isinstance(content, str) and content:
|
||||
# ai_message: final text reply
|
||||
self._put(
|
||||
event_type="ai_message",
|
||||
category="message",
|
||||
content=message.model_dump(),
|
||||
metadata={"model_name": model_name, "finish_reason": "stop"},
|
||||
)
|
||||
self._last_ai_msg = content
|
||||
self._msg_count += 1
|
||||
|
||||
# Token accumulation
|
||||
if self._track_tokens:
|
||||
input_tk = usage_dict.get("input_tokens", 0) or 0
|
||||
output_tk = usage_dict.get("output_tokens", 0) or 0
|
||||
total_tk = usage_dict.get("total_tokens", 0) or 0
|
||||
if total_tk == 0:
|
||||
total_tk = input_tk + output_tk
|
||||
if total_tk > 0:
|
||||
self._total_input_tokens += input_tk
|
||||
self._total_output_tokens += output_tk
|
||||
self._total_tokens += total_tk
|
||||
self._llm_call_count += 1
|
||||
if caller.startswith("subagent:"):
|
||||
self._subagent_tokens += total_tk
|
||||
elif caller.startswith("middleware:"):
|
||||
self._middleware_tokens += total_tk
|
||||
else:
|
||||
self._lead_agent_tokens += total_tk
|
||||
|
||||
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._llm_start_times.pop(str(run_id), None)
|
||||
self._put(event_type="llm_error", category="trace", content=str(error))
|
||||
|
||||
# -- Tool callbacks --
|
||||
|
||||
def on_tool_start(self, serialized: dict, input_str: str, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
tool_call_id = kwargs.get("tool_call_id")
|
||||
if tool_call_id:
|
||||
self._tool_call_ids[str(run_id)] = tool_call_id
|
||||
self._put(
|
||||
event_type="tool_start",
|
||||
category="trace",
|
||||
metadata={
|
||||
"tool_name": serialized.get("name", ""),
|
||||
"tool_call_id": tool_call_id,
|
||||
"args": str(input_str)[:2000],
|
||||
},
|
||||
)
|
||||
|
||||
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
# Extract fields from ToolMessage object when LangChain provides one.
|
||||
# LangChain's _format_output wraps tool results into a ToolMessage
|
||||
# with tool_call_id, name, status, and artifact — more complete than
|
||||
# what kwargs alone provides.
|
||||
if isinstance(output, ToolMessage):
|
||||
tool_call_id = output.tool_call_id or kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
||||
tool_name = output.name or kwargs.get("name", "")
|
||||
status = getattr(output, "status", "success") or "success"
|
||||
content_str = output.content if isinstance(output.content, str) else str(output.content)
|
||||
# Use model_dump() for checkpoint-aligned message content.
|
||||
# Override tool_call_id if it was resolved from cache.
|
||||
msg_content = output.model_dump()
|
||||
if msg_content.get("tool_call_id") != tool_call_id:
|
||||
msg_content["tool_call_id"] = tool_call_id
|
||||
else:
|
||||
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
||||
tool_name = kwargs.get("name", "")
|
||||
status = "success"
|
||||
content_str = str(output)
|
||||
# Construct checkpoint-aligned dict when output is a plain string.
|
||||
msg_content = ToolMessage(
|
||||
content=content_str,
|
||||
tool_call_id=tool_call_id or "",
|
||||
name=tool_name,
|
||||
status=status,
|
||||
).model_dump()
|
||||
|
||||
# Trace event (always)
|
||||
self._put(
|
||||
event_type="tool_end",
|
||||
category="trace",
|
||||
content=content_str,
|
||||
metadata={
|
||||
"tool_name": tool_name,
|
||||
"tool_call_id": tool_call_id,
|
||||
"status": status,
|
||||
},
|
||||
)
|
||||
|
||||
# Message event: tool_result (checkpoint-aligned model_dump format)
|
||||
self._put(
|
||||
event_type="tool_result",
|
||||
category="message",
|
||||
content=msg_content,
|
||||
metadata={"tool_name": tool_name, "status": status},
|
||||
)
|
||||
|
||||
def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
||||
tool_name = kwargs.get("name", "")
|
||||
|
||||
# Trace event
|
||||
self._put(
|
||||
event_type="tool_error",
|
||||
category="trace",
|
||||
content=str(error),
|
||||
metadata={
|
||||
"tool_name": tool_name,
|
||||
"tool_call_id": tool_call_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Message event: tool_result with error status (checkpoint-aligned)
|
||||
msg_content = ToolMessage(
|
||||
content=str(error),
|
||||
tool_call_id=tool_call_id or "",
|
||||
name=tool_name,
|
||||
status="error",
|
||||
).model_dump()
|
||||
self._put(
|
||||
event_type="tool_result",
|
||||
category="message",
|
||||
content=msg_content,
|
||||
metadata={"tool_name": tool_name, "status": "error"},
|
||||
)
|
||||
|
||||
# -- Custom event callback --
|
||||
|
||||
def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from deerflow.runtime.serialization import serialize_lc_object
|
||||
|
||||
if name == "summarization":
|
||||
data_dict = data if isinstance(data, dict) else {}
|
||||
self._put(
|
||||
event_type="summarization",
|
||||
category="trace",
|
||||
content=data_dict.get("summary", ""),
|
||||
metadata={
|
||||
"replaced_message_ids": data_dict.get("replaced_message_ids", []),
|
||||
"replaced_count": data_dict.get("replaced_count", 0),
|
||||
},
|
||||
)
|
||||
self._put(
|
||||
event_type="middleware:summarize",
|
||||
category="middleware",
|
||||
content={"role": "system", "content": data_dict.get("summary", "")},
|
||||
metadata={"replaced_count": data_dict.get("replaced_count", 0)},
|
||||
)
|
||||
else:
|
||||
event_data = serialize_lc_object(data) if not isinstance(data, dict) else data
|
||||
self._put(
|
||||
event_type=name,
|
||||
category="trace",
|
||||
metadata=event_data if isinstance(event_data, dict) else {"data": event_data},
|
||||
)
|
||||
|
||||
# -- Internal methods --
|
||||
|
||||
def _put(self, *, event_type: str, category: str, content: str | dict = "", metadata: dict | None = None) -> None:
|
||||
self._buffer.append(
|
||||
{
|
||||
"thread_id": self.thread_id,
|
||||
"run_id": self.run_id,
|
||||
"event_type": event_type,
|
||||
"category": category,
|
||||
"content": content,
|
||||
"metadata": metadata or {},
|
||||
"created_at": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
)
|
||||
if len(self._buffer) >= self._flush_threshold:
|
||||
self._flush_sync()
|
||||
|
||||
def _flush_sync(self) -> None:
|
||||
"""Best-effort flush of buffer to RunEventStore.
|
||||
|
||||
BaseCallbackHandler methods are synchronous. If an event loop is
|
||||
running we schedule an async ``put_batch``; otherwise the events
|
||||
stay in the buffer and are flushed later by the async ``flush()``
|
||||
call in the worker's ``finally`` block.
|
||||
"""
|
||||
if not self._buffer:
|
||||
return
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# No event loop — keep events in buffer for later async flush.
|
||||
return
|
||||
batch = self._buffer.copy()
|
||||
self._buffer.clear()
|
||||
task = loop.create_task(self._flush_async(batch))
|
||||
task.add_done_callback(self._on_flush_done)
|
||||
|
||||
async def _flush_async(self, batch: list[dict]) -> None:
|
||||
try:
|
||||
await self._store.put_batch(batch)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to flush %d events for run %s — returning to buffer",
|
||||
len(batch),
|
||||
self.run_id,
|
||||
exc_info=True,
|
||||
)
|
||||
# Return failed events to buffer for retry on next flush
|
||||
self._buffer = batch + self._buffer
|
||||
|
||||
@staticmethod
|
||||
def _on_flush_done(task: asyncio.Task) -> None:
|
||||
if task.cancelled():
|
||||
return
|
||||
exc = task.exception()
|
||||
if exc:
|
||||
logger.warning("Journal flush task failed: %s", exc)
|
||||
|
||||
def _identify_caller(self, kwargs: dict) -> str:
|
||||
for tag in kwargs.get("tags") or []:
|
||||
if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"):
|
||||
return tag
|
||||
# Default to lead_agent: the main agent graph does not inject
|
||||
# callback tags, while subagents and middleware explicitly tag
|
||||
# themselves.
|
||||
return "lead_agent"
|
||||
|
||||
# -- Public methods (called by worker) --
|
||||
|
||||
def set_first_human_message(self, content: str) -> None:
|
||||
"""Record the first human message for convenience fields."""
|
||||
self._first_human_msg = content[:2000] if content else None
|
||||
|
||||
def record_middleware(self, tag: str, *, name: str, hook: str, action: str, changes: dict) -> None:
|
||||
"""Record a middleware state-change event.
|
||||
|
||||
Called by middleware implementations when they perform a meaningful
|
||||
state change (e.g., title generation, summarization, HITL approval).
|
||||
Pure-observation middleware should not call this.
|
||||
|
||||
Args:
|
||||
tag: Short identifier for the middleware (e.g., "title", "summarize",
|
||||
"guardrail"). Used to form event_type="middleware:{tag}".
|
||||
name: Full middleware class name.
|
||||
hook: Lifecycle hook that triggered the action (e.g., "after_model").
|
||||
action: Specific action performed (e.g., "generate_title").
|
||||
changes: Dict describing the state changes made.
|
||||
"""
|
||||
self._put(
|
||||
event_type=f"middleware:{tag}",
|
||||
category="middleware",
|
||||
content={"name": name, "hook": hook, "action": action, "changes": changes},
|
||||
)
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Force flush remaining buffer. Called in worker's finally block."""
|
||||
if self._buffer:
|
||||
batch = self._buffer.copy()
|
||||
self._buffer.clear()
|
||||
await self._store.put_batch(batch)
|
||||
|
||||
def get_completion_data(self) -> dict:
|
||||
"""Return accumulated token and message data for run completion."""
|
||||
return {
|
||||
"total_input_tokens": self._total_input_tokens,
|
||||
"total_output_tokens": self._total_output_tokens,
|
||||
"total_tokens": self._total_tokens,
|
||||
"llm_call_count": self._llm_call_count,
|
||||
"lead_agent_tokens": self._lead_agent_tokens,
|
||||
"subagent_tokens": self._subagent_tokens,
|
||||
"middleware_tokens": self._middleware_tokens,
|
||||
"message_count": self._msg_count,
|
||||
"last_ai_message": self._last_ai_msg,
|
||||
"first_human_message": self._first_human_msg,
|
||||
}
|
||||
@@ -2,11 +2,12 @@
|
||||
|
||||
from .manager import ConflictError, RunManager, RunRecord, UnsupportedStrategyError
|
||||
from .schemas import DisconnectMode, RunStatus
|
||||
from .worker import run_agent
|
||||
from .worker import RunContext, run_agent
|
||||
|
||||
__all__ = [
|
||||
"ConflictError",
|
||||
"DisconnectMode",
|
||||
"RunContext",
|
||||
"RunManager",
|
||||
"RunRecord",
|
||||
"RunStatus",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""In-memory run registry."""
|
||||
"""In-memory run registry with optional persistent RunStore backing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -7,9 +7,13 @@ import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .schemas import DisconnectMode, RunStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.runtime.runs.store.base import RunStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -38,11 +42,44 @@ class RunRecord:
|
||||
|
||||
|
||||
class RunManager:
|
||||
"""In-memory run registry. All mutations are protected by an asyncio lock."""
|
||||
"""In-memory run registry with optional persistent RunStore backing.
|
||||
|
||||
def __init__(self) -> None:
|
||||
All mutations are protected by an asyncio lock. When a ``store`` is
|
||||
provided, serializable metadata is also persisted to the store so
|
||||
that run history survives process restarts.
|
||||
"""
|
||||
|
||||
def __init__(self, store: RunStore | None = None) -> None:
|
||||
self._runs: dict[str, RunRecord] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._store = store
|
||||
|
||||
async def _persist_to_store(self, record: RunRecord, *, follow_up_to_run_id: str | None = None) -> None:
|
||||
"""Best-effort persist run record to backing store."""
|
||||
if self._store is None:
|
||||
return
|
||||
try:
|
||||
await self._store.put(
|
||||
record.run_id,
|
||||
thread_id=record.thread_id,
|
||||
assistant_id=record.assistant_id,
|
||||
status=record.status.value,
|
||||
multitask_strategy=record.multitask_strategy,
|
||||
metadata=record.metadata or {},
|
||||
kwargs=record.kwargs or {},
|
||||
created_at=record.created_at,
|
||||
follow_up_to_run_id=follow_up_to_run_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
|
||||
|
||||
async def update_run_completion(self, run_id: str, **kwargs) -> None:
|
||||
"""Persist token usage and completion data to the backing store."""
|
||||
if self._store is not None:
|
||||
try:
|
||||
await self._store.update_run_completion(run_id, **kwargs)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
|
||||
|
||||
async def create(
|
||||
self,
|
||||
@@ -53,6 +90,7 @@ class RunManager:
|
||||
metadata: dict | None = None,
|
||||
kwargs: dict | None = None,
|
||||
multitask_strategy: str = "reject",
|
||||
follow_up_to_run_id: str | None = None,
|
||||
) -> RunRecord:
|
||||
"""Create a new pending run and register it."""
|
||||
run_id = str(uuid.uuid4())
|
||||
@@ -71,6 +109,7 @@ class RunManager:
|
||||
)
|
||||
async with self._lock:
|
||||
self._runs[run_id] = record
|
||||
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
|
||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||
return record
|
||||
|
||||
@@ -96,6 +135,11 @@ class RunManager:
|
||||
record.updated_at = _now_iso()
|
||||
if error is not None:
|
||||
record.error = error
|
||||
if self._store is not None:
|
||||
try:
|
||||
await self._store.update_status(run_id, status.value, error=error)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist status update for run %s", run_id, exc_info=True)
|
||||
logger.info("Run %s -> %s", run_id, status.value)
|
||||
|
||||
async def cancel(self, run_id: str, *, action: str = "interrupt") -> bool:
|
||||
@@ -132,6 +176,7 @@ class RunManager:
|
||||
metadata: dict | None = None,
|
||||
kwargs: dict | None = None,
|
||||
multitask_strategy: str = "reject",
|
||||
follow_up_to_run_id: str | None = None,
|
||||
) -> RunRecord:
|
||||
"""Atomically check for inflight runs and create a new one.
|
||||
|
||||
@@ -185,6 +230,7 @@ class RunManager:
|
||||
)
|
||||
self._runs[run_id] = record
|
||||
|
||||
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
|
||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||
return record
|
||||
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
from deerflow.runtime.runs.store.base import RunStore
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
__all__ = ["MemoryRunStore", "RunStore"]
|
||||
@@ -0,0 +1,96 @@
|
||||
"""Abstract interface for run metadata storage.
|
||||
|
||||
RunManager depends on this interface. Implementations:
|
||||
- MemoryRunStore: in-memory dict (development, tests)
|
||||
- Future: RunRepository backed by SQLAlchemy ORM
|
||||
|
||||
All methods accept an optional owner_id for user isolation.
|
||||
When owner_id is None, no user filtering is applied (single-user mode).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from typing import Any
|
||||
|
||||
|
||||
class RunStore(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
async def put(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
thread_id: str,
|
||||
assistant_id: str | None = None,
|
||||
owner_id: str | None = None,
|
||||
status: str = "pending",
|
||||
multitask_strategy: str = "reject",
|
||||
metadata: dict[str, Any] | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
error: str | None = None,
|
||||
created_at: str | None = None,
|
||||
follow_up_to_run_id: str | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get(self, run_id: str) -> dict[str, Any] | None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_by_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
owner_id: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[dict[str, Any]]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def update_status(
|
||||
self,
|
||||
run_id: str,
|
||||
status: str,
|
||||
*,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete(self, run_id: str) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def update_run_completion(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
status: str,
|
||||
total_input_tokens: int = 0,
|
||||
total_output_tokens: int = 0,
|
||||
total_tokens: int = 0,
|
||||
llm_call_count: int = 0,
|
||||
lead_agent_tokens: int = 0,
|
||||
subagent_tokens: int = 0,
|
||||
middleware_tokens: int = 0,
|
||||
message_count: int = 0,
|
||||
last_ai_message: str | None = None,
|
||||
first_human_message: str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||
"""Aggregate token usage for completed runs in a thread.
|
||||
|
||||
Returns a dict with keys: total_tokens, total_input_tokens,
|
||||
total_output_tokens, total_runs, by_model (model_name → {tokens, runs}),
|
||||
by_caller ({lead_agent, subagent, middleware}).
|
||||
"""
|
||||
pass
|
||||
@@ -0,0 +1,100 @@
|
||||
"""In-memory RunStore. Used when database.backend=memory (default) and in tests.
|
||||
|
||||
Equivalent to the original RunManager._runs dict behavior.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from deerflow.runtime.runs.store.base import RunStore
|
||||
|
||||
|
||||
class MemoryRunStore(RunStore):
|
||||
def __init__(self) -> None:
|
||||
self._runs: dict[str, dict[str, Any]] = {}
|
||||
|
||||
async def put(
|
||||
self,
|
||||
run_id,
|
||||
*,
|
||||
thread_id,
|
||||
assistant_id=None,
|
||||
owner_id=None,
|
||||
status="pending",
|
||||
multitask_strategy="reject",
|
||||
metadata=None,
|
||||
kwargs=None,
|
||||
error=None,
|
||||
created_at=None,
|
||||
follow_up_to_run_id=None,
|
||||
):
|
||||
now = datetime.now(UTC).isoformat()
|
||||
self._runs[run_id] = {
|
||||
"run_id": run_id,
|
||||
"thread_id": thread_id,
|
||||
"assistant_id": assistant_id,
|
||||
"owner_id": owner_id,
|
||||
"status": status,
|
||||
"multitask_strategy": multitask_strategy,
|
||||
"metadata": metadata or {},
|
||||
"kwargs": kwargs or {},
|
||||
"error": error,
|
||||
"follow_up_to_run_id": follow_up_to_run_id,
|
||||
"created_at": created_at or now,
|
||||
"updated_at": now,
|
||||
}
|
||||
|
||||
async def get(self, run_id):
|
||||
return self._runs.get(run_id)
|
||||
|
||||
async def list_by_thread(self, thread_id, *, owner_id=None, limit=100):
|
||||
results = [r for r in self._runs.values() if r["thread_id"] == thread_id and (owner_id is None or r.get("owner_id") == owner_id)]
|
||||
results.sort(key=lambda r: r["created_at"], reverse=True)
|
||||
return results[:limit]
|
||||
|
||||
async def update_status(self, run_id, status, *, error=None):
|
||||
if run_id in self._runs:
|
||||
self._runs[run_id]["status"] = status
|
||||
if error is not None:
|
||||
self._runs[run_id]["error"] = error
|
||||
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
|
||||
|
||||
async def delete(self, run_id):
|
||||
self._runs.pop(run_id, None)
|
||||
|
||||
async def update_run_completion(self, run_id, *, status, **kwargs):
|
||||
if run_id in self._runs:
|
||||
self._runs[run_id]["status"] = status
|
||||
for key, value in kwargs.items():
|
||||
if value is not None:
|
||||
self._runs[run_id][key] = value
|
||||
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
|
||||
|
||||
async def list_pending(self, *, before=None):
|
||||
now = before or datetime.now(UTC).isoformat()
|
||||
results = [r for r in self._runs.values() if r["status"] == "pending" and r["created_at"] <= now]
|
||||
results.sort(key=lambda r: r["created_at"])
|
||||
return results
|
||||
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in ("success", "error")]
|
||||
by_model: dict[str, dict] = {}
|
||||
for r in completed:
|
||||
model = r.get("model_name") or "unknown"
|
||||
entry = by_model.setdefault(model, {"tokens": 0, "runs": 0})
|
||||
entry["tokens"] += r.get("total_tokens", 0)
|
||||
entry["runs"] += 1
|
||||
return {
|
||||
"total_tokens": sum(r.get("total_tokens", 0) for r in completed),
|
||||
"total_input_tokens": sum(r.get("total_input_tokens", 0) for r in completed),
|
||||
"total_output_tokens": sum(r.get("total_output_tokens", 0) for r in completed),
|
||||
"total_runs": len(completed),
|
||||
"by_model": by_model,
|
||||
"by_caller": {
|
||||
"lead_agent": sum(r.get("lead_agent_tokens", 0) for r in completed),
|
||||
"subagent": sum(r.get("subagent_tokens", 0) for r in completed),
|
||||
"middleware": sum(r.get("middleware_tokens", 0) for r in completed),
|
||||
},
|
||||
}
|
||||
@@ -19,7 +19,11 @@ import asyncio
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from deerflow.runtime.serialization import serialize
|
||||
from deerflow.runtime.stream_bridge import StreamBridge
|
||||
@@ -33,13 +37,29 @@ logger = logging.getLogger(__name__)
|
||||
_VALID_LG_MODES = {"values", "updates", "checkpoints", "tasks", "debug", "messages", "custom"}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunContext:
|
||||
"""Infrastructure dependencies for a single agent run.
|
||||
|
||||
Groups checkpointer, store, and persistence-related singletons so that
|
||||
``run_agent`` (and any future callers) receive one object instead of a
|
||||
growing list of keyword arguments.
|
||||
"""
|
||||
|
||||
checkpointer: Any
|
||||
store: Any | None = field(default=None)
|
||||
event_store: Any | None = field(default=None)
|
||||
run_events_config: Any | None = field(default=None)
|
||||
thread_meta_repo: Any | None = field(default=None)
|
||||
follow_up_to_run_id: str | None = field(default=None)
|
||||
|
||||
|
||||
async def run_agent(
|
||||
bridge: StreamBridge,
|
||||
run_manager: RunManager,
|
||||
record: RunRecord,
|
||||
*,
|
||||
checkpointer: Any,
|
||||
store: Any | None = None,
|
||||
ctx: RunContext,
|
||||
agent_factory: Any,
|
||||
graph_input: dict,
|
||||
config: dict,
|
||||
@@ -50,6 +70,14 @@ async def run_agent(
|
||||
) -> None:
|
||||
"""Execute an agent in the background, publishing events to *bridge*."""
|
||||
|
||||
# Unpack infrastructure dependencies from RunContext.
|
||||
checkpointer = ctx.checkpointer
|
||||
store = ctx.store
|
||||
event_store = ctx.event_store
|
||||
run_events_config = ctx.run_events_config
|
||||
thread_meta_repo = ctx.thread_meta_repo
|
||||
follow_up_to_run_id = ctx.follow_up_to_run_id
|
||||
|
||||
run_id = record.run_id
|
||||
thread_id = record.thread_id
|
||||
requested_modes: set[str] = set(stream_modes or ["values"])
|
||||
@@ -57,6 +85,35 @@ async def run_agent(
|
||||
pre_run_snapshot: dict[str, Any] | None = None
|
||||
snapshot_capture_failed = False
|
||||
|
||||
# Initialize RunJournal for event capture
|
||||
journal = None
|
||||
if event_store is not None:
|
||||
from deerflow.runtime.journal import RunJournal
|
||||
|
||||
journal = RunJournal(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
event_store=event_store,
|
||||
track_token_usage=getattr(run_events_config, "track_token_usage", True),
|
||||
)
|
||||
|
||||
# Write human_message event (model_dump format, aligned with checkpoint)
|
||||
human_msg = _extract_human_message(graph_input)
|
||||
if human_msg is not None:
|
||||
msg_metadata = {}
|
||||
if follow_up_to_run_id:
|
||||
msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id
|
||||
await event_store.put(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content=human_msg.model_dump(),
|
||||
metadata=msg_metadata or None,
|
||||
)
|
||||
content = human_msg.content
|
||||
journal.set_first_human_message(content if isinstance(content, str) else str(content))
|
||||
|
||||
# Track whether "events" was requested but skipped
|
||||
if "events" in requested_modes:
|
||||
logger.info(
|
||||
@@ -110,6 +167,11 @@ async def run_agent(
|
||||
config["context"].setdefault("thread_id", thread_id)
|
||||
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
||||
|
||||
# Inject RunJournal as a LangChain callback handler.
|
||||
# on_llm_end captures token usage; on_chain_start/end captures lifecycle.
|
||||
if journal is not None:
|
||||
config.setdefault("callbacks", []).append(journal)
|
||||
|
||||
runnable_config = RunnableConfig(**config)
|
||||
agent = agent_factory(config=runnable_config)
|
||||
|
||||
@@ -236,6 +298,37 @@ async def run_agent(
|
||||
)
|
||||
|
||||
finally:
|
||||
# Flush any buffered journal events and persist completion data
|
||||
if journal is not None:
|
||||
try:
|
||||
await journal.flush()
|
||||
except Exception:
|
||||
logger.warning("Failed to flush journal for run %s", run_id, exc_info=True)
|
||||
|
||||
# Persist token usage + convenience fields to RunStore
|
||||
completion = journal.get_completion_data()
|
||||
await run_manager.update_run_completion(run_id, status=record.status.value, **completion)
|
||||
|
||||
# Sync title from checkpoint to threads_meta.display_name
|
||||
if checkpointer is not None:
|
||||
try:
|
||||
ckpt_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
ckpt_tuple = await checkpointer.aget_tuple(ckpt_config)
|
||||
if ckpt_tuple is not None:
|
||||
ckpt = getattr(ckpt_tuple, "checkpoint", {}) or {}
|
||||
title = ckpt.get("channel_values", {}).get("title")
|
||||
if title:
|
||||
await thread_meta_repo.update_display_name(thread_id, title)
|
||||
except Exception:
|
||||
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id)
|
||||
|
||||
# Update threads_meta status based on run outcome
|
||||
try:
|
||||
final_status = "idle" if record.status == RunStatus.success else record.status.value
|
||||
await thread_meta_repo.update_status(thread_id, final_status)
|
||||
except Exception:
|
||||
logger.debug("Failed to update thread_meta status for %s (non-fatal)", thread_id)
|
||||
|
||||
await bridge.publish_end(run_id)
|
||||
asyncio.create_task(bridge.cleanup(run_id, delay=60))
|
||||
|
||||
@@ -355,6 +448,31 @@ def _lg_mode_to_sse_event(mode: str) -> str:
|
||||
return mode
|
||||
|
||||
|
||||
def _extract_human_message(graph_input: dict) -> HumanMessage | None:
|
||||
"""Extract or construct a HumanMessage from graph_input for event recording.
|
||||
|
||||
Returns a LangChain HumanMessage so callers can use .model_dump() to get
|
||||
the checkpoint-aligned serialization format.
|
||||
"""
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
messages = graph_input.get("messages")
|
||||
if not messages:
|
||||
return None
|
||||
last = messages[-1] if isinstance(messages, list) else messages
|
||||
if isinstance(last, HumanMessage):
|
||||
return last
|
||||
if isinstance(last, str):
|
||||
return HumanMessage(content=last) if last else None
|
||||
if hasattr(last, "content"):
|
||||
content = last.content
|
||||
return HumanMessage(content=content)
|
||||
if isinstance(last, dict):
|
||||
content = last.get("content", "")
|
||||
return HumanMessage(content=content) if content else None
|
||||
return None
|
||||
|
||||
|
||||
def _unpack_stream_item(
|
||||
item: Any,
|
||||
lg_modes: list[str],
|
||||
|
||||
@@ -0,0 +1,148 @@
|
||||
"""Request-scoped user context for owner-based authorization.
|
||||
|
||||
This module holds a :class:`~contextvars.ContextVar` that the gateway's
|
||||
auth middleware sets after a successful authentication. Repository
|
||||
methods read the contextvar via a sentinel default parameter, letting
|
||||
routers stay free of ``owner_id`` boilerplate.
|
||||
|
||||
Three-state semantics for the repository ``owner_id`` parameter (the
|
||||
consumer side of this module lives in ``deerflow.persistence.*``):
|
||||
|
||||
- ``_AUTO`` (module-private sentinel, default): read from contextvar;
|
||||
raise :class:`RuntimeError` if unset.
|
||||
- Explicit ``str``: use the provided value, overriding contextvar.
|
||||
- Explicit ``None``: no WHERE clause — used only by migration scripts
|
||||
and admin CLIs that intentionally bypass isolation.
|
||||
|
||||
Dependency direction
|
||||
--------------------
|
||||
``persistence`` (lower layer) reads from this module; ``gateway.auth``
|
||||
(higher layer) writes to it. ``CurrentUser`` is defined here as a
|
||||
:class:`typing.Protocol` so that ``persistence`` never needs to import
|
||||
the concrete ``User`` class from ``gateway.auth.models``. Any object
|
||||
with an ``.id: str`` attribute structurally satisfies the protocol.
|
||||
|
||||
Asyncio semantics
|
||||
-----------------
|
||||
``ContextVar`` is task-local under asyncio, not thread-local. Each
|
||||
FastAPI request runs in its own task, so the context is naturally
|
||||
isolated. ``asyncio.create_task`` and ``asyncio.to_thread`` inherit the
|
||||
parent task's context, which is typically the intended behaviour; if
|
||||
a background task must *not* see the foreground user, wrap it with
|
||||
``contextvars.copy_context()`` to get a clean copy.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextvars import ContextVar, Token
|
||||
from typing import Final, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class CurrentUser(Protocol):
|
||||
"""Structural type for the current authenticated user.
|
||||
|
||||
Any object with an ``.id: str`` attribute satisfies this protocol.
|
||||
Concrete implementations live in ``app.gateway.auth.models.User``.
|
||||
"""
|
||||
|
||||
id: str
|
||||
|
||||
|
||||
_current_user: Final[ContextVar[CurrentUser | None]] = ContextVar("deerflow_current_user", default=None)
|
||||
|
||||
|
||||
def set_current_user(user: CurrentUser) -> Token[CurrentUser | None]:
|
||||
"""Set the current user for this async task.
|
||||
|
||||
Returns a reset token that should be passed to
|
||||
:func:`reset_current_user` in a ``finally`` block to restore the
|
||||
previous context.
|
||||
"""
|
||||
return _current_user.set(user)
|
||||
|
||||
|
||||
def reset_current_user(token: Token[CurrentUser | None]) -> None:
|
||||
"""Restore the context to the state captured by ``token``."""
|
||||
_current_user.reset(token)
|
||||
|
||||
|
||||
def get_current_user() -> CurrentUser | None:
|
||||
"""Return the current user, or ``None`` if unset.
|
||||
|
||||
Safe to call in any context. Used by code paths that can proceed
|
||||
without a user (e.g. migration scripts, public endpoints).
|
||||
"""
|
||||
return _current_user.get()
|
||||
|
||||
|
||||
def require_current_user() -> CurrentUser:
|
||||
"""Return the current user, or raise :class:`RuntimeError`.
|
||||
|
||||
Used by repository code that must not be called outside a
|
||||
request-authenticated context. The error message is phrased so
|
||||
that a caller debugging a stack trace can locate the offending
|
||||
code path.
|
||||
"""
|
||||
user = _current_user.get()
|
||||
if user is None:
|
||||
raise RuntimeError("repository accessed without user context")
|
||||
return user
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sentinel-based owner_id resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Repository methods accept an ``owner_id`` keyword-only argument that
|
||||
# defaults to ``AUTO``. The three possible values drive distinct
|
||||
# behaviours; see the docstring on :func:`resolve_owner_id`.
|
||||
|
||||
|
||||
class _AutoSentinel:
|
||||
"""Singleton marker meaning 'resolve owner_id from contextvar'."""
|
||||
|
||||
_instance: _AutoSentinel | None = None
|
||||
|
||||
def __new__(cls) -> _AutoSentinel:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<AUTO>"
|
||||
|
||||
|
||||
AUTO: Final[_AutoSentinel] = _AutoSentinel()
|
||||
|
||||
|
||||
def resolve_owner_id(
|
||||
value: str | None | _AutoSentinel,
|
||||
*,
|
||||
method_name: str = "repository method",
|
||||
) -> str | None:
|
||||
"""Resolve the owner_id parameter passed to a repository method.
|
||||
|
||||
Three-state semantics:
|
||||
|
||||
- :data:`AUTO` (default): read from contextvar; raise
|
||||
:class:`RuntimeError` if no user is in context. This is the
|
||||
common case for request-scoped calls.
|
||||
- Explicit ``str``: use the provided id verbatim, overriding any
|
||||
contextvar value. Useful for tests and admin-override flows.
|
||||
- Explicit ``None``: no filter — the repository should skip the
|
||||
owner_id WHERE clause entirely. Reserved for migration scripts
|
||||
and CLI tools that intentionally bypass isolation.
|
||||
"""
|
||||
if isinstance(value, _AutoSentinel):
|
||||
user = _current_user.get()
|
||||
if user is None:
|
||||
raise RuntimeError(f"{method_name} called with owner_id=AUTO but no user context is set; pass an explicit owner_id, set the contextvar via auth middleware, or opt out with owner_id=None for migration/CLI paths.")
|
||||
# Coerce to ``str`` at the boundary: ``User.id`` is typed as
|
||||
# ``UUID`` for the API surface, but the persistence layer
|
||||
# stores ``owner_id`` as ``String(64)`` and aiosqlite cannot
|
||||
# bind a raw UUID object to a VARCHAR column ("type 'UUID' is
|
||||
# not supported"). Honour the documented return type here
|
||||
# rather than ripple a type change through every caller.
|
||||
return str(user.id)
|
||||
return value
|
||||
@@ -33,10 +33,19 @@ dependencies = [
|
||||
"langchain-google-genai>=4.2.1",
|
||||
"langgraph-checkpoint-sqlite>=3.0.3",
|
||||
"langgraph-sdk>=0.1.51",
|
||||
"sqlalchemy[asyncio]>=2.0,<3.0",
|
||||
"aiosqlite>=0.19",
|
||||
"alembic>=1.13",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
ollama = ["langchain-ollama>=0.3.0"]
|
||||
postgres = [
|
||||
"asyncpg>=0.29",
|
||||
"langgraph-checkpoint-postgres>=3.0.5",
|
||||
"psycopg[binary]>=3.3.3",
|
||||
"psycopg-pool>=3.3.0",
|
||||
]
|
||||
pymupdf = ["pymupdf4llm>=0.0.17"]
|
||||
|
||||
[build-system]
|
||||
|
||||
@@ -17,11 +17,24 @@ dependencies = [
|
||||
"langgraph-sdk>=0.1.51",
|
||||
"markdown-to-mrkdwn>=0.3.1",
|
||||
"wecom-aibot-python-sdk>=0.1.6",
|
||||
"bcrypt>=4.0.0",
|
||||
"pyjwt>=2.9.0",
|
||||
"email-validator>=2.0.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
postgres = [
|
||||
"deerflow-harness[postgres]",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = ["pytest>=8.0.0", "ruff>=0.14.11"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
markers = [
|
||||
"no_auto_user: disable the conftest autouse contextvar fixture for this test",
|
||||
]
|
||||
|
||||
[tool.uv.workspace]
|
||||
members = ["packages/harness"]
|
||||
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
"""Helpers for router-level tests that need a stubbed auth context.
|
||||
|
||||
The production gateway runs ``AuthMiddleware`` (validates the JWT cookie)
|
||||
ahead of every router, plus ``@require_permission(owner_check=True)``
|
||||
decorators that read ``request.state.auth`` and call
|
||||
``thread_meta_repo.check_access``. Router-level unit tests construct
|
||||
**bare** FastAPI apps that include only one router — they have neither
|
||||
the auth middleware nor a real thread_meta_repo, so the decorators raise
|
||||
401 (TestClient path) or ValueError (direct-call path).
|
||||
|
||||
This module provides two surfaces:
|
||||
|
||||
1. :func:`make_authed_test_app` — wraps ``FastAPI()`` with a tiny
|
||||
``BaseHTTPMiddleware`` that stamps a fake user / AuthContext on every
|
||||
request, plus a permissive ``thread_meta_repo`` mock on
|
||||
``app.state``. Use from TestClient-based router tests.
|
||||
|
||||
2. :func:`call_unwrapped` — invokes the underlying function bypassing
|
||||
the ``@require_permission`` decorator chain by walking ``__wrapped__``.
|
||||
Use from direct-call tests that previously imported the route
|
||||
function and called it positionally.
|
||||
|
||||
Both helpers are deliberately permissive: they never deny a request.
|
||||
Tests that want to verify the *auth boundary itself* (e.g.
|
||||
``test_auth_middleware``, ``test_auth_type_system``) build their own
|
||||
apps with the real middleware — those should not use this module.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import ParamSpec, TypeVar
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI, Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.authz import AuthContext, Permissions
|
||||
|
||||
# Default permission set granted to the stub user. Mirrors `_ALL_PERMISSIONS`
|
||||
# in authz.py — kept inline so the tests don't import a private symbol.
|
||||
_STUB_PERMISSIONS: list[str] = [
|
||||
Permissions.THREADS_READ,
|
||||
Permissions.THREADS_WRITE,
|
||||
Permissions.THREADS_DELETE,
|
||||
Permissions.RUNS_CREATE,
|
||||
Permissions.RUNS_READ,
|
||||
Permissions.RUNS_CANCEL,
|
||||
]
|
||||
|
||||
|
||||
def _make_stub_user() -> User:
|
||||
"""A deterministic test user — same shape as production, fresh UUID."""
|
||||
return User(
|
||||
email="router-test@example.com",
|
||||
password_hash="x",
|
||||
system_role="user",
|
||||
id=uuid4(),
|
||||
)
|
||||
|
||||
|
||||
class _StubAuthMiddleware(BaseHTTPMiddleware):
|
||||
"""Stamp a fake user / AuthContext onto every request.
|
||||
|
||||
Mirrors what production ``AuthMiddleware`` does after the JWT decode
|
||||
+ DB lookup short-circuit, so ``@require_permission`` finds an
|
||||
authenticated context and skips its own re-authentication path.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp, user_factory: Callable[[], User]) -> None:
|
||||
super().__init__(app)
|
||||
self._user_factory = user_factory
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
user = self._user_factory()
|
||||
request.state.user = user
|
||||
request.state.auth = AuthContext(user=user, permissions=list(_STUB_PERMISSIONS))
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
def make_authed_test_app(
|
||||
*,
|
||||
user_factory: Callable[[], User] | None = None,
|
||||
owner_check_passes: bool = True,
|
||||
) -> FastAPI:
|
||||
"""Build a FastAPI test app with stub auth + permissive thread_meta_repo.
|
||||
|
||||
Args:
|
||||
user_factory: Override the default test user. Must return a fully
|
||||
populated :class:`User`. Useful for cross-user isolation tests
|
||||
that need a stable id across requests.
|
||||
owner_check_passes: When True (default), ``thread_meta_repo.check_access``
|
||||
returns True for every call so ``@require_permission(owner_check=True)``
|
||||
never blocks the route under test. Pass False to verify that
|
||||
permission failures surface correctly.
|
||||
|
||||
Returns:
|
||||
A ``FastAPI`` app with the stub middleware installed and
|
||||
``app.state.thread_meta_repo`` set to a permissive mock. The
|
||||
caller is still responsible for ``app.include_router(...)``.
|
||||
"""
|
||||
factory = user_factory or _make_stub_user
|
||||
app = FastAPI()
|
||||
app.add_middleware(_StubAuthMiddleware, user_factory=factory)
|
||||
|
||||
repo = MagicMock()
|
||||
repo.check_access = AsyncMock(return_value=owner_check_passes)
|
||||
app.state.thread_meta_repo = repo
|
||||
|
||||
return app
|
||||
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
|
||||
def call_unwrapped(decorated: Callable[_P, _R], /, *args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
"""Invoke the underlying function of a ``@require_permission``-decorated route.
|
||||
|
||||
``functools.wraps`` sets ``__wrapped__`` on each layer; we walk all
|
||||
the way down to the original handler, bypassing every authz +
|
||||
require_auth wrapper. Use from tests that need to call route
|
||||
functions directly (without TestClient) and don't want to construct
|
||||
a fake ``Request`` just to satisfy the decorator. The ``ParamSpec``
|
||||
propagates the wrapped route's signature so call sites still get
|
||||
parameter checking despite the unwrapping.
|
||||
"""
|
||||
fn: Callable = decorated
|
||||
while hasattr(fn, "__wrapped__"):
|
||||
fn = fn.__wrapped__ # type: ignore[attr-defined]
|
||||
return fn(*args, **kwargs)
|
||||
@@ -7,6 +7,7 @@ issues when unit-testing lightweight config/registry code in isolation.
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@@ -53,3 +54,44 @@ def provisioner_module():
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auto-set user context for every test unless marked no_auto_user
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Repository methods read ``owner_id`` from a contextvar by default
|
||||
# (see ``deerflow.runtime.user_context``). Without this fixture, every
|
||||
# pre-existing persistence test would raise RuntimeError because the
|
||||
# contextvar is unset. The fixture sets a default test user on every
|
||||
# test; tests that explicitly want to verify behaviour *without* a user
|
||||
# context should mark themselves ``@pytest.mark.no_auto_user``.
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _auto_user_context(request):
|
||||
"""Inject a default ``test-user-autouse`` into the contextvar.
|
||||
|
||||
Opt-out via ``@pytest.mark.no_auto_user``. Uses lazy import so that
|
||||
tests which don't touch the persistence layer never pay the cost
|
||||
of importing runtime.user_context.
|
||||
"""
|
||||
if request.node.get_closest_marker("no_auto_user"):
|
||||
yield
|
||||
return
|
||||
|
||||
try:
|
||||
from deerflow.runtime.user_context import (
|
||||
reset_current_user,
|
||||
set_current_user,
|
||||
)
|
||||
except ImportError:
|
||||
yield
|
||||
return
|
||||
|
||||
user = SimpleNamespace(id="test-user-autouse", email="test@local")
|
||||
token = set_current_user(user)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
reset_current_user(token)
|
||||
|
||||
@@ -3,7 +3,7 @@ import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from _router_auth_helpers import call_unwrapped, make_authed_test_app
|
||||
from fastapi.testclient import TestClient
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import FileResponse
|
||||
@@ -36,7 +36,7 @@ def test_get_artifact_reads_utf8_text_file_on_windows_locale(tmp_path, monkeypat
|
||||
monkeypatch.setattr(artifacts_router, "resolve_thread_virtual_path", lambda _thread_id, _path: artifact_path)
|
||||
|
||||
request = _make_request()
|
||||
response = asyncio.run(artifacts_router.get_artifact("thread-1", "mnt/user-data/outputs/note.txt", request))
|
||||
response = asyncio.run(call_unwrapped(artifacts_router.get_artifact, "thread-1", "mnt/user-data/outputs/note.txt", request))
|
||||
|
||||
assert bytes(response.body).decode("utf-8") == text
|
||||
assert response.media_type == "text/plain"
|
||||
@@ -49,7 +49,7 @@ def test_get_artifact_forces_download_for_active_content(tmp_path, monkeypatch,
|
||||
|
||||
monkeypatch.setattr(artifacts_router, "resolve_thread_virtual_path", lambda _thread_id, _path: artifact_path)
|
||||
|
||||
response = asyncio.run(artifacts_router.get_artifact("thread-1", f"mnt/user-data/outputs/{filename}", _make_request()))
|
||||
response = asyncio.run(call_unwrapped(artifacts_router.get_artifact, "thread-1", f"mnt/user-data/outputs/{filename}", _make_request()))
|
||||
|
||||
assert isinstance(response, FileResponse)
|
||||
assert response.headers.get("content-disposition", "").startswith("attachment;")
|
||||
@@ -63,7 +63,7 @@ def test_get_artifact_forces_download_for_active_content_in_skill_archive(tmp_pa
|
||||
|
||||
monkeypatch.setattr(artifacts_router, "resolve_thread_virtual_path", lambda _thread_id, _path: skill_path)
|
||||
|
||||
response = asyncio.run(artifacts_router.get_artifact("thread-1", f"mnt/user-data/outputs/sample.skill/{filename}", _make_request()))
|
||||
response = asyncio.run(call_unwrapped(artifacts_router.get_artifact, "thread-1", f"mnt/user-data/outputs/sample.skill/{filename}", _make_request()))
|
||||
|
||||
assert response.headers.get("content-disposition", "").startswith("attachment;")
|
||||
assert bytes(response.body) == content.encode("utf-8")
|
||||
@@ -75,7 +75,7 @@ def test_get_artifact_download_false_does_not_force_attachment(tmp_path, monkeyp
|
||||
|
||||
monkeypatch.setattr(artifacts_router, "resolve_thread_virtual_path", lambda _thread_id, _path: artifact_path)
|
||||
|
||||
app = FastAPI()
|
||||
app = make_authed_test_app()
|
||||
app.include_router(artifacts_router.router)
|
||||
|
||||
with TestClient(app) as client:
|
||||
@@ -93,7 +93,7 @@ def test_get_artifact_download_true_forces_attachment_for_skill_archive(tmp_path
|
||||
|
||||
monkeypatch.setattr(artifacts_router, "resolve_thread_virtual_path", lambda _thread_id, _path: skill_path)
|
||||
|
||||
app = FastAPI()
|
||||
app = make_authed_test_app()
|
||||
app.include_router(artifacts_router.router)
|
||||
|
||||
with TestClient(app) as client:
|
||||
|
||||
@@ -0,0 +1,654 @@
|
||||
"""Tests for authentication module: JWT, password hashing, AuthContext, and authz decorators."""
|
||||
|
||||
from datetime import timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.auth import create_access_token, decode_token, hash_password, verify_password
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.authz import (
|
||||
AuthContext,
|
||||
Permissions,
|
||||
get_auth_context,
|
||||
require_auth,
|
||||
require_permission,
|
||||
)
|
||||
|
||||
# ── Password Hashing ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_hash_password_and_verify():
|
||||
"""Hashing and verification round-trip."""
|
||||
password = "s3cr3tP@ssw0rd!"
|
||||
hashed = hash_password(password)
|
||||
assert hashed != password
|
||||
assert verify_password(password, hashed) is True
|
||||
assert verify_password("wrongpassword", hashed) is False
|
||||
|
||||
|
||||
def test_hash_password_different_each_time():
|
||||
"""bcrypt generates unique salts, so same password has different hashes."""
|
||||
password = "testpassword"
|
||||
h1 = hash_password(password)
|
||||
h2 = hash_password(password)
|
||||
assert h1 != h2 # Different salts
|
||||
# But both verify correctly
|
||||
assert verify_password(password, h1) is True
|
||||
assert verify_password(password, h2) is True
|
||||
|
||||
|
||||
def test_verify_password_rejects_empty():
|
||||
"""Empty password should not verify."""
|
||||
hashed = hash_password("nonempty")
|
||||
assert verify_password("", hashed) is False
|
||||
|
||||
|
||||
# ── JWT ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_create_and_decode_token():
|
||||
"""JWT creation and decoding round-trip."""
|
||||
user_id = str(uuid4())
|
||||
# Set a valid JWT secret for this test
|
||||
import os
|
||||
|
||||
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||
token = create_access_token(user_id)
|
||||
assert isinstance(token, str)
|
||||
|
||||
payload = decode_token(token)
|
||||
assert payload is not None
|
||||
assert payload.sub == user_id
|
||||
|
||||
|
||||
def test_decode_token_expired():
|
||||
"""Expired token returns TokenError.EXPIRED."""
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
user_id = str(uuid4())
|
||||
# Create token that expires immediately
|
||||
token = create_access_token(user_id, expires_delta=timedelta(seconds=-1))
|
||||
payload = decode_token(token)
|
||||
assert payload == TokenError.EXPIRED
|
||||
|
||||
|
||||
def test_decode_token_invalid():
|
||||
"""Invalid token returns TokenError."""
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
assert isinstance(decode_token("not.a.valid.token"), TokenError)
|
||||
assert isinstance(decode_token(""), TokenError)
|
||||
assert isinstance(decode_token("completely-wrong"), TokenError)
|
||||
|
||||
|
||||
def test_create_token_custom_expiry():
|
||||
"""Custom expiry is respected."""
|
||||
user_id = str(uuid4())
|
||||
token = create_access_token(user_id, expires_delta=timedelta(hours=1))
|
||||
payload = decode_token(token)
|
||||
assert payload is not None
|
||||
assert payload.sub == user_id
|
||||
|
||||
|
||||
# ── AuthContext ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_auth_context_unauthenticated():
|
||||
"""AuthContext with no user."""
|
||||
ctx = AuthContext(user=None, permissions=[])
|
||||
assert ctx.is_authenticated is False
|
||||
assert ctx.has_permission("threads", "read") is False
|
||||
|
||||
|
||||
def test_auth_context_authenticated_no_perms():
|
||||
"""AuthContext with user but no permissions."""
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
ctx = AuthContext(user=user, permissions=[])
|
||||
assert ctx.is_authenticated is True
|
||||
assert ctx.has_permission("threads", "read") is False
|
||||
|
||||
|
||||
def test_auth_context_has_permission():
|
||||
"""AuthContext permission checking."""
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
perms = [Permissions.THREADS_READ, Permissions.THREADS_WRITE]
|
||||
ctx = AuthContext(user=user, permissions=perms)
|
||||
assert ctx.has_permission("threads", "read") is True
|
||||
assert ctx.has_permission("threads", "write") is True
|
||||
assert ctx.has_permission("threads", "delete") is False
|
||||
assert ctx.has_permission("runs", "read") is False
|
||||
|
||||
|
||||
def test_auth_context_require_user_raises():
|
||||
"""require_user raises 401 when not authenticated."""
|
||||
ctx = AuthContext(user=None, permissions=[])
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
ctx.require_user()
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
def test_auth_context_require_user_returns_user():
|
||||
"""require_user returns user when authenticated."""
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
ctx = AuthContext(user=user, permissions=[])
|
||||
returned = ctx.require_user()
|
||||
assert returned == user
|
||||
|
||||
|
||||
# ── get_auth_context helper ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_get_auth_context_not_set():
|
||||
"""get_auth_context returns None when auth not set on request."""
|
||||
mock_request = MagicMock()
|
||||
# Make getattr return None (simulating attribute not set)
|
||||
mock_request.state = MagicMock()
|
||||
del mock_request.state.auth
|
||||
assert get_auth_context(mock_request) is None
|
||||
|
||||
|
||||
def test_get_auth_context_set():
|
||||
"""get_auth_context returns the AuthContext from request."""
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
ctx = AuthContext(user=user, permissions=[Permissions.THREADS_READ])
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.auth = ctx
|
||||
|
||||
assert get_auth_context(mock_request) == ctx
|
||||
|
||||
|
||||
# ── require_auth decorator ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_require_auth_sets_auth_context():
|
||||
"""require_auth sets auth context on request from cookie."""
|
||||
from fastapi import Request
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/test")
|
||||
@require_auth
|
||||
async def endpoint(request: Request):
|
||||
ctx = get_auth_context(request)
|
||||
return {"authenticated": ctx.is_authenticated}
|
||||
|
||||
with TestClient(app) as client:
|
||||
# No cookie → anonymous
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["authenticated"] is False
|
||||
|
||||
|
||||
def test_require_auth_requires_request_param():
|
||||
"""require_auth raises ValueError if request parameter is missing."""
|
||||
import asyncio
|
||||
|
||||
@require_auth
|
||||
async def bad_endpoint(): # Missing `request` parameter
|
||||
pass
|
||||
|
||||
with pytest.raises(ValueError, match="require_auth decorator requires 'request' parameter"):
|
||||
asyncio.run(bad_endpoint())
|
||||
|
||||
|
||||
# ── require_permission decorator ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_require_permission_requires_auth():
|
||||
"""require_permission raises 401 when not authenticated."""
|
||||
from fastapi import Request
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/test")
|
||||
@require_permission("threads", "read")
|
||||
async def endpoint(request: Request):
|
||||
return {"ok": True}
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Authentication required" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_require_permission_denies_wrong_permission():
|
||||
"""User without required permission gets 403."""
|
||||
from fastapi import Request
|
||||
|
||||
app = FastAPI()
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
|
||||
@app.get("/test")
|
||||
@require_permission("threads", "delete")
|
||||
async def endpoint(request: Request):
|
||||
return {"ok": True}
|
||||
|
||||
mock_auth = AuthContext(user=user, permissions=[Permissions.THREADS_READ])
|
||||
|
||||
with patch("app.gateway.authz._authenticate", return_value=mock_auth):
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 403
|
||||
assert "Permission denied" in response.json()["detail"]
|
||||
|
||||
|
||||
# ── Weak JWT secret warning ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
# ── User Model Fields ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_user_model_has_needs_setup_default_false():
|
||||
"""New users default to needs_setup=False."""
|
||||
user = User(email="test@example.com", password_hash="hash")
|
||||
assert user.needs_setup is False
|
||||
|
||||
|
||||
def test_user_model_has_token_version_default_zero():
|
||||
"""New users default to token_version=0."""
|
||||
user = User(email="test@example.com", password_hash="hash")
|
||||
assert user.token_version == 0
|
||||
|
||||
|
||||
def test_user_model_needs_setup_true():
|
||||
"""Auto-created admin has needs_setup=True."""
|
||||
user = User(email="admin@example.com", password_hash="hash", needs_setup=True)
|
||||
assert user.needs_setup is True
|
||||
|
||||
|
||||
def test_sqlite_round_trip_new_fields():
|
||||
"""needs_setup and token_version survive create → read round-trip.
|
||||
|
||||
Uses the shared persistence engine (same one threads_meta, runs,
|
||||
run_events, and feedback use). The old separate .deer-flow/users.db
|
||||
file is gone.
|
||||
"""
|
||||
import asyncio
|
||||
import tempfile
|
||||
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
|
||||
async def _run() -> None:
|
||||
from deerflow.persistence.engine import (
|
||||
close_engine,
|
||||
get_session_factory,
|
||||
init_engine,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
url = f"sqlite+aiosqlite:///{tmpdir}/scratch.db"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=tmpdir)
|
||||
try:
|
||||
repo = SQLiteUserRepository(get_session_factory())
|
||||
user = User(
|
||||
email="setup@test.com",
|
||||
password_hash="fakehash",
|
||||
system_role="admin",
|
||||
needs_setup=True,
|
||||
token_version=3,
|
||||
)
|
||||
created = await repo.create_user(user)
|
||||
assert created.needs_setup is True
|
||||
assert created.token_version == 3
|
||||
|
||||
fetched = await repo.get_user_by_email("setup@test.com")
|
||||
assert fetched is not None
|
||||
assert fetched.needs_setup is True
|
||||
assert fetched.token_version == 3
|
||||
|
||||
fetched.needs_setup = False
|
||||
fetched.token_version = 4
|
||||
await repo.update_user(fetched)
|
||||
refetched = await repo.get_user_by_id(str(fetched.id))
|
||||
assert refetched is not None
|
||||
assert refetched.needs_setup is False
|
||||
assert refetched.token_version == 4
|
||||
finally:
|
||||
await close_engine()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_update_user_raises_when_row_concurrently_deleted(tmp_path):
|
||||
"""Concurrent-delete during update_user must hard-fail, not silently no-op.
|
||||
|
||||
Earlier the SQLite repo returned the input unchanged when the row was
|
||||
missing, making a phantom success path that admin password reset
|
||||
callers (`reset_admin`, `_ensure_admin_user`) would happily log as
|
||||
'password reset'. The new contract: raise ``UserNotFoundError`` so
|
||||
a vanished row never looks like a successful update.
|
||||
"""
|
||||
import asyncio
|
||||
import tempfile
|
||||
|
||||
from app.gateway.auth.repositories.base import UserNotFoundError
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
|
||||
async def _run() -> None:
|
||||
from deerflow.persistence.engine import (
|
||||
close_engine,
|
||||
get_session_factory,
|
||||
init_engine,
|
||||
)
|
||||
from deerflow.persistence.user.model import UserRow
|
||||
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
url = f"sqlite+aiosqlite:///{d}/scratch.db"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=d)
|
||||
try:
|
||||
sf = get_session_factory()
|
||||
repo = SQLiteUserRepository(sf)
|
||||
user = User(
|
||||
email="ghost@test.com",
|
||||
password_hash="fakehash",
|
||||
system_role="user",
|
||||
)
|
||||
created = await repo.create_user(user)
|
||||
|
||||
# Simulate "row vanished underneath us" by deleting the row
|
||||
# via the raw ORM session, then attempt to update.
|
||||
async with sf() as session:
|
||||
row = await session.get(UserRow, str(created.id))
|
||||
assert row is not None
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
|
||||
created.needs_setup = True
|
||||
with pytest.raises(UserNotFoundError):
|
||||
await repo.update_user(created)
|
||||
finally:
|
||||
await close_engine()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
# ── Token Versioning ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_jwt_encodes_ver():
|
||||
"""JWT payload includes ver field."""
|
||||
import os
|
||||
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||
token = create_access_token(str(uuid4()), token_version=3)
|
||||
payload = decode_token(token)
|
||||
assert not isinstance(payload, TokenError)
|
||||
assert payload.ver == 3
|
||||
|
||||
|
||||
def test_jwt_default_ver_zero():
|
||||
"""JWT ver defaults to 0."""
|
||||
import os
|
||||
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||
token = create_access_token(str(uuid4()))
|
||||
payload = decode_token(token)
|
||||
assert not isinstance(payload, TokenError)
|
||||
assert payload.ver == 0
|
||||
|
||||
|
||||
def test_token_version_mismatch_rejects():
|
||||
"""Token with stale ver is rejected by get_current_user_from_request."""
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||
|
||||
user_id = str(uuid4())
|
||||
token = create_access_token(user_id, token_version=0)
|
||||
|
||||
mock_user = User(id=user_id, email="test@example.com", password_hash="hash", token_version=1)
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.cookies = {"access_token": token}
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider") as mock_provider_fn:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_user = AsyncMock(return_value=mock_user)
|
||||
mock_provider_fn.return_value = mock_provider
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "revoked" in str(exc_info.value.detail).lower()
|
||||
|
||||
|
||||
# ── change-password extension ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_change_password_request_accepts_new_email():
|
||||
"""ChangePasswordRequest model accepts optional new_email."""
|
||||
from app.gateway.routers.auth import ChangePasswordRequest
|
||||
|
||||
req = ChangePasswordRequest(
|
||||
current_password="old",
|
||||
new_password="newpassword",
|
||||
new_email="new@example.com",
|
||||
)
|
||||
assert req.new_email == "new@example.com"
|
||||
|
||||
|
||||
def test_change_password_request_new_email_optional():
|
||||
"""ChangePasswordRequest model works without new_email."""
|
||||
from app.gateway.routers.auth import ChangePasswordRequest
|
||||
|
||||
req = ChangePasswordRequest(current_password="old", new_password="newpassword")
|
||||
assert req.new_email is None
|
||||
|
||||
|
||||
def test_login_response_includes_needs_setup():
|
||||
"""LoginResponse includes needs_setup field."""
|
||||
from app.gateway.routers.auth import LoginResponse
|
||||
|
||||
resp = LoginResponse(expires_in=3600, needs_setup=True)
|
||||
assert resp.needs_setup is True
|
||||
resp2 = LoginResponse(expires_in=3600)
|
||||
assert resp2.needs_setup is False
|
||||
|
||||
|
||||
# ── Rate Limiting ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_rate_limiter_allows_under_limit():
|
||||
"""Requests under the limit are allowed."""
|
||||
from app.gateway.routers.auth import _check_rate_limit, _login_attempts
|
||||
|
||||
_login_attempts.clear()
|
||||
_check_rate_limit("192.168.1.1") # Should not raise
|
||||
|
||||
|
||||
def test_rate_limiter_blocks_after_max_failures():
|
||||
"""IP is blocked after 5 consecutive failures."""
|
||||
from app.gateway.routers.auth import _check_rate_limit, _login_attempts, _record_login_failure
|
||||
|
||||
_login_attempts.clear()
|
||||
ip = "10.0.0.1"
|
||||
for _ in range(5):
|
||||
_record_login_failure(ip)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
_check_rate_limit(ip)
|
||||
assert exc_info.value.status_code == 429
|
||||
|
||||
|
||||
def test_rate_limiter_resets_on_success():
|
||||
"""Successful login clears the failure counter."""
|
||||
from app.gateway.routers.auth import _check_rate_limit, _login_attempts, _record_login_failure, _record_login_success
|
||||
|
||||
_login_attempts.clear()
|
||||
ip = "10.0.0.2"
|
||||
for _ in range(4):
|
||||
_record_login_failure(ip)
|
||||
_record_login_success(ip)
|
||||
_check_rate_limit(ip) # Should not raise
|
||||
|
||||
|
||||
# ── Client IP extraction ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_get_client_ip_direct_connection_no_proxy(monkeypatch):
|
||||
"""Direct mode (no AUTH_TRUSTED_PROXIES): use TCP peer regardless of X-Real-IP."""
|
||||
monkeypatch.delenv("AUTH_TRUSTED_PROXIES", raising=False)
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "203.0.113.42"
|
||||
req.headers = {}
|
||||
assert _get_client_ip(req) == "203.0.113.42"
|
||||
|
||||
|
||||
def test_get_client_ip_x_real_ip_ignored_when_no_trusted_proxy(monkeypatch):
|
||||
"""X-Real-IP is silently ignored if AUTH_TRUSTED_PROXIES is unset.
|
||||
|
||||
This closes the bypass where any client could rotate X-Real-IP per
|
||||
request to dodge per-IP rate limits in dev / direct mode.
|
||||
"""
|
||||
monkeypatch.delenv("AUTH_TRUSTED_PROXIES", raising=False)
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "127.0.0.1"
|
||||
req.headers = {"x-real-ip": "203.0.113.42"}
|
||||
assert _get_client_ip(req) == "127.0.0.1"
|
||||
|
||||
|
||||
def test_get_client_ip_x_real_ip_honored_from_trusted_proxy(monkeypatch):
|
||||
"""X-Real-IP is honored when the TCP peer matches AUTH_TRUSTED_PROXIES."""
|
||||
monkeypatch.setenv("AUTH_TRUSTED_PROXIES", "10.0.0.0/8")
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "10.5.6.7" # in trusted CIDR
|
||||
req.headers = {"x-real-ip": "203.0.113.42"}
|
||||
assert _get_client_ip(req) == "203.0.113.42"
|
||||
|
||||
|
||||
def test_get_client_ip_x_real_ip_rejected_from_untrusted_peer(monkeypatch):
|
||||
"""X-Real-IP is rejected when the TCP peer is NOT in the trusted list."""
|
||||
monkeypatch.setenv("AUTH_TRUSTED_PROXIES", "10.0.0.0/8")
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "8.8.8.8" # NOT in trusted CIDR
|
||||
req.headers = {"x-real-ip": "203.0.113.42"} # client trying to spoof
|
||||
assert _get_client_ip(req) == "8.8.8.8"
|
||||
|
||||
|
||||
def test_get_client_ip_xff_never_honored(monkeypatch):
|
||||
"""X-Forwarded-For is never used; only X-Real-IP from a trusted peer."""
|
||||
monkeypatch.setenv("AUTH_TRUSTED_PROXIES", "10.0.0.0/8")
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "10.0.0.1"
|
||||
req.headers = {"x-forwarded-for": "198.51.100.5"} # no x-real-ip
|
||||
assert _get_client_ip(req) == "10.0.0.1"
|
||||
|
||||
|
||||
def test_get_client_ip_invalid_trusted_proxy_entry_skipped(monkeypatch, caplog):
|
||||
"""Garbage entries in AUTH_TRUSTED_PROXIES are warned and skipped."""
|
||||
monkeypatch.setenv("AUTH_TRUSTED_PROXIES", "not-an-ip,10.0.0.0/8")
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "10.5.6.7"
|
||||
req.headers = {"x-real-ip": "203.0.113.42"}
|
||||
assert _get_client_ip(req) == "203.0.113.42" # valid entry still works
|
||||
|
||||
|
||||
def test_get_client_ip_no_client_returns_unknown(monkeypatch):
|
||||
"""No request.client → 'unknown' marker (no crash)."""
|
||||
monkeypatch.delenv("AUTH_TRUSTED_PROXIES", raising=False)
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client = None
|
||||
req.headers = {}
|
||||
assert _get_client_ip(req) == "unknown"
|
||||
|
||||
|
||||
# ── Common-password blocklist ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_register_rejects_literal_password():
|
||||
"""Pydantic validator rejects 'password' as a registration password."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.gateway.routers.auth import RegisterRequest
|
||||
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
RegisterRequest(email="x@example.com", password="password")
|
||||
assert "too common" in str(exc.value)
|
||||
|
||||
|
||||
def test_register_rejects_common_password_case_insensitive():
|
||||
"""Case variants of common passwords are also rejected."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.gateway.routers.auth import RegisterRequest
|
||||
|
||||
for variant in ["PASSWORD", "Password1", "qwerty123", "letmein1"]:
|
||||
with pytest.raises(ValidationError):
|
||||
RegisterRequest(email="x@example.com", password=variant)
|
||||
|
||||
|
||||
def test_register_accepts_strong_password():
|
||||
"""A non-blocklisted password of length >=8 is accepted."""
|
||||
from app.gateway.routers.auth import RegisterRequest
|
||||
|
||||
req = RegisterRequest(email="x@example.com", password="Tr0ub4dor&3-Horse")
|
||||
assert req.password == "Tr0ub4dor&3-Horse"
|
||||
|
||||
|
||||
def test_change_password_rejects_common_password():
|
||||
"""The same blocklist applies to change-password."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.gateway.routers.auth import ChangePasswordRequest
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
ChangePasswordRequest(current_password="anything", new_password="iloveyou")
|
||||
|
||||
|
||||
def test_password_blocklist_keeps_short_passwords_for_length_check():
|
||||
"""Short passwords still fail the min_length check (not the blocklist)."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.gateway.routers.auth import RegisterRequest
|
||||
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
RegisterRequest(email="x@example.com", password="abc")
|
||||
# the length check should fire, not the blocklist
|
||||
assert "at least 8 characters" in str(exc.value)
|
||||
|
||||
|
||||
# ── Weak JWT secret warning ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_missing_jwt_secret_generates_ephemeral(monkeypatch, caplog):
|
||||
"""get_auth_config() auto-generates an ephemeral secret when AUTH_JWT_SECRET is unset."""
|
||||
import logging
|
||||
|
||||
import app.gateway.auth.config as config_module
|
||||
|
||||
config_module._auth_config = None
|
||||
monkeypatch.delenv("AUTH_JWT_SECRET", raising=False)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
config = config_module.get_auth_config()
|
||||
|
||||
assert config.jwt_secret # non-empty ephemeral secret
|
||||
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
|
||||
|
||||
# Cleanup
|
||||
config_module._auth_config = None
|
||||
@@ -0,0 +1,54 @@
|
||||
"""Tests for AuthConfig typed configuration."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.gateway.auth.config import AuthConfig
|
||||
|
||||
|
||||
def test_auth_config_defaults():
|
||||
config = AuthConfig(jwt_secret="test-secret-key-123")
|
||||
assert config.token_expiry_days == 7
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_range():
|
||||
AuthConfig(jwt_secret="s", token_expiry_days=1)
|
||||
AuthConfig(jwt_secret="s", token_expiry_days=30)
|
||||
with pytest.raises(Exception):
|
||||
AuthConfig(jwt_secret="s", token_expiry_days=0)
|
||||
with pytest.raises(Exception):
|
||||
AuthConfig(jwt_secret="s", token_expiry_days=31)
|
||||
|
||||
|
||||
def test_auth_config_from_env():
|
||||
env = {"AUTH_JWT_SECRET": "test-jwt-secret-from-env"}
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
import app.gateway.auth.config as cfg
|
||||
|
||||
old = cfg._auth_config
|
||||
cfg._auth_config = None
|
||||
try:
|
||||
config = cfg.get_auth_config()
|
||||
assert config.jwt_secret == "test-jwt-secret-from-env"
|
||||
finally:
|
||||
cfg._auth_config = old
|
||||
|
||||
|
||||
def test_auth_config_missing_secret_generates_ephemeral(caplog):
|
||||
import logging
|
||||
|
||||
import app.gateway.auth.config as cfg
|
||||
|
||||
old = cfg._auth_config
|
||||
cfg._auth_config = None
|
||||
try:
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
os.environ.pop("AUTH_JWT_SECRET", None)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
config = cfg.get_auth_config()
|
||||
assert config.jwt_secret
|
||||
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
|
||||
finally:
|
||||
cfg._auth_config = old
|
||||
@@ -0,0 +1,75 @@
|
||||
"""Tests for auth error types and typed decode_token."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt as pyjwt
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
|
||||
from app.gateway.auth.jwt import create_access_token, decode_token
|
||||
|
||||
|
||||
def test_auth_error_code_values():
|
||||
assert AuthErrorCode.INVALID_CREDENTIALS == "invalid_credentials"
|
||||
assert AuthErrorCode.TOKEN_EXPIRED == "token_expired"
|
||||
assert AuthErrorCode.NOT_AUTHENTICATED == "not_authenticated"
|
||||
|
||||
|
||||
def test_token_error_values():
|
||||
assert TokenError.EXPIRED == "expired"
|
||||
assert TokenError.INVALID_SIGNATURE == "invalid_signature"
|
||||
assert TokenError.MALFORMED == "malformed"
|
||||
|
||||
|
||||
def test_auth_error_response_serialization():
|
||||
err = AuthErrorResponse(
|
||||
code=AuthErrorCode.TOKEN_EXPIRED,
|
||||
message="Token has expired",
|
||||
)
|
||||
d = err.model_dump()
|
||||
assert d == {"code": "token_expired", "message": "Token has expired"}
|
||||
|
||||
|
||||
def test_auth_error_response_from_dict():
|
||||
d = {"code": "invalid_credentials", "message": "Wrong password"}
|
||||
err = AuthErrorResponse(**d)
|
||||
assert err.code == AuthErrorCode.INVALID_CREDENTIALS
|
||||
|
||||
|
||||
# ── decode_token typed failure tests ──────────────────────────────
|
||||
|
||||
_TEST_SECRET = "test-secret-for-jwt-decode-token-tests"
|
||||
|
||||
|
||||
def _setup_config():
|
||||
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
|
||||
|
||||
|
||||
def test_decode_token_returns_token_error_on_expired():
|
||||
_setup_config()
|
||||
expired_payload = {"sub": "user-1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(expired_payload, _TEST_SECRET, algorithm="HS256")
|
||||
result = decode_token(token)
|
||||
assert result == TokenError.EXPIRED
|
||||
|
||||
|
||||
def test_decode_token_returns_token_error_on_bad_signature():
|
||||
_setup_config()
|
||||
payload = {"sub": "user-1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256")
|
||||
result = decode_token(token)
|
||||
assert result == TokenError.INVALID_SIGNATURE
|
||||
|
||||
|
||||
def test_decode_token_returns_token_error_on_malformed():
|
||||
_setup_config()
|
||||
result = decode_token("not-a-jwt")
|
||||
assert result == TokenError.MALFORMED
|
||||
|
||||
|
||||
def test_decode_token_returns_payload_on_valid():
|
||||
_setup_config()
|
||||
token = create_access_token("user-123")
|
||||
result = decode_token(token)
|
||||
assert not isinstance(result, TokenError)
|
||||
assert result.sub == "user-123"
|
||||
@@ -0,0 +1,222 @@
|
||||
"""Tests for the global AuthMiddleware (fail-closed safety net)."""
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from app.gateway.auth_middleware import AuthMiddleware, _is_public
|
||||
|
||||
# ── _is_public unit tests ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/health",
|
||||
"/health/",
|
||||
"/docs",
|
||||
"/docs/",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/api/v1/auth/login/local",
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/logout",
|
||||
"/api/v1/auth/setup-status",
|
||||
],
|
||||
)
|
||||
def test_public_paths(path: str):
|
||||
assert _is_public(path) is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/api/models",
|
||||
"/api/mcp/config",
|
||||
"/api/memory",
|
||||
"/api/skills",
|
||||
"/api/threads/123",
|
||||
"/api/threads/123/uploads",
|
||||
"/api/agents",
|
||||
"/api/channels",
|
||||
"/api/runs/stream",
|
||||
"/api/threads/123/runs",
|
||||
"/api/v1/auth/me",
|
||||
"/api/v1/auth/change-password",
|
||||
],
|
||||
)
|
||||
def test_protected_paths(path: str):
|
||||
assert _is_public(path) is False
|
||||
|
||||
|
||||
# ── Trailing slash / normalization edge cases ─────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/api/v1/auth/login/local/",
|
||||
"/api/v1/auth/register/",
|
||||
"/api/v1/auth/logout/",
|
||||
"/api/v1/auth/setup-status/",
|
||||
],
|
||||
)
|
||||
def test_public_auth_paths_with_trailing_slash(path: str):
|
||||
assert _is_public(path) is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/api/models/",
|
||||
"/api/v1/auth/me/",
|
||||
"/api/v1/auth/change-password/",
|
||||
],
|
||||
)
|
||||
def test_protected_paths_with_trailing_slash(path: str):
|
||||
assert _is_public(path) is False
|
||||
|
||||
|
||||
def test_unknown_api_path_is_protected():
|
||||
"""Fail-closed: any new /api/* path is protected by default."""
|
||||
assert _is_public("/api/new-feature") is False
|
||||
assert _is_public("/api/v2/something") is False
|
||||
assert _is_public("/api/v1/auth/new-endpoint") is False
|
||||
|
||||
|
||||
# ── Middleware integration tests ──────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_app():
|
||||
"""Create a minimal FastAPI app with AuthMiddleware for testing."""
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(AuthMiddleware)
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/api/v1/auth/me")
|
||||
async def auth_me():
|
||||
return {"id": "1", "email": "test@test.com"}
|
||||
|
||||
@app.get("/api/v1/auth/setup-status")
|
||||
async def setup_status():
|
||||
return {"needs_setup": False}
|
||||
|
||||
@app.get("/api/models")
|
||||
async def models_get():
|
||||
return {"models": []}
|
||||
|
||||
@app.put("/api/mcp/config")
|
||||
async def mcp_put():
|
||||
return {"ok": True}
|
||||
|
||||
@app.delete("/api/threads/abc")
|
||||
async def thread_delete():
|
||||
return {"ok": True}
|
||||
|
||||
@app.patch("/api/threads/abc")
|
||||
async def thread_patch():
|
||||
return {"ok": True}
|
||||
|
||||
@app.post("/api/threads/abc/runs/stream")
|
||||
async def stream():
|
||||
return {"ok": True}
|
||||
|
||||
@app.get("/api/future-endpoint")
|
||||
async def future():
|
||||
return {"ok": True}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
return TestClient(_make_app())
|
||||
|
||||
|
||||
def test_public_path_no_cookie(client):
|
||||
res = client.get("/health")
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def test_public_auth_path_no_cookie(client):
|
||||
"""Public auth endpoints (login/register) pass without cookie."""
|
||||
res = client.get("/api/v1/auth/setup-status")
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def test_protected_auth_path_no_cookie(client):
|
||||
"""/auth/me requires cookie even though it's under /api/v1/auth/."""
|
||||
res = client.get("/api/v1/auth/me")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_protected_path_no_cookie_returns_401(client):
|
||||
res = client.get("/api/models")
|
||||
assert res.status_code == 401
|
||||
body = res.json()
|
||||
assert body["detail"]["code"] == "not_authenticated"
|
||||
|
||||
|
||||
def test_protected_path_with_junk_cookie_rejected(client):
|
||||
"""Junk cookie → 401. Middleware strictly validates the JWT now
|
||||
(AUTH_TEST_PLAN test 7.5.8); it no longer silently passes bad
|
||||
tokens through to the route handler."""
|
||||
res = client.get("/api/models", cookies={"access_token": "some-token"})
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_protected_post_no_cookie_returns_401(client):
|
||||
res = client.post("/api/threads/abc/runs/stream")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
# ── Method matrix: PUT/DELETE/PATCH also protected ────────────────────────
|
||||
|
||||
|
||||
def test_protected_put_no_cookie(client):
|
||||
res = client.put("/api/mcp/config")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_protected_delete_no_cookie(client):
|
||||
res = client.delete("/api/threads/abc")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_protected_patch_no_cookie(client):
|
||||
res = client.patch("/api/threads/abc")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_put_with_junk_cookie_rejected(client):
|
||||
"""Junk cookie on PUT → 401 (strict JWT validation in middleware)."""
|
||||
client.cookies.set("access_token", "tok")
|
||||
res = client.put("/api/mcp/config")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_delete_with_junk_cookie_rejected(client):
|
||||
"""Junk cookie on DELETE → 401 (strict JWT validation in middleware)."""
|
||||
client.cookies.set("access_token", "tok")
|
||||
res = client.delete("/api/threads/abc")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
# ── Fail-closed: unknown future endpoints ─────────────────────────────────
|
||||
|
||||
|
||||
def test_unknown_endpoint_no_cookie_returns_401(client):
|
||||
"""Any new /api/* endpoint is blocked by default without cookie."""
|
||||
res = client.get("/api/future-endpoint")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_unknown_endpoint_with_junk_cookie_rejected(client):
|
||||
"""New endpoints are also protected by strict JWT validation."""
|
||||
client.cookies.set("access_token", "tok")
|
||||
res = client.get("/api/future-endpoint")
|
||||
assert res.status_code == 401
|
||||
@@ -0,0 +1,701 @@
|
||||
"""Tests for auth type system hardening.
|
||||
|
||||
Covers structured error responses, typed decode_token callers,
|
||||
CSRF middleware path matching, config-driven cookie security,
|
||||
and unhappy paths / edge cases for all auth boundaries.
|
||||
"""
|
||||
|
||||
import os
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import jwt as pyjwt
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
|
||||
from app.gateway.auth.jwt import decode_token
|
||||
from app.gateway.csrf_middleware import (
|
||||
CSRF_COOKIE_NAME,
|
||||
CSRF_HEADER_NAME,
|
||||
CSRFMiddleware,
|
||||
is_auth_endpoint,
|
||||
should_check_csrf,
|
||||
)
|
||||
|
||||
# ── Setup ────────────────────────────────────────────────────────────
|
||||
|
||||
_TEST_SECRET = "test-secret-for-auth-type-system-tests-min32"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _persistence_engine(tmp_path):
|
||||
"""Initialise a per-test SQLite engine + reset cached provider singletons.
|
||||
|
||||
The auth tests call real HTTP handlers that go through
|
||||
``SQLiteUserRepository`` → ``get_session_factory``. Each test gets
|
||||
a fresh DB plus a clean ``deps._cached_*`` so the cached provider
|
||||
does not hold a dangling reference to the previous test's engine.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from app.gateway import deps
|
||||
from deerflow.persistence.engine import close_engine, init_engine
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path}/auth_types.db"
|
||||
asyncio.run(init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)))
|
||||
deps._cached_local_provider = None
|
||||
deps._cached_repo = None
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
deps._cached_local_provider = None
|
||||
deps._cached_repo = None
|
||||
asyncio.run(close_engine())
|
||||
|
||||
|
||||
def _setup_config():
|
||||
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
|
||||
|
||||
|
||||
# ── CSRF Middleware Path Matching ────────────────────────────────────
|
||||
|
||||
|
||||
class _FakeRequest:
|
||||
"""Minimal request mock for CSRF path matching tests."""
|
||||
|
||||
def __init__(self, path: str, method: str = "POST"):
|
||||
self.method = method
|
||||
|
||||
class _URL:
|
||||
def __init__(self, p):
|
||||
self.path = p
|
||||
|
||||
self.url = _URL(path)
|
||||
self.cookies = {}
|
||||
self.headers = {}
|
||||
|
||||
|
||||
def test_csrf_exempts_login_local():
|
||||
"""login/local (actual route) should be exempt from CSRF."""
|
||||
req = _FakeRequest("/api/v1/auth/login/local")
|
||||
assert is_auth_endpoint(req) is True
|
||||
|
||||
|
||||
def test_csrf_exempts_login_local_trailing_slash():
|
||||
"""Trailing slash should also be exempt."""
|
||||
req = _FakeRequest("/api/v1/auth/login/local/")
|
||||
assert is_auth_endpoint(req) is True
|
||||
|
||||
|
||||
def test_csrf_exempts_logout():
|
||||
req = _FakeRequest("/api/v1/auth/logout")
|
||||
assert is_auth_endpoint(req) is True
|
||||
|
||||
|
||||
def test_csrf_exempts_register():
|
||||
req = _FakeRequest("/api/v1/auth/register")
|
||||
assert is_auth_endpoint(req) is True
|
||||
|
||||
|
||||
def test_csrf_does_not_exempt_old_login_path():
|
||||
"""Old /api/v1/auth/login (without /local) should NOT be exempt."""
|
||||
req = _FakeRequest("/api/v1/auth/login")
|
||||
assert is_auth_endpoint(req) is False
|
||||
|
||||
|
||||
def test_csrf_does_not_exempt_me():
|
||||
req = _FakeRequest("/api/v1/auth/me")
|
||||
assert is_auth_endpoint(req) is False
|
||||
|
||||
|
||||
def test_csrf_skips_get_requests():
|
||||
req = _FakeRequest("/api/v1/auth/me", method="GET")
|
||||
assert should_check_csrf(req) is False
|
||||
|
||||
|
||||
def test_csrf_checks_post_to_protected():
|
||||
req = _FakeRequest("/api/v1/some/endpoint", method="POST")
|
||||
assert should_check_csrf(req) is True
|
||||
|
||||
|
||||
# ── Structured Error Response Format ────────────────────────────────
|
||||
|
||||
|
||||
def test_auth_error_response_has_code_and_message():
|
||||
"""All auth errors should have structured {code, message} format."""
|
||||
err = AuthErrorResponse(
|
||||
code=AuthErrorCode.INVALID_CREDENTIALS,
|
||||
message="Wrong password",
|
||||
)
|
||||
d = err.model_dump()
|
||||
assert "code" in d
|
||||
assert "message" in d
|
||||
assert d["code"] == "invalid_credentials"
|
||||
|
||||
|
||||
def test_auth_error_response_all_codes_serializable():
|
||||
"""Every AuthErrorCode should be serializable in AuthErrorResponse."""
|
||||
for code in AuthErrorCode:
|
||||
err = AuthErrorResponse(code=code, message=f"Test {code.value}")
|
||||
d = err.model_dump()
|
||||
assert d["code"] == code.value
|
||||
|
||||
|
||||
# ── decode_token Caller Pattern ──────────────────────────────────────
|
||||
|
||||
|
||||
def test_decode_token_expired_maps_to_token_expired_code():
|
||||
"""TokenError.EXPIRED should map to AuthErrorCode.TOKEN_EXPIRED."""
|
||||
_setup_config()
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt as pyjwt
|
||||
|
||||
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
|
||||
result = decode_token(token)
|
||||
assert result == TokenError.EXPIRED
|
||||
|
||||
# Verify the mapping pattern used in route handlers
|
||||
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
|
||||
assert code == AuthErrorCode.TOKEN_EXPIRED
|
||||
|
||||
|
||||
def test_decode_token_invalid_sig_maps_to_token_invalid_code():
|
||||
"""TokenError.INVALID_SIGNATURE should map to AuthErrorCode.TOKEN_INVALID."""
|
||||
_setup_config()
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt as pyjwt
|
||||
|
||||
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(payload, "wrong-key", algorithm="HS256")
|
||||
result = decode_token(token)
|
||||
assert result == TokenError.INVALID_SIGNATURE
|
||||
|
||||
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
|
||||
assert code == AuthErrorCode.TOKEN_INVALID
|
||||
|
||||
|
||||
def test_decode_token_malformed_maps_to_token_invalid_code():
|
||||
"""TokenError.MALFORMED should map to AuthErrorCode.TOKEN_INVALID."""
|
||||
_setup_config()
|
||||
result = decode_token("garbage")
|
||||
assert result == TokenError.MALFORMED
|
||||
|
||||
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
|
||||
assert code == AuthErrorCode.TOKEN_INVALID
|
||||
|
||||
|
||||
# ── Login Response Format ────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_login_response_model_has_no_access_token():
|
||||
"""LoginResponse should NOT contain access_token field (RFC-001)."""
|
||||
from app.gateway.routers.auth import LoginResponse
|
||||
|
||||
resp = LoginResponse(expires_in=604800)
|
||||
d = resp.model_dump()
|
||||
assert "access_token" not in d
|
||||
assert "expires_in" in d
|
||||
assert d["expires_in"] == 604800
|
||||
|
||||
|
||||
def test_login_response_model_fields():
|
||||
"""LoginResponse has expires_in and needs_setup."""
|
||||
from app.gateway.routers.auth import LoginResponse
|
||||
|
||||
fields = set(LoginResponse.model_fields.keys())
|
||||
assert fields == {"expires_in", "needs_setup"}
|
||||
|
||||
|
||||
# ── AuthConfig in Route ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_used_in_login_response():
|
||||
"""LoginResponse.expires_in should come from config.token_expiry_days."""
|
||||
from app.gateway.routers.auth import LoginResponse
|
||||
|
||||
expected_seconds = 14 * 24 * 3600
|
||||
resp = LoginResponse(expires_in=expected_seconds)
|
||||
assert resp.expires_in == expected_seconds
|
||||
|
||||
|
||||
# ── UserResponse Type Preservation ───────────────────────────────────
|
||||
|
||||
|
||||
def test_user_response_system_role_literal():
|
||||
"""UserResponse.system_role should only accept 'admin' or 'user'."""
|
||||
from app.gateway.auth.models import UserResponse
|
||||
|
||||
# Valid roles
|
||||
resp = UserResponse(id="1", email="a@b.com", system_role="admin")
|
||||
assert resp.system_role == "admin"
|
||||
|
||||
resp = UserResponse(id="1", email="a@b.com", system_role="user")
|
||||
assert resp.system_role == "user"
|
||||
|
||||
|
||||
def test_user_response_rejects_invalid_role():
|
||||
"""UserResponse should reject invalid system_role values."""
|
||||
from app.gateway.auth.models import UserResponse
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UserResponse(id="1", email="a@b.com", system_role="superadmin")
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════
|
||||
# UNHAPPY PATHS / EDGE CASES
|
||||
# ══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
# ── get_current_user structured 401 responses ────────────────────────
|
||||
|
||||
|
||||
def test_get_current_user_no_cookie_returns_not_authenticated():
|
||||
"""No cookie → 401 with code=not_authenticated."""
|
||||
import asyncio
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
mock_request = type("MockRequest", (), {"cookies": {}})()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
detail = exc_info.value.detail
|
||||
assert detail["code"] == "not_authenticated"
|
||||
|
||||
|
||||
def test_get_current_user_expired_token_returns_token_expired():
|
||||
"""Expired token → 401 with code=token_expired."""
|
||||
import asyncio
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
_setup_config()
|
||||
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
|
||||
|
||||
mock_request = type("MockRequest", (), {"cookies": {"access_token": token}})()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
detail = exc_info.value.detail
|
||||
assert detail["code"] == "token_expired"
|
||||
|
||||
|
||||
def test_get_current_user_invalid_token_returns_token_invalid():
|
||||
"""Bad signature → 401 with code=token_invalid."""
|
||||
import asyncio
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
_setup_config()
|
||||
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256")
|
||||
|
||||
mock_request = type("MockRequest", (), {"cookies": {"access_token": token}})()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
detail = exc_info.value.detail
|
||||
assert detail["code"] == "token_invalid"
|
||||
|
||||
|
||||
def test_get_current_user_malformed_token_returns_token_invalid():
|
||||
"""Garbage token → 401 with code=token_invalid."""
|
||||
import asyncio
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
_setup_config()
|
||||
mock_request = type("MockRequest", (), {"cookies": {"access_token": "not-a-jwt"}})()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
detail = exc_info.value.detail
|
||||
assert detail["code"] == "token_invalid"
|
||||
|
||||
|
||||
# ── decode_token edge cases ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_decode_token_empty_string_returns_malformed():
|
||||
_setup_config()
|
||||
result = decode_token("")
|
||||
assert result == TokenError.MALFORMED
|
||||
|
||||
|
||||
def test_decode_token_whitespace_returns_malformed():
|
||||
_setup_config()
|
||||
result = decode_token(" ")
|
||||
assert result == TokenError.MALFORMED
|
||||
|
||||
|
||||
# ── AuthConfig validation edge cases ─────────────────────────────────
|
||||
|
||||
|
||||
def test_auth_config_missing_jwt_secret_raises():
|
||||
"""AuthConfig requires jwt_secret — no default allowed."""
|
||||
with pytest.raises(ValidationError):
|
||||
AuthConfig()
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_zero_raises():
|
||||
"""token_expiry_days must be >= 1."""
|
||||
with pytest.raises(ValidationError):
|
||||
AuthConfig(jwt_secret="secret", token_expiry_days=0)
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_31_raises():
|
||||
"""token_expiry_days must be <= 30."""
|
||||
with pytest.raises(ValidationError):
|
||||
AuthConfig(jwt_secret="secret", token_expiry_days=31)
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_boundary_1_ok():
|
||||
config = AuthConfig(jwt_secret="secret", token_expiry_days=1)
|
||||
assert config.token_expiry_days == 1
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_boundary_30_ok():
|
||||
config = AuthConfig(jwt_secret="secret", token_expiry_days=30)
|
||||
assert config.token_expiry_days == 30
|
||||
|
||||
|
||||
def test_get_auth_config_missing_env_var_generates_ephemeral(caplog):
|
||||
"""get_auth_config() auto-generates ephemeral secret when AUTH_JWT_SECRET is unset."""
|
||||
import logging
|
||||
|
||||
import app.gateway.auth.config as cfg
|
||||
|
||||
old = cfg._auth_config
|
||||
cfg._auth_config = None
|
||||
try:
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
os.environ.pop("AUTH_JWT_SECRET", None)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
config = cfg.get_auth_config()
|
||||
assert config.jwt_secret
|
||||
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
|
||||
finally:
|
||||
cfg._auth_config = old
|
||||
|
||||
|
||||
# ── CSRF middleware integration (unhappy paths) ──────────────────────
|
||||
|
||||
|
||||
def _make_csrf_app():
|
||||
"""Create a minimal FastAPI app with CSRFMiddleware for testing."""
|
||||
from fastapi import HTTPException as _HTTPException
|
||||
from fastapi.responses import JSONResponse as _JSONResponse
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.exception_handler(_HTTPException)
|
||||
async def _http_exc_handler(request, exc):
|
||||
return _JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
||||
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
@app.post("/api/v1/test/protected")
|
||||
async def protected():
|
||||
return {"ok": True}
|
||||
|
||||
@app.post("/api/v1/auth/login/local")
|
||||
async def login():
|
||||
return {"ok": True}
|
||||
|
||||
@app.get("/api/v1/test/read")
|
||||
async def read_endpoint():
|
||||
return {"ok": True}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def test_csrf_middleware_blocks_post_without_token():
|
||||
"""POST to protected endpoint without CSRF token → 403 with structured detail."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
resp = client.post("/api/v1/test/protected")
|
||||
assert resp.status_code == 403
|
||||
assert "CSRF" in resp.json()["detail"]
|
||||
assert "missing" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_csrf_middleware_blocks_post_with_mismatched_token():
|
||||
"""POST with mismatched CSRF cookie/header → 403 with mismatch detail."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
client.cookies.set(CSRF_COOKIE_NAME, "token-a")
|
||||
resp = client.post(
|
||||
"/api/v1/test/protected",
|
||||
headers={CSRF_HEADER_NAME: "token-b"},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
assert "mismatch" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_csrf_middleware_allows_post_with_matching_token():
|
||||
"""POST with matching CSRF cookie/header → 200."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
token = secrets.token_urlsafe(64)
|
||||
client.cookies.set(CSRF_COOKIE_NAME, token)
|
||||
resp = client.post(
|
||||
"/api/v1/test/protected",
|
||||
headers={CSRF_HEADER_NAME: token},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_csrf_middleware_allows_get_without_token():
|
||||
"""GET requests bypass CSRF check."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
resp = client.get("/api/v1/test/read")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_csrf_middleware_exempts_login_local():
|
||||
"""POST to login/local is exempt from CSRF (no token yet)."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
resp = client.post("/api/v1/auth/login/local")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_csrf_middleware_sets_cookie_on_auth_endpoint():
|
||||
"""Auth endpoints should receive a CSRF cookie in response."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
resp = client.post("/api/v1/auth/login/local")
|
||||
assert CSRF_COOKIE_NAME in resp.cookies
|
||||
|
||||
|
||||
# ── UserResponse edge cases ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_user_response_missing_required_fields():
|
||||
"""UserResponse with missing fields → ValidationError."""
|
||||
from app.gateway.auth.models import UserResponse
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UserResponse(id="1") # missing email, system_role
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UserResponse(id="1", email="a@b.com") # missing system_role
|
||||
|
||||
|
||||
def test_user_response_empty_string_role_rejected():
|
||||
"""Empty string is not a valid role."""
|
||||
from app.gateway.auth.models import UserResponse
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UserResponse(id="1", email="a@b.com", system_role="")
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════
|
||||
# HTTP-LEVEL API CONTRACT TESTS
|
||||
# ══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
def _make_auth_app():
|
||||
"""Create FastAPI app with auth routes for contract testing."""
|
||||
from app.gateway.app import create_app
|
||||
|
||||
return create_app()
|
||||
|
||||
|
||||
def _get_auth_client():
|
||||
"""Get TestClient for auth API contract tests."""
|
||||
return TestClient(_make_auth_app())
|
||||
|
||||
|
||||
def test_api_auth_me_no_cookie_returns_structured_401():
|
||||
"""/api/v1/auth/me without cookie → 401 with {code: 'not_authenticated'}."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "not_authenticated"
|
||||
assert "message" in body["detail"]
|
||||
|
||||
|
||||
def test_api_auth_me_expired_token_returns_structured_401():
|
||||
"""/api/v1/auth/me with expired token → 401 with {code: 'token_expired'}."""
|
||||
_setup_config()
|
||||
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
|
||||
|
||||
client = _get_auth_client()
|
||||
client.cookies.set("access_token", token)
|
||||
resp = client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "token_expired"
|
||||
|
||||
|
||||
def test_api_auth_me_invalid_sig_returns_structured_401():
|
||||
"""/api/v1/auth/me with bad signature → 401 with {code: 'token_invalid'}."""
|
||||
_setup_config()
|
||||
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(payload, "wrong-key", algorithm="HS256")
|
||||
|
||||
client = _get_auth_client()
|
||||
client.cookies.set("access_token", token)
|
||||
resp = client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "token_invalid"
|
||||
|
||||
|
||||
def test_api_login_bad_credentials_returns_structured_401():
|
||||
"""Login with wrong password → 401 with {code: 'invalid_credentials'}."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
data={"username": "nonexistent@test.com", "password": "wrongpassword"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "invalid_credentials"
|
||||
|
||||
|
||||
def test_api_login_success_no_token_in_body():
|
||||
"""Successful login → response body has expires_in but NOT access_token."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
# Register first
|
||||
client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "contract-test@test.com", "password": "securepassword123"},
|
||||
)
|
||||
# Login
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
data={"username": "contract-test@test.com", "password": "securepassword123"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert "expires_in" in body
|
||||
assert "access_token" not in body
|
||||
# Token should be in cookie, not body
|
||||
assert "access_token" in resp.cookies
|
||||
|
||||
|
||||
def test_api_register_duplicate_returns_structured_400():
|
||||
"""Register with duplicate email → 400 with {code: 'email_already_exists'}."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
email = "dup-contract-test@test.com"
|
||||
# First register
|
||||
client.post("/api/v1/auth/register", json={"email": email, "password": "Tr0ub4dor3a"})
|
||||
# Duplicate
|
||||
resp = client.post("/api/v1/auth/register", json={"email": email, "password": "AnotherStr0ngPwd!"})
|
||||
assert resp.status_code == 400
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "email_already_exists"
|
||||
|
||||
|
||||
# ── Cookie security: HTTP vs HTTPS ────────────────────────────────────
|
||||
|
||||
|
||||
def _unique_email(prefix: str) -> str:
|
||||
return f"{prefix}-{secrets.token_hex(4)}@test.com"
|
||||
|
||||
|
||||
def _get_set_cookie_headers(resp) -> list[str]:
|
||||
"""Extract all set-cookie header values from a TestClient response."""
|
||||
return [v for k, v in resp.headers.multi_items() if k.lower() == "set-cookie"]
|
||||
|
||||
|
||||
def test_register_http_cookie_httponly_true_secure_false():
|
||||
"""HTTP register → access_token cookie is httponly=True, secure=False, no max_age."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("http-cookie"), "password": "Tr0ub4dor3a"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
cookie_header = resp.headers.get("set-cookie", "")
|
||||
assert "access_token=" in cookie_header
|
||||
assert "httponly" in cookie_header.lower()
|
||||
assert "secure" not in cookie_header.lower().replace("samesite", "")
|
||||
|
||||
|
||||
def test_register_https_cookie_httponly_true_secure_true():
|
||||
"""HTTPS register (x-forwarded-proto) → access_token cookie is httponly=True, secure=True, has max_age."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("https-cookie"), "password": "Tr0ub4dor3a"},
|
||||
headers={"x-forwarded-proto": "https"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
cookie_header = resp.headers.get("set-cookie", "")
|
||||
assert "access_token=" in cookie_header
|
||||
assert "httponly" in cookie_header.lower()
|
||||
assert "secure" in cookie_header.lower()
|
||||
assert "max-age" in cookie_header.lower()
|
||||
|
||||
|
||||
def test_login_https_sets_secure_cookie():
|
||||
"""HTTPS login → access_token cookie has secure flag."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
email = _unique_email("https-login")
|
||||
client.post("/api/v1/auth/register", json={"email": email, "password": "Tr0ub4dor3a"})
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
data={"username": email, "password": "Tr0ub4dor3a"},
|
||||
headers={"x-forwarded-proto": "https"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
cookie_header = resp.headers.get("set-cookie", "")
|
||||
assert "access_token=" in cookie_header
|
||||
assert "httponly" in cookie_header.lower()
|
||||
assert "secure" in cookie_header.lower()
|
||||
|
||||
|
||||
def test_csrf_cookie_secure_on_https():
|
||||
"""HTTPS register → csrf_token cookie has secure flag but NOT httponly."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("csrf-https"), "password": "Tr0ub4dor3a"},
|
||||
headers={"x-forwarded-proto": "https"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h]
|
||||
assert csrf_cookies, "csrf_token cookie not set on HTTPS register"
|
||||
csrf_header = csrf_cookies[0]
|
||||
assert "secure" in csrf_header.lower()
|
||||
assert "httponly" not in csrf_header.lower()
|
||||
|
||||
|
||||
def test_csrf_cookie_not_secure_on_http():
|
||||
"""HTTP register → csrf_token cookie does NOT have secure flag."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("csrf-http"), "password": "Tr0ub4dor3a"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h]
|
||||
assert csrf_cookies, "csrf_token cookie not set on HTTP register"
|
||||
csrf_header = csrf_cookies[0]
|
||||
assert "secure" not in csrf_header.lower().replace("samesite", "")
|
||||
@@ -14,9 +14,10 @@ class TestCheckpointerNoneFix:
|
||||
"""make_checkpointer should return InMemorySaver when config.checkpointer is None."""
|
||||
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
||||
|
||||
# Mock get_app_config to return a config with checkpointer=None
|
||||
# Mock get_app_config to return a config with checkpointer=None and database=None
|
||||
mock_config = MagicMock()
|
||||
mock_config.checkpointer = None
|
||||
mock_config.database = None
|
||||
|
||||
with patch("deerflow.agents.checkpointer.async_provider.get_app_config", return_value=mock_config):
|
||||
async with make_checkpointer() as checkpointer:
|
||||
|
||||
@@ -0,0 +1,188 @@
|
||||
"""Tests for LangChain-to-OpenAI message format converters."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.runtime.converters import (
|
||||
langchain_messages_to_openai,
|
||||
langchain_to_openai_completion,
|
||||
langchain_to_openai_message,
|
||||
)
|
||||
|
||||
|
||||
def _make_ai_message(content="", tool_calls=None, id="msg-123", usage_metadata=None, response_metadata=None):
|
||||
msg = MagicMock()
|
||||
msg.type = "ai"
|
||||
msg.content = content
|
||||
msg.tool_calls = tool_calls or []
|
||||
msg.id = id
|
||||
msg.usage_metadata = usage_metadata
|
||||
msg.response_metadata = response_metadata or {}
|
||||
return msg
|
||||
|
||||
|
||||
def _make_human_message(content="Hello"):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = content
|
||||
return msg
|
||||
|
||||
|
||||
def _make_system_message(content="You are an assistant."):
|
||||
msg = MagicMock()
|
||||
msg.type = "system"
|
||||
msg.content = content
|
||||
return msg
|
||||
|
||||
|
||||
def _make_tool_message(content="result", tool_call_id="call-abc"):
|
||||
msg = MagicMock()
|
||||
msg.type = "tool"
|
||||
msg.content = content
|
||||
msg.tool_call_id = tool_call_id
|
||||
return msg
|
||||
|
||||
|
||||
class TestLangchainToOpenaiMessage:
|
||||
def test_ai_message_text_only(self):
|
||||
msg = _make_ai_message(content="Hello world")
|
||||
result = langchain_to_openai_message(msg)
|
||||
assert result["role"] == "assistant"
|
||||
assert result["content"] == "Hello world"
|
||||
assert "tool_calls" not in result
|
||||
|
||||
def test_ai_message_with_tool_calls(self):
|
||||
tool_calls = [
|
||||
{"id": "call-1", "name": "bash", "args": {"command": "ls"}},
|
||||
]
|
||||
msg = _make_ai_message(content="", tool_calls=tool_calls)
|
||||
result = langchain_to_openai_message(msg)
|
||||
assert result["role"] == "assistant"
|
||||
assert result["content"] is None
|
||||
assert len(result["tool_calls"]) == 1
|
||||
tc = result["tool_calls"][0]
|
||||
assert tc["id"] == "call-1"
|
||||
assert tc["type"] == "function"
|
||||
assert tc["function"]["name"] == "bash"
|
||||
# arguments must be a JSON string
|
||||
args = json.loads(tc["function"]["arguments"])
|
||||
assert args == {"command": "ls"}
|
||||
|
||||
def test_ai_message_text_and_tool_calls(self):
|
||||
tool_calls = [
|
||||
{"id": "call-2", "name": "read_file", "args": {"path": "/tmp/x"}},
|
||||
]
|
||||
msg = _make_ai_message(content="Reading the file", tool_calls=tool_calls)
|
||||
result = langchain_to_openai_message(msg)
|
||||
assert result["role"] == "assistant"
|
||||
assert result["content"] == "Reading the file"
|
||||
assert len(result["tool_calls"]) == 1
|
||||
|
||||
def test_ai_message_empty_content_no_tools(self):
|
||||
msg = _make_ai_message(content="")
|
||||
result = langchain_to_openai_message(msg)
|
||||
assert result["role"] == "assistant"
|
||||
assert result["content"] == ""
|
||||
assert "tool_calls" not in result
|
||||
|
||||
def test_ai_message_list_content(self):
|
||||
# Multimodal content is preserved as-is
|
||||
list_content = [
|
||||
{"type": "text", "text": "Here is an image"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
|
||||
]
|
||||
msg = _make_ai_message(content=list_content)
|
||||
result = langchain_to_openai_message(msg)
|
||||
assert result["role"] == "assistant"
|
||||
assert result["content"] == list_content
|
||||
|
||||
def test_human_message(self):
|
||||
msg = _make_human_message("Tell me a joke")
|
||||
result = langchain_to_openai_message(msg)
|
||||
assert result["role"] == "user"
|
||||
assert result["content"] == "Tell me a joke"
|
||||
|
||||
def test_tool_message(self):
|
||||
msg = _make_tool_message(content="file contents here", tool_call_id="call-xyz")
|
||||
result = langchain_to_openai_message(msg)
|
||||
assert result["role"] == "tool"
|
||||
assert result["tool_call_id"] == "call-xyz"
|
||||
assert result["content"] == "file contents here"
|
||||
|
||||
def test_system_message(self):
|
||||
msg = _make_system_message("You are a helpful assistant.")
|
||||
result = langchain_to_openai_message(msg)
|
||||
assert result["role"] == "system"
|
||||
assert result["content"] == "You are a helpful assistant."
|
||||
|
||||
|
||||
class TestLangchainToOpenaiCompletion:
|
||||
def test_basic_completion(self):
|
||||
usage_metadata = {"input_tokens": 10, "output_tokens": 20}
|
||||
msg = _make_ai_message(
|
||||
content="Hello",
|
||||
id="msg-abc",
|
||||
usage_metadata=usage_metadata,
|
||||
response_metadata={"model_name": "gpt-4o", "finish_reason": "stop"},
|
||||
)
|
||||
result = langchain_to_openai_completion(msg)
|
||||
assert result["id"] == "msg-abc"
|
||||
assert result["model"] == "gpt-4o"
|
||||
assert len(result["choices"]) == 1
|
||||
choice = result["choices"][0]
|
||||
assert choice["index"] == 0
|
||||
assert choice["finish_reason"] == "stop"
|
||||
assert choice["message"]["role"] == "assistant"
|
||||
assert choice["message"]["content"] == "Hello"
|
||||
assert result["usage"] is not None
|
||||
assert result["usage"]["prompt_tokens"] == 10
|
||||
assert result["usage"]["completion_tokens"] == 20
|
||||
assert result["usage"]["total_tokens"] == 30
|
||||
|
||||
def test_completion_with_tool_calls(self):
|
||||
tool_calls = [{"id": "call-1", "name": "bash", "args": {}}]
|
||||
msg = _make_ai_message(
|
||||
content="",
|
||||
tool_calls=tool_calls,
|
||||
id="msg-tc",
|
||||
response_metadata={"model_name": "gpt-4o"},
|
||||
)
|
||||
result = langchain_to_openai_completion(msg)
|
||||
assert result["choices"][0]["finish_reason"] == "tool_calls"
|
||||
|
||||
def test_completion_no_usage(self):
|
||||
msg = _make_ai_message(content="Hi", id="msg-nousage", usage_metadata=None)
|
||||
result = langchain_to_openai_completion(msg)
|
||||
assert result["usage"] is None
|
||||
|
||||
def test_finish_reason_from_response_metadata(self):
|
||||
msg = _make_ai_message(
|
||||
content="Done",
|
||||
id="msg-fr",
|
||||
response_metadata={"model_name": "claude-3", "finish_reason": "end_turn"},
|
||||
)
|
||||
result = langchain_to_openai_completion(msg)
|
||||
assert result["choices"][0]["finish_reason"] == "end_turn"
|
||||
|
||||
def test_finish_reason_default_stop(self):
|
||||
msg = _make_ai_message(content="Done", id="msg-defstop", response_metadata={})
|
||||
result = langchain_to_openai_completion(msg)
|
||||
assert result["choices"][0]["finish_reason"] == "stop"
|
||||
|
||||
|
||||
class TestMessagesToOpenai:
|
||||
def test_convert_message_list(self):
|
||||
human = _make_human_message("Hi")
|
||||
ai = _make_ai_message(content="Hello!")
|
||||
tool_msg = _make_tool_message("result", "call-1")
|
||||
messages = [human, ai, tool_msg]
|
||||
result = langchain_messages_to_openai(messages)
|
||||
assert len(result) == 3
|
||||
assert result[0]["role"] == "user"
|
||||
assert result[1]["role"] == "assistant"
|
||||
assert result[2]["role"] == "tool"
|
||||
|
||||
def test_empty_list(self):
|
||||
assert langchain_messages_to_openai([]) == []
|
||||
@@ -0,0 +1,319 @@
|
||||
"""Tests for _ensure_admin_user() in app.py.
|
||||
|
||||
Covers: first-boot admin creation, auto-reset on needs_setup=True,
|
||||
no-op on needs_setup=False, migration, and edge cases.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-ensure-admin-testing-min-32")
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
from app.gateway.auth.models import User
|
||||
|
||||
_JWT_SECRET = "test-secret-key-ensure-admin-testing-min-32"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_auth_config():
|
||||
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
|
||||
yield
|
||||
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
|
||||
|
||||
|
||||
def _make_app_stub(store=None):
|
||||
"""Minimal app-like object with state.store."""
|
||||
app = SimpleNamespace()
|
||||
app.state = SimpleNamespace()
|
||||
app.state.store = store
|
||||
return app
|
||||
|
||||
|
||||
def _make_provider(user_count=0, admin_user=None):
|
||||
p = AsyncMock()
|
||||
p.count_users = AsyncMock(return_value=user_count)
|
||||
p.create_user = AsyncMock(
|
||||
side_effect=lambda **kw: User(
|
||||
email=kw["email"],
|
||||
password_hash="hashed",
|
||||
system_role=kw.get("system_role", "user"),
|
||||
needs_setup=kw.get("needs_setup", False),
|
||||
)
|
||||
)
|
||||
p.get_user_by_email = AsyncMock(return_value=admin_user)
|
||||
p.update_user = AsyncMock(side_effect=lambda u: u)
|
||||
return p
|
||||
|
||||
|
||||
# ── First boot: no users ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_first_boot_creates_admin():
|
||||
"""count_users==0 → create admin with needs_setup=True."""
|
||||
provider = _make_provider(user_count=0)
|
||||
app = _make_app_stub()
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
provider.create_user.assert_called_once()
|
||||
call_kwargs = provider.create_user.call_args[1]
|
||||
assert call_kwargs["email"] == "admin@deerflow.dev"
|
||||
assert call_kwargs["system_role"] == "admin"
|
||||
assert call_kwargs["needs_setup"] is True
|
||||
assert len(call_kwargs["password"]) > 10 # random password generated
|
||||
|
||||
|
||||
def test_first_boot_triggers_migration_if_store_present():
|
||||
"""First boot with store → _migrate_orphaned_threads called."""
|
||||
provider = _make_provider(user_count=0)
|
||||
store = AsyncMock()
|
||||
store.asearch = AsyncMock(return_value=[])
|
||||
app = _make_app_stub(store=store)
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
store.asearch.assert_called_once()
|
||||
|
||||
|
||||
def test_first_boot_no_store_skips_migration():
|
||||
"""First boot without store → no crash, migration skipped."""
|
||||
provider = _make_provider(user_count=0)
|
||||
app = _make_app_stub(store=None)
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
provider.create_user.assert_called_once()
|
||||
|
||||
|
||||
# ── Subsequent boot: needs_setup=True → auto-reset ───────────────────────
|
||||
|
||||
|
||||
def test_needs_setup_true_resets_password():
|
||||
"""Existing admin with needs_setup=True → password reset + token_version bumped."""
|
||||
admin = User(
|
||||
email="admin@deerflow.dev",
|
||||
password_hash="old-hash",
|
||||
system_role="admin",
|
||||
needs_setup=True,
|
||||
token_version=0,
|
||||
created_at=datetime.now(UTC) - timedelta(seconds=30),
|
||||
)
|
||||
provider = _make_provider(user_count=1, admin_user=admin)
|
||||
app = _make_app_stub()
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="new-hash"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
# Password was reset
|
||||
provider.update_user.assert_called_once()
|
||||
updated = provider.update_user.call_args[0][0]
|
||||
assert updated.password_hash == "new-hash"
|
||||
assert updated.token_version == 1
|
||||
|
||||
|
||||
def test_needs_setup_true_consecutive_resets_increment_version():
|
||||
"""Two boots with needs_setup=True → token_version increments each time."""
|
||||
admin = User(
|
||||
email="admin@deerflow.dev",
|
||||
password_hash="hash",
|
||||
system_role="admin",
|
||||
needs_setup=True,
|
||||
token_version=3,
|
||||
created_at=datetime.now(UTC) - timedelta(seconds=30),
|
||||
)
|
||||
provider = _make_provider(user_count=1, admin_user=admin)
|
||||
app = _make_app_stub()
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="new-hash"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
updated = provider.update_user.call_args[0][0]
|
||||
assert updated.token_version == 4
|
||||
|
||||
|
||||
# ── Subsequent boot: needs_setup=False → no-op ──────────────────────────
|
||||
|
||||
|
||||
def test_needs_setup_false_no_reset():
|
||||
"""Admin with needs_setup=False → no password reset, no update."""
|
||||
admin = User(
|
||||
email="admin@deerflow.dev",
|
||||
password_hash="stable-hash",
|
||||
system_role="admin",
|
||||
needs_setup=False,
|
||||
token_version=2,
|
||||
)
|
||||
provider = _make_provider(user_count=1, admin_user=admin)
|
||||
app = _make_app_stub()
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
provider.update_user.assert_not_called()
|
||||
assert admin.password_hash == "stable-hash"
|
||||
assert admin.token_version == 2
|
||||
|
||||
|
||||
# ── Edge cases ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_no_admin_email_found_no_crash():
|
||||
"""Users exist but no admin@deerflow.dev → no crash, no reset."""
|
||||
provider = _make_provider(user_count=3, admin_user=None)
|
||||
app = _make_app_stub()
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
provider.update_user.assert_not_called()
|
||||
provider.create_user.assert_not_called()
|
||||
|
||||
|
||||
def test_migration_failure_is_non_fatal():
|
||||
"""_migrate_orphaned_threads exception is caught and logged."""
|
||||
provider = _make_provider(user_count=0)
|
||||
store = AsyncMock()
|
||||
store.asearch = AsyncMock(side_effect=RuntimeError("store crashed"))
|
||||
app = _make_app_stub(store=store)
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
# Should not raise
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
provider.create_user.assert_called_once()
|
||||
|
||||
|
||||
# ── Section 5.1-5.6 upgrade path: orphan thread migration ────────────────
|
||||
|
||||
|
||||
def test_migrate_orphaned_threads_stamps_owner_id_on_unowned_rows():
|
||||
"""First boot finds Store-only legacy threads → stamps admin's id.
|
||||
|
||||
Validates the **TC-UPG-02 upgrade story**: an operator running main
|
||||
(no auth) accumulates threads in the LangGraph Store namespace
|
||||
``("threads",)`` with no ``metadata.owner_id``. After upgrading to
|
||||
feat/auth-on-2.0-rc, the first ``_ensure_admin_user`` boot should
|
||||
rewrite each unowned item with the freshly created admin's id.
|
||||
"""
|
||||
from app.gateway.app import _migrate_orphaned_threads
|
||||
|
||||
# Three orphan items + one already-owned item that should be left alone.
|
||||
items = [
|
||||
SimpleNamespace(key="t1", value={"metadata": {"title": "old-thread-1"}}),
|
||||
SimpleNamespace(key="t2", value={"metadata": {"title": "old-thread-2"}}),
|
||||
SimpleNamespace(key="t3", value={"metadata": {}}),
|
||||
SimpleNamespace(key="t4", value={"metadata": {"owner_id": "someone-else", "title": "preserved"}}),
|
||||
]
|
||||
store = AsyncMock()
|
||||
# asearch returns the entire batch on first call, then an empty page
|
||||
# to terminate _iter_store_items.
|
||||
store.asearch = AsyncMock(side_effect=[items, []])
|
||||
aput_calls: list[tuple[tuple, str, dict]] = []
|
||||
|
||||
async def _record_aput(namespace, key, value):
|
||||
aput_calls.append((namespace, key, value))
|
||||
|
||||
store.aput = AsyncMock(side_effect=_record_aput)
|
||||
|
||||
migrated = asyncio.run(_migrate_orphaned_threads(store, "admin-id-42"))
|
||||
|
||||
# Three orphan rows migrated, one preserved.
|
||||
assert migrated == 3
|
||||
assert len(aput_calls) == 3
|
||||
rewritten_keys = {call[1] for call in aput_calls}
|
||||
assert rewritten_keys == {"t1", "t2", "t3"}
|
||||
# Each rewrite carries the new owner_id; titles preserved where present.
|
||||
by_key = {call[1]: call[2] for call in aput_calls}
|
||||
assert by_key["t1"]["metadata"]["owner_id"] == "admin-id-42"
|
||||
assert by_key["t1"]["metadata"]["title"] == "old-thread-1"
|
||||
assert by_key["t3"]["metadata"]["owner_id"] == "admin-id-42"
|
||||
# The pre-owned item must NOT have been rewritten.
|
||||
assert "t4" not in rewritten_keys
|
||||
|
||||
|
||||
def test_migrate_orphaned_threads_empty_store_is_noop():
|
||||
"""A store with no threads → migrated == 0, no aput calls."""
|
||||
from app.gateway.app import _migrate_orphaned_threads
|
||||
|
||||
store = AsyncMock()
|
||||
store.asearch = AsyncMock(return_value=[])
|
||||
store.aput = AsyncMock()
|
||||
|
||||
migrated = asyncio.run(_migrate_orphaned_threads(store, "admin-id-42"))
|
||||
|
||||
assert migrated == 0
|
||||
store.aput.assert_not_called()
|
||||
|
||||
|
||||
def test_iter_store_items_walks_multiple_pages():
|
||||
"""Cursor-style iterator pulls every page until a short page terminates.
|
||||
|
||||
Closes the regression where the old hardcoded ``limit=1000`` could
|
||||
silently drop orphans on a large pre-upgrade dataset. The migration
|
||||
code path uses the default ``page_size=500``; this test pins the
|
||||
iterator with ``page_size=2`` so it stays fast.
|
||||
"""
|
||||
from app.gateway.app import _iter_store_items
|
||||
|
||||
page_a = [SimpleNamespace(key=f"t{i}", value={"metadata": {}}) for i in range(2)]
|
||||
page_b = [SimpleNamespace(key=f"t{i + 2}", value={"metadata": {}}) for i in range(2)]
|
||||
page_c: list = [] # short page → loop terminates
|
||||
|
||||
store = AsyncMock()
|
||||
store.asearch = AsyncMock(side_effect=[page_a, page_b, page_c])
|
||||
|
||||
async def _collect():
|
||||
return [item.key async for item in _iter_store_items(store, ("threads",), page_size=2)]
|
||||
|
||||
keys = asyncio.run(_collect())
|
||||
assert keys == ["t0", "t1", "t2", "t3"]
|
||||
# Three asearch calls: full batch, full batch, empty terminator
|
||||
assert store.asearch.await_count == 3
|
||||
|
||||
|
||||
def test_iter_store_items_terminates_on_short_page():
|
||||
"""A short page (len < page_size) ends the loop without an extra call."""
|
||||
from app.gateway.app import _iter_store_items
|
||||
|
||||
page = [SimpleNamespace(key=f"t{i}", value={}) for i in range(3)]
|
||||
store = AsyncMock()
|
||||
store.asearch = AsyncMock(return_value=page)
|
||||
|
||||
async def _collect():
|
||||
return [item.key async for item in _iter_store_items(store, ("threads",), page_size=10)]
|
||||
|
||||
keys = asyncio.run(_collect())
|
||||
assert keys == ["t0", "t1", "t2"]
|
||||
# Only one call — no terminator probe needed because len(batch) < page_size
|
||||
assert store.asearch.await_count == 1
|
||||
@@ -0,0 +1,215 @@
|
||||
"""Tests for FeedbackRepository and follow-up association.
|
||||
|
||||
Uses temp SQLite DB for ORM tests.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.persistence.feedback import FeedbackRepository
|
||||
|
||||
|
||||
async def _make_feedback_repo(tmp_path):
|
||||
from deerflow.persistence.engine import get_session_factory, init_engine
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
return FeedbackRepository(get_session_factory())
|
||||
|
||||
|
||||
async def _cleanup():
|
||||
from deerflow.persistence.engine import close_engine
|
||||
|
||||
await close_engine()
|
||||
|
||||
|
||||
# -- FeedbackRepository --
|
||||
|
||||
|
||||
class TestFeedbackRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_create_positive(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
record = await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
assert record["feedback_id"]
|
||||
assert record["rating"] == 1
|
||||
assert record["run_id"] == "r1"
|
||||
assert record["thread_id"] == "t1"
|
||||
assert "created_at" in record
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_negative_with_comment(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
record = await repo.create(
|
||||
run_id="r1",
|
||||
thread_id="t1",
|
||||
rating=-1,
|
||||
comment="Response was inaccurate",
|
||||
)
|
||||
assert record["rating"] == -1
|
||||
assert record["comment"] == "Response was inaccurate"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_message_id(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
record = await repo.create(run_id="r1", thread_id="t1", rating=1, message_id="msg-42")
|
||||
assert record["message_id"] == "msg-42"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_owner(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
record = await repo.create(run_id="r1", thread_id="t1", rating=1, owner_id="user-1")
|
||||
assert record["owner_id"] == "user-1"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_invalid_rating_zero(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
with pytest.raises(ValueError):
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=0)
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_invalid_rating_five(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
with pytest.raises(ValueError):
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=5)
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
created = await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
fetched = await repo.get(created["feedback_id"])
|
||||
assert fetched is not None
|
||||
assert fetched["feedback_id"] == created["feedback_id"]
|
||||
assert fetched["rating"] == 1
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_nonexistent(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
assert await repo.get("nonexistent") is None
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_run(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=-1)
|
||||
await repo.create(run_id="r2", thread_id="t1", rating=1)
|
||||
results = await repo.list_by_run("t1", "r1")
|
||||
assert len(results) == 2
|
||||
assert all(r["run_id"] == "r1" for r in results)
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
await repo.create(run_id="r2", thread_id="t1", rating=-1)
|
||||
await repo.create(run_id="r3", thread_id="t2", rating=1)
|
||||
results = await repo.list_by_thread("t1")
|
||||
assert len(results) == 2
|
||||
assert all(r["thread_id"] == "t1" for r in results)
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
created = await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
deleted = await repo.delete(created["feedback_id"])
|
||||
assert deleted is True
|
||||
assert await repo.get(created["feedback_id"]) is None
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_nonexistent(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
deleted = await repo.delete("nonexistent")
|
||||
assert deleted is False
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_aggregate_by_run(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=-1)
|
||||
stats = await repo.aggregate_by_run("t1", "r1")
|
||||
assert stats["total"] == 3
|
||||
assert stats["positive"] == 2
|
||||
assert stats["negative"] == 1
|
||||
assert stats["run_id"] == "r1"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_aggregate_empty(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
stats = await repo.aggregate_by_run("t1", "r1")
|
||||
assert stats["total"] == 0
|
||||
assert stats["positive"] == 0
|
||||
assert stats["negative"] == 0
|
||||
await _cleanup()
|
||||
|
||||
|
||||
# -- Follow-up association --
|
||||
|
||||
|
||||
class TestFollowUpAssociation:
|
||||
@pytest.mark.anyio
|
||||
async def test_run_records_follow_up_via_memory_store(self):
|
||||
"""MemoryRunStore stores follow_up_to_run_id in kwargs."""
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
store = MemoryRunStore()
|
||||
await store.put("r1", thread_id="t1", status="success")
|
||||
# MemoryRunStore doesn't have follow_up_to_run_id as a top-level param,
|
||||
# but it can be passed via metadata
|
||||
await store.put("r2", thread_id="t1", metadata={"follow_up_to_run_id": "r1"})
|
||||
run = await store.get("r2")
|
||||
assert run["metadata"]["follow_up_to_run_id"] == "r1"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_human_message_has_follow_up_metadata(self):
|
||||
"""human_message event metadata includes follow_up_to_run_id."""
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
|
||||
event_store = MemoryRunEventStore()
|
||||
await event_store.put(
|
||||
thread_id="t1",
|
||||
run_id="r2",
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content="Tell me more about that",
|
||||
metadata={"follow_up_to_run_id": "r1"},
|
||||
)
|
||||
messages = await event_store.list_messages("t1")
|
||||
assert messages[0]["metadata"]["follow_up_to_run_id"] == "r1"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_follow_up_auto_detection_logic(self):
|
||||
"""Simulate the auto-detection: latest successful run becomes follow_up_to."""
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
store = MemoryRunStore()
|
||||
await store.put("r1", thread_id="t1", status="success")
|
||||
await store.put("r2", thread_id="t1", status="error")
|
||||
|
||||
# Auto-detect: list_by_thread returns newest first
|
||||
recent = await store.list_by_thread("t1", limit=1)
|
||||
follow_up = None
|
||||
if recent and recent[0].get("status") == "success":
|
||||
follow_up = recent[0]["run_id"]
|
||||
# r2 (error) is newest, so no follow_up detected
|
||||
assert follow_up is None
|
||||
|
||||
# Now add a successful run
|
||||
await store.put("r3", thread_id="t1", status="success")
|
||||
recent = await store.list_by_thread("t1", limit=1)
|
||||
follow_up = None
|
||||
if recent and recent[0].get("status") == "success":
|
||||
follow_up = recent[0]["run_id"]
|
||||
assert follow_up == "r3"
|
||||
@@ -0,0 +1,312 @@
|
||||
"""Tests for LangGraph Server auth handler (langgraph_auth.py).
|
||||
|
||||
Validates that the LangGraph auth layer enforces the same rules as Gateway:
|
||||
cookie → JWT decode → DB lookup → token_version check → owner filter
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-for-langgraph-auth-testing-min-32")
|
||||
|
||||
from langgraph_sdk import Auth
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
from app.gateway.auth.jwt import create_access_token, decode_token
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.langgraph_auth import add_owner_filter, authenticate
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
_JWT_SECRET = "test-secret-key-for-langgraph-auth-testing-min-32"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_auth_config():
|
||||
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
|
||||
yield
|
||||
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
|
||||
|
||||
|
||||
def _req(cookies=None, method="GET", headers=None):
|
||||
return SimpleNamespace(cookies=cookies or {}, method=method, headers=headers or {})
|
||||
|
||||
|
||||
def _user(user_id=None, token_version=0):
|
||||
return User(email="test@example.com", password_hash="fakehash", system_role="user", id=user_id or uuid4(), token_version=token_version)
|
||||
|
||||
|
||||
def _mock_provider(user=None):
|
||||
p = AsyncMock()
|
||||
p.get_user = AsyncMock(return_value=user)
|
||||
return p
|
||||
|
||||
|
||||
# ── @auth.authenticate ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_no_cookie_raises_401():
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req()))
|
||||
assert exc.value.status_code == 401
|
||||
assert "Not authenticated" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_invalid_jwt_raises_401():
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": "garbage"})))
|
||||
assert exc.value.status_code == 401
|
||||
assert "Token error" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_expired_jwt_raises_401():
|
||||
token = create_access_token("user-1", expires_delta=timedelta(seconds=-1))
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": token})))
|
||||
assert exc.value.status_code == 401
|
||||
|
||||
|
||||
def test_user_not_found_raises_401():
|
||||
token = create_access_token("ghost")
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(None)):
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": token})))
|
||||
assert exc.value.status_code == 401
|
||||
assert "User not found" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_token_version_mismatch_raises_401():
|
||||
user = _user(token_version=2)
|
||||
token = create_access_token(str(user.id), token_version=1)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": token})))
|
||||
assert exc.value.status_code == 401
|
||||
assert "revoked" in str(exc.value.detail).lower()
|
||||
|
||||
|
||||
def test_valid_token_returns_user_id():
|
||||
user = _user(token_version=0)
|
||||
token = create_access_token(str(user.id), token_version=0)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
result = asyncio.run(authenticate(_req({"access_token": token})))
|
||||
assert result == str(user.id)
|
||||
|
||||
|
||||
def test_valid_token_matching_version():
|
||||
user = _user(token_version=5)
|
||||
token = create_access_token(str(user.id), token_version=5)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
result = asyncio.run(authenticate(_req({"access_token": token})))
|
||||
assert result == str(user.id)
|
||||
|
||||
|
||||
# ── @auth.authenticate edge cases ────────────────────────────────────────
|
||||
|
||||
|
||||
def test_provider_exception_propagates():
|
||||
"""Provider raises → should not be swallowed silently."""
|
||||
token = create_access_token("user-1")
|
||||
p = AsyncMock()
|
||||
p.get_user = AsyncMock(side_effect=RuntimeError("DB down"))
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=p):
|
||||
with pytest.raises(RuntimeError, match="DB down"):
|
||||
asyncio.run(authenticate(_req({"access_token": token})))
|
||||
|
||||
|
||||
def test_jwt_missing_ver_defaults_to_zero():
|
||||
"""JWT without 'ver' claim → decoded as ver=0, matches user with token_version=0."""
|
||||
import jwt as pyjwt
|
||||
|
||||
uid = str(uuid4())
|
||||
raw = pyjwt.encode({"sub": uid, "exp": 9999999999, "iat": 1000000000}, _JWT_SECRET, algorithm="HS256")
|
||||
user = _user(user_id=uid, token_version=0)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
result = asyncio.run(authenticate(_req({"access_token": raw})))
|
||||
assert result == uid
|
||||
|
||||
|
||||
def test_jwt_missing_ver_rejected_when_user_version_nonzero():
|
||||
"""JWT without 'ver' (defaults 0) vs user with token_version=1 → 401."""
|
||||
import jwt as pyjwt
|
||||
|
||||
uid = str(uuid4())
|
||||
raw = pyjwt.encode({"sub": uid, "exp": 9999999999, "iat": 1000000000}, _JWT_SECRET, algorithm="HS256")
|
||||
user = _user(user_id=uid, token_version=1)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": raw})))
|
||||
assert exc.value.status_code == 401
|
||||
|
||||
|
||||
def test_wrong_secret_raises_401():
|
||||
"""Token signed with different secret → 401."""
|
||||
import jwt as pyjwt
|
||||
|
||||
raw = pyjwt.encode({"sub": "user-1", "exp": 9999999999, "ver": 0}, "wrong-secret-that-is-long-enough-32chars!", algorithm="HS256")
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": raw})))
|
||||
assert exc.value.status_code == 401
|
||||
|
||||
|
||||
# ── @auth.on (owner filter) ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class _FakeUser:
|
||||
"""Minimal BaseUser-compatible object without langgraph_api.config dependency."""
|
||||
|
||||
def __init__(self, identity: str):
|
||||
self.identity = identity
|
||||
self.is_authenticated = True
|
||||
self.display_name = identity
|
||||
|
||||
|
||||
def _make_ctx(user_id):
|
||||
return Auth.types.AuthContext(resource="threads", action="create", user=_FakeUser(user_id), permissions=[])
|
||||
|
||||
|
||||
def test_filter_injects_user_id():
|
||||
value = {}
|
||||
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
||||
assert value["metadata"]["owner_id"] == "user-a"
|
||||
|
||||
|
||||
def test_filter_preserves_existing_metadata():
|
||||
value = {"metadata": {"title": "hello"}}
|
||||
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
||||
assert value["metadata"]["owner_id"] == "user-a"
|
||||
assert value["metadata"]["title"] == "hello"
|
||||
|
||||
|
||||
def test_filter_returns_user_id_dict():
|
||||
result = asyncio.run(add_owner_filter(_make_ctx("user-x"), {}))
|
||||
assert result == {"owner_id": "user-x"}
|
||||
|
||||
|
||||
def test_filter_read_write_consistency():
|
||||
value = {}
|
||||
filter_dict = asyncio.run(add_owner_filter(_make_ctx("user-1"), value))
|
||||
assert value["metadata"]["owner_id"] == filter_dict["owner_id"]
|
||||
|
||||
|
||||
def test_different_users_different_filters():
|
||||
f_a = asyncio.run(add_owner_filter(_make_ctx("a"), {}))
|
||||
f_b = asyncio.run(add_owner_filter(_make_ctx("b"), {}))
|
||||
assert f_a["owner_id"] != f_b["owner_id"]
|
||||
|
||||
|
||||
def test_filter_overrides_conflicting_user_id():
|
||||
"""If value already has a different user_id in metadata, it gets overwritten."""
|
||||
value = {"metadata": {"owner_id": "attacker"}}
|
||||
asyncio.run(add_owner_filter(_make_ctx("real-owner"), value))
|
||||
assert value["metadata"]["owner_id"] == "real-owner"
|
||||
|
||||
|
||||
def test_filter_with_empty_metadata():
|
||||
"""Explicit empty metadata dict is fine."""
|
||||
value = {"metadata": {}}
|
||||
result = asyncio.run(add_owner_filter(_make_ctx("user-z"), value))
|
||||
assert value["metadata"]["owner_id"] == "user-z"
|
||||
assert result == {"owner_id": "user-z"}
|
||||
|
||||
|
||||
# ── Gateway parity ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_shared_jwt_secret():
|
||||
token = create_access_token("user-1", token_version=3)
|
||||
payload = decode_token(token)
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
assert not isinstance(payload, TokenError)
|
||||
assert payload.sub == "user-1"
|
||||
assert payload.ver == 3
|
||||
|
||||
|
||||
def test_langgraph_json_has_auth_path():
|
||||
import json
|
||||
|
||||
config = json.loads((Path(__file__).parent.parent / "langgraph.json").read_text())
|
||||
assert "auth" in config
|
||||
assert "langgraph_auth" in config["auth"]["path"]
|
||||
|
||||
|
||||
def test_auth_handler_has_both_layers():
|
||||
from app.gateway.langgraph_auth import auth
|
||||
|
||||
assert auth._authenticate_handler is not None
|
||||
assert len(auth._global_handlers) == 1
|
||||
|
||||
|
||||
# ── CSRF in LangGraph auth ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_csrf_get_no_check():
|
||||
"""GET requests skip CSRF — should proceed to JWT validation."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req(method="GET")))
|
||||
# Rejected by missing cookie, NOT by CSRF
|
||||
assert exc.value.status_code == 401
|
||||
assert "Not authenticated" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_csrf_post_missing_token():
|
||||
"""POST without CSRF token → 403."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req(method="POST", cookies={"access_token": "some-jwt"})))
|
||||
assert exc.value.status_code == 403
|
||||
assert "CSRF token missing" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_csrf_post_mismatched_token():
|
||||
"""POST with mismatched CSRF tokens → 403."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(
|
||||
authenticate(
|
||||
_req(
|
||||
method="POST",
|
||||
cookies={"access_token": "some-jwt", "csrf_token": "real-token"},
|
||||
headers={"x-csrf-token": "wrong-token"},
|
||||
)
|
||||
)
|
||||
)
|
||||
assert exc.value.status_code == 403
|
||||
assert "mismatch" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_csrf_post_matching_token_proceeds_to_jwt():
|
||||
"""POST with matching CSRF tokens passes CSRF check, then fails on JWT."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(
|
||||
authenticate(
|
||||
_req(
|
||||
method="POST",
|
||||
cookies={"access_token": "garbage", "csrf_token": "same-token"},
|
||||
headers={"x-csrf-token": "same-token"},
|
||||
)
|
||||
)
|
||||
)
|
||||
# Past CSRF, rejected by JWT decode
|
||||
assert exc.value.status_code == 401
|
||||
assert "Token error" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_csrf_put_requires_token():
|
||||
"""PUT also requires CSRF."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req(method="PUT", cookies={"access_token": "jwt"})))
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
|
||||
def test_csrf_delete_requires_token():
|
||||
"""DELETE also requires CSRF."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req(method="DELETE", cookies={"access_token": "jwt"})))
|
||||
assert exc.value.status_code == 403
|
||||
@@ -146,8 +146,11 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
|
||||
lambda: SummarizationConfig(enabled=True, model_name="model-masswork"),
|
||||
)
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
fake_model = object()
|
||||
fake_model = MagicMock()
|
||||
fake_model.with_config.return_value = fake_model
|
||||
|
||||
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None):
|
||||
captured["name"] = name
|
||||
@@ -163,3 +166,4 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
|
||||
assert captured["name"] == "model-masswork"
|
||||
assert captured["thinking_enabled"] is False
|
||||
assert middleware["model"] is fake_model
|
||||
fake_model.with_config.assert_called_once_with(tags=["middleware:summarize"])
|
||||
|
||||
@@ -793,6 +793,84 @@ def test_thinking_disabled_vllm_enable_thinking_format(monkeypatch):
|
||||
assert captured.get("reasoning_effort") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# stream_usage injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeWithStreamUsage(FakeChatModel):
|
||||
"""Fake model that declares stream_usage in model_fields (like BaseChatOpenAI)."""
|
||||
|
||||
stream_usage: bool | None = None
|
||||
|
||||
|
||||
def test_stream_usage_injected_for_openai_compatible_model(monkeypatch):
|
||||
"""Factory should set stream_usage=True for models with stream_usage field."""
|
||||
cfg = _make_app_config([_make_model("deepseek", use="langchain_deepseek:ChatDeepSeek")])
|
||||
_patch_factory(monkeypatch, cfg, model_class=_FakeWithStreamUsage)
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
class CapturingModel(_FakeWithStreamUsage):
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
BaseChatModel.__init__(self, **kwargs)
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="deepseek")
|
||||
|
||||
assert captured.get("stream_usage") is True
|
||||
|
||||
|
||||
def test_stream_usage_not_injected_for_non_openai_model(monkeypatch):
|
||||
"""Factory should NOT inject stream_usage for models without the field."""
|
||||
cfg = _make_app_config([_make_model("claude", use="langchain_anthropic:ChatAnthropic")])
|
||||
_patch_factory(monkeypatch, cfg)
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
class CapturingModel(FakeChatModel):
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
BaseChatModel.__init__(self, **kwargs)
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="claude")
|
||||
|
||||
assert "stream_usage" not in captured
|
||||
|
||||
|
||||
def test_stream_usage_not_overridden_when_explicitly_set_in_config(monkeypatch):
|
||||
"""If config dumps stream_usage=False, factory should respect it."""
|
||||
cfg = _make_app_config([_make_model("deepseek", use="langchain_deepseek:ChatDeepSeek")])
|
||||
_patch_factory(monkeypatch, cfg, model_class=_FakeWithStreamUsage)
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
class CapturingModel(_FakeWithStreamUsage):
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
BaseChatModel.__init__(self, **kwargs)
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
# Simulate config having stream_usage explicitly set by patching model_dump
|
||||
original_get_model_config = cfg.get_model_config
|
||||
|
||||
def patched_get_model_config(name):
|
||||
mc = original_get_model_config(name)
|
||||
mc.stream_usage = False # type: ignore[attr-defined]
|
||||
return mc
|
||||
|
||||
monkeypatch.setattr(cfg, "get_model_config", patched_get_model_config)
|
||||
|
||||
factory_module.create_chat_model(name="deepseek")
|
||||
|
||||
assert captured.get("stream_usage") is False
|
||||
|
||||
|
||||
def test_openai_responses_api_settings_are_passed_to_chatopenai(monkeypatch):
|
||||
model = ModelConfig(
|
||||
name="gpt-5-responses",
|
||||
|
||||
@@ -0,0 +1,465 @@
|
||||
"""Cross-user isolation tests — non-negotiable safety gate.
|
||||
|
||||
Mirrors TC-API-17..20 from backend/docs/AUTH_TEST_PLAN.md. A failure
|
||||
here means users can see each other's data; PR must not merge.
|
||||
|
||||
Architecture note
|
||||
-----------------
|
||||
These tests bypass the HTTP layer and exercise the storage-layer
|
||||
owner filter directly by switching the ``user_context`` contextvar
|
||||
between two users. The safety property under test is:
|
||||
|
||||
After a repository write with owner_id=A, a subsequent read with
|
||||
owner_id=B must not return the row, and vice versa.
|
||||
|
||||
The HTTP layer is covered by test_auth_middleware.py, which proves
|
||||
that a request cookie reaches the ``set_current_user`` call. Together
|
||||
the two suites prove the full chain:
|
||||
|
||||
cookie → middleware → contextvar → repository → isolation
|
||||
|
||||
Every test in this file opts out of the autouse contextvar fixture
|
||||
(``@pytest.mark.no_auto_user``) so it can set the contextvar to the
|
||||
specific users it cares about.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.user_context import (
|
||||
reset_current_user,
|
||||
set_current_user,
|
||||
)
|
||||
|
||||
USER_A = SimpleNamespace(id="user-a", email="a@test.local")
|
||||
USER_B = SimpleNamespace(id="user-b", email="b@test.local")
|
||||
|
||||
|
||||
async def _make_engines(tmp_path):
|
||||
"""Initialize the shared engine against a per-test SQLite DB.
|
||||
|
||||
Returns a cleanup coroutine the caller should await at the end.
|
||||
"""
|
||||
from deerflow.persistence.engine import close_engine, init_engine
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'isolation.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
return close_engine
|
||||
|
||||
|
||||
def _as_user(user):
|
||||
"""Context manager-like helper that set/reset the contextvar."""
|
||||
|
||||
class _Ctx:
|
||||
def __enter__(self):
|
||||
self._token = set_current_user(user)
|
||||
return user
|
||||
|
||||
def __exit__(self, *exc):
|
||||
reset_current_user(self._token)
|
||||
|
||||
return _Ctx()
|
||||
|
||||
|
||||
# ── TC-API-17 — threads_meta isolation ────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_thread_meta_cross_user_isolation(tmp_path):
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
repo = ThreadMetaRepository(get_session_factory())
|
||||
|
||||
# User A creates a thread.
|
||||
with _as_user(USER_A):
|
||||
await repo.create("t-alpha", display_name="A's private thread")
|
||||
|
||||
# User B creates a thread.
|
||||
with _as_user(USER_B):
|
||||
await repo.create("t-beta", display_name="B's private thread")
|
||||
|
||||
# User A must see only A's thread.
|
||||
with _as_user(USER_A):
|
||||
a_view = await repo.get("t-alpha")
|
||||
assert a_view is not None
|
||||
assert a_view["display_name"] == "A's private thread"
|
||||
|
||||
# CRITICAL: User A must NOT see B's thread.
|
||||
leaked = await repo.get("t-beta")
|
||||
assert leaked is None, f"User A leaked User B's thread: {leaked}"
|
||||
|
||||
# Search should only return A's threads.
|
||||
results = await repo.search()
|
||||
assert [r["thread_id"] for r in results] == ["t-alpha"]
|
||||
|
||||
# User B must see only B's thread.
|
||||
with _as_user(USER_B):
|
||||
b_view = await repo.get("t-beta")
|
||||
assert b_view is not None
|
||||
assert b_view["display_name"] == "B's private thread"
|
||||
|
||||
leaked = await repo.get("t-alpha")
|
||||
assert leaked is None, f"User B leaked User A's thread: {leaked}"
|
||||
|
||||
results = await repo.search()
|
||||
assert [r["thread_id"] for r in results] == ["t-beta"]
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_thread_meta_cross_user_mutation_denied(tmp_path):
|
||||
"""User B cannot update or delete a thread owned by User A."""
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
repo = ThreadMetaRepository(get_session_factory())
|
||||
|
||||
with _as_user(USER_A):
|
||||
await repo.create("t-alpha", display_name="original")
|
||||
|
||||
# User B tries to rename A's thread — must be a no-op.
|
||||
with _as_user(USER_B):
|
||||
await repo.update_display_name("t-alpha", "hacked")
|
||||
|
||||
# Verify the row is unchanged from A's perspective.
|
||||
with _as_user(USER_A):
|
||||
row = await repo.get("t-alpha")
|
||||
assert row is not None
|
||||
assert row["display_name"] == "original"
|
||||
|
||||
# User B tries to delete A's thread — must be a no-op.
|
||||
with _as_user(USER_B):
|
||||
await repo.delete("t-alpha")
|
||||
|
||||
# A's thread still exists.
|
||||
with _as_user(USER_A):
|
||||
row = await repo.get("t-alpha")
|
||||
assert row is not None
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
|
||||
# ── TC-API-18 — runs isolation ────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_runs_cross_user_isolation(tmp_path):
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.run import RunRepository
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
repo = RunRepository(get_session_factory())
|
||||
|
||||
with _as_user(USER_A):
|
||||
await repo.put("run-a1", thread_id="t-alpha")
|
||||
await repo.put("run-a2", thread_id="t-alpha")
|
||||
|
||||
with _as_user(USER_B):
|
||||
await repo.put("run-b1", thread_id="t-beta")
|
||||
|
||||
# User A must see only A's runs.
|
||||
with _as_user(USER_A):
|
||||
r = await repo.get("run-a1")
|
||||
assert r is not None
|
||||
assert r["run_id"] == "run-a1"
|
||||
|
||||
leaked = await repo.get("run-b1")
|
||||
assert leaked is None, "User A leaked User B's run"
|
||||
|
||||
a_runs = await repo.list_by_thread("t-alpha")
|
||||
assert {r["run_id"] for r in a_runs} == {"run-a1", "run-a2"}
|
||||
|
||||
# Listing B's thread from A's perspective: empty
|
||||
empty = await repo.list_by_thread("t-beta")
|
||||
assert empty == []
|
||||
|
||||
# User B must see only B's runs.
|
||||
with _as_user(USER_B):
|
||||
leaked = await repo.get("run-a1")
|
||||
assert leaked is None, "User B leaked User A's run"
|
||||
|
||||
b_runs = await repo.list_by_thread("t-beta")
|
||||
assert [r["run_id"] for r in b_runs] == ["run-b1"]
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_runs_cross_user_delete_denied(tmp_path):
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.run import RunRepository
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
repo = RunRepository(get_session_factory())
|
||||
|
||||
with _as_user(USER_A):
|
||||
await repo.put("run-a1", thread_id="t-alpha")
|
||||
|
||||
# User B tries to delete A's run — no-op.
|
||||
with _as_user(USER_B):
|
||||
await repo.delete("run-a1")
|
||||
|
||||
# A's run still exists.
|
||||
with _as_user(USER_A):
|
||||
row = await repo.get("run-a1")
|
||||
assert row is not None
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
|
||||
# ── TC-API-19 — run_events isolation (CRITICAL: content leak) ─────────────
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_run_events_cross_user_isolation(tmp_path):
|
||||
"""run_events holds raw conversation content — most sensitive leak vector."""
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
store = DbRunEventStore(get_session_factory())
|
||||
|
||||
with _as_user(USER_A):
|
||||
await store.put(
|
||||
thread_id="t-alpha",
|
||||
run_id="run-a1",
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content="User A private question",
|
||||
)
|
||||
await store.put(
|
||||
thread_id="t-alpha",
|
||||
run_id="run-a1",
|
||||
event_type="ai_message",
|
||||
category="message",
|
||||
content="User A private answer",
|
||||
)
|
||||
|
||||
with _as_user(USER_B):
|
||||
await store.put(
|
||||
thread_id="t-beta",
|
||||
run_id="run-b1",
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content="User B private question",
|
||||
)
|
||||
|
||||
# User A must see only A's events — CRITICAL.
|
||||
with _as_user(USER_A):
|
||||
msgs = await store.list_messages("t-alpha")
|
||||
contents = [m["content"] for m in msgs]
|
||||
assert "User A private question" in contents
|
||||
assert "User A private answer" in contents
|
||||
# CRITICAL: User B's content must not appear.
|
||||
assert "User B private question" not in contents
|
||||
|
||||
# Attempt to read B's thread by guessing thread_id.
|
||||
leaked = await store.list_messages("t-beta")
|
||||
assert leaked == [], f"User A leaked User B's messages: {leaked}"
|
||||
|
||||
leaked_events = await store.list_events("t-beta", "run-b1")
|
||||
assert leaked_events == [], "User A leaked User B's events"
|
||||
|
||||
# count_messages must also be zero for B's thread from A's view.
|
||||
count = await store.count_messages("t-beta")
|
||||
assert count == 0
|
||||
|
||||
# User B must see only B's events.
|
||||
with _as_user(USER_B):
|
||||
msgs = await store.list_messages("t-beta")
|
||||
contents = [m["content"] for m in msgs]
|
||||
assert "User B private question" in contents
|
||||
assert "User A private question" not in contents
|
||||
assert "User A private answer" not in contents
|
||||
|
||||
count = await store.count_messages("t-alpha")
|
||||
assert count == 0
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_run_events_cross_user_delete_denied(tmp_path):
|
||||
"""User B cannot delete User A's event stream."""
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
store = DbRunEventStore(get_session_factory())
|
||||
|
||||
with _as_user(USER_A):
|
||||
await store.put(
|
||||
thread_id="t-alpha",
|
||||
run_id="run-a1",
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content="hello",
|
||||
)
|
||||
|
||||
# User B tries to wipe A's thread events.
|
||||
with _as_user(USER_B):
|
||||
removed = await store.delete_by_thread("t-alpha")
|
||||
assert removed == 0, f"User B deleted {removed} of User A's events"
|
||||
|
||||
# A's events still exist.
|
||||
with _as_user(USER_A):
|
||||
count = await store.count_messages("t-alpha")
|
||||
assert count == 1
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
|
||||
# ── TC-API-20 — feedback isolation ────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_feedback_cross_user_isolation(tmp_path):
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.feedback import FeedbackRepository
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
repo = FeedbackRepository(get_session_factory())
|
||||
|
||||
# User A submits positive feedback.
|
||||
with _as_user(USER_A):
|
||||
a_feedback = await repo.create(
|
||||
run_id="run-a1",
|
||||
thread_id="t-alpha",
|
||||
rating=1,
|
||||
comment="A liked this",
|
||||
)
|
||||
|
||||
# User B submits negative feedback.
|
||||
with _as_user(USER_B):
|
||||
b_feedback = await repo.create(
|
||||
run_id="run-b1",
|
||||
thread_id="t-beta",
|
||||
rating=-1,
|
||||
comment="B disliked this",
|
||||
)
|
||||
|
||||
# User A must see only A's feedback.
|
||||
with _as_user(USER_A):
|
||||
retrieved = await repo.get(a_feedback["feedback_id"])
|
||||
assert retrieved is not None
|
||||
assert retrieved["comment"] == "A liked this"
|
||||
|
||||
# CRITICAL: cannot read B's feedback by id.
|
||||
leaked = await repo.get(b_feedback["feedback_id"])
|
||||
assert leaked is None, "User A leaked User B's feedback"
|
||||
|
||||
# list_by_run for B's run must be empty.
|
||||
empty = await repo.list_by_run("t-beta", "run-b1")
|
||||
assert empty == []
|
||||
|
||||
# User B must see only B's feedback.
|
||||
with _as_user(USER_B):
|
||||
leaked = await repo.get(a_feedback["feedback_id"])
|
||||
assert leaked is None, "User B leaked User A's feedback"
|
||||
|
||||
b_list = await repo.list_by_run("t-beta", "run-b1")
|
||||
assert len(b_list) == 1
|
||||
assert b_list[0]["comment"] == "B disliked this"
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_feedback_cross_user_delete_denied(tmp_path):
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.feedback import FeedbackRepository
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
repo = FeedbackRepository(get_session_factory())
|
||||
|
||||
with _as_user(USER_A):
|
||||
fb = await repo.create(run_id="run-a1", thread_id="t-alpha", rating=1)
|
||||
|
||||
# User B tries to delete A's feedback — must return False (no-op).
|
||||
with _as_user(USER_B):
|
||||
deleted = await repo.delete(fb["feedback_id"])
|
||||
assert deleted is False, "User B deleted User A's feedback"
|
||||
|
||||
# A's feedback still retrievable.
|
||||
with _as_user(USER_A):
|
||||
row = await repo.get(fb["feedback_id"])
|
||||
assert row is not None
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
|
||||
# ── Regression: AUTO sentinel without contextvar must raise ───────────────
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_repository_without_context_raises(tmp_path):
|
||||
"""Defense-in-depth: calling repo methods without a user context errors."""
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
repo = ThreadMetaRepository(get_session_factory())
|
||||
# Contextvar is explicitly unset under @pytest.mark.no_auto_user.
|
||||
with pytest.raises(RuntimeError, match="no user context is set"):
|
||||
await repo.get("anything")
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
|
||||
# ── Escape hatch: explicit owner_id=None bypasses filter (for migration) ──
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_explicit_none_bypasses_filter(tmp_path):
|
||||
"""Migration scripts pass owner_id=None to see all rows regardless of owner."""
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
repo = ThreadMetaRepository(get_session_factory())
|
||||
|
||||
# Seed data as two different users.
|
||||
with _as_user(USER_A):
|
||||
await repo.create("t-alpha")
|
||||
with _as_user(USER_B):
|
||||
await repo.create("t-beta")
|
||||
|
||||
# Migration-style read: no contextvar, explicit None bypass.
|
||||
all_rows = await repo.search(owner_id=None)
|
||||
thread_ids = {r["thread_id"] for r in all_rows}
|
||||
assert thread_ids == {"t-alpha", "t-beta"}
|
||||
|
||||
# Explicit get with None does not apply the filter either.
|
||||
row_a = await repo.get("t-alpha", owner_id=None)
|
||||
assert row_a is not None
|
||||
row_b = await repo.get("t-beta", owner_id=None)
|
||||
assert row_b is not None
|
||||
finally:
|
||||
await cleanup()
|
||||
@@ -0,0 +1,232 @@
|
||||
"""Tests for the persistence layer scaffolding.
|
||||
|
||||
Tests:
|
||||
1. DatabaseConfig property derivation (paths, URLs)
|
||||
2. MemoryRunStore CRUD + owner_id filtering
|
||||
3. Base.to_dict() via inspect mixin
|
||||
4. Engine init/close lifecycle (memory + SQLite)
|
||||
5. Postgres missing-dep error message
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.config.database_config import DatabaseConfig
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
# -- DatabaseConfig --
|
||||
|
||||
|
||||
class TestDatabaseConfig:
|
||||
def test_defaults(self):
|
||||
c = DatabaseConfig()
|
||||
assert c.backend == "memory"
|
||||
assert c.pool_size == 5
|
||||
|
||||
def test_sqlite_paths_are_different(self):
|
||||
c = DatabaseConfig(backend="sqlite", sqlite_dir="./mydata")
|
||||
assert c.checkpointer_sqlite_path.endswith("checkpoints.db")
|
||||
assert c.app_sqlite_path.endswith("app.db")
|
||||
assert "mydata" in c.checkpointer_sqlite_path
|
||||
assert c.checkpointer_sqlite_path != c.app_sqlite_path
|
||||
|
||||
def test_app_sqlalchemy_url_sqlite(self):
|
||||
c = DatabaseConfig(backend="sqlite", sqlite_dir="./data")
|
||||
url = c.app_sqlalchemy_url
|
||||
assert url.startswith("sqlite+aiosqlite:///")
|
||||
assert "app.db" in url
|
||||
|
||||
def test_app_sqlalchemy_url_postgres(self):
|
||||
c = DatabaseConfig(
|
||||
backend="postgres",
|
||||
postgres_url="postgresql://u:p@h:5432/db",
|
||||
)
|
||||
url = c.app_sqlalchemy_url
|
||||
assert url.startswith("postgresql+asyncpg://")
|
||||
assert "u:p@h:5432/db" in url
|
||||
|
||||
def test_app_sqlalchemy_url_postgres_already_asyncpg(self):
|
||||
c = DatabaseConfig(
|
||||
backend="postgres",
|
||||
postgres_url="postgresql+asyncpg://u:p@h:5432/db",
|
||||
)
|
||||
url = c.app_sqlalchemy_url
|
||||
assert url.count("asyncpg") == 1
|
||||
|
||||
def test_memory_has_no_url(self):
|
||||
c = DatabaseConfig(backend="memory")
|
||||
with pytest.raises(ValueError, match="No SQLAlchemy URL"):
|
||||
_ = c.app_sqlalchemy_url
|
||||
|
||||
|
||||
# -- MemoryRunStore --
|
||||
|
||||
|
||||
class TestMemoryRunStore:
|
||||
@pytest.fixture
|
||||
def store(self):
|
||||
return MemoryRunStore()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_put_and_get(self, store):
|
||||
await store.put("r1", thread_id="t1", status="pending")
|
||||
row = await store.get("r1")
|
||||
assert row is not None
|
||||
assert row["run_id"] == "r1"
|
||||
assert row["status"] == "pending"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_missing_returns_none(self, store):
|
||||
assert await store.get("nope") is None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_status(self, store):
|
||||
await store.put("r1", thread_id="t1")
|
||||
await store.update_status("r1", "running")
|
||||
assert (await store.get("r1"))["status"] == "running"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_status_with_error(self, store):
|
||||
await store.put("r1", thread_id="t1")
|
||||
await store.update_status("r1", "error", error="boom")
|
||||
row = await store.get("r1")
|
||||
assert row["status"] == "error"
|
||||
assert row["error"] == "boom"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread(self, store):
|
||||
await store.put("r1", thread_id="t1")
|
||||
await store.put("r2", thread_id="t1")
|
||||
await store.put("r3", thread_id="t2")
|
||||
rows = await store.list_by_thread("t1")
|
||||
assert len(rows) == 2
|
||||
assert all(r["thread_id"] == "t1" for r in rows)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_owner_filter(self, store):
|
||||
await store.put("r1", thread_id="t1", owner_id="alice")
|
||||
await store.put("r2", thread_id="t1", owner_id="bob")
|
||||
rows = await store.list_by_thread("t1", owner_id="alice")
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["owner_id"] == "alice"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_owner_none_returns_all(self, store):
|
||||
await store.put("r1", thread_id="t1", owner_id="alice")
|
||||
await store.put("r2", thread_id="t1", owner_id="bob")
|
||||
rows = await store.list_by_thread("t1", owner_id=None)
|
||||
assert len(rows) == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete(self, store):
|
||||
await store.put("r1", thread_id="t1")
|
||||
await store.delete("r1")
|
||||
assert await store.get("r1") is None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_nonexistent_is_noop(self, store):
|
||||
await store.delete("nope") # should not raise
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_pending(self, store):
|
||||
await store.put("r1", thread_id="t1", status="pending")
|
||||
await store.put("r2", thread_id="t1", status="running")
|
||||
await store.put("r3", thread_id="t2", status="pending")
|
||||
pending = await store.list_pending()
|
||||
assert len(pending) == 2
|
||||
assert all(r["status"] == "pending" for r in pending)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_pending_respects_before(self, store):
|
||||
past = "2020-01-01T00:00:00+00:00"
|
||||
future = "2099-01-01T00:00:00+00:00"
|
||||
await store.put("r1", thread_id="t1", status="pending", created_at=past)
|
||||
await store.put("r2", thread_id="t1", status="pending", created_at=future)
|
||||
pending = await store.list_pending(before=datetime.now(UTC).isoformat())
|
||||
assert len(pending) == 1
|
||||
assert pending[0]["run_id"] == "r1"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_pending_fifo_order(self, store):
|
||||
await store.put("r2", thread_id="t1", status="pending", created_at="2024-01-02T00:00:00+00:00")
|
||||
await store.put("r1", thread_id="t1", status="pending", created_at="2024-01-01T00:00:00+00:00")
|
||||
pending = await store.list_pending()
|
||||
assert pending[0]["run_id"] == "r1"
|
||||
|
||||
|
||||
# -- Base.to_dict mixin --
|
||||
|
||||
|
||||
class TestBaseToDictMixin:
|
||||
@pytest.mark.anyio
|
||||
async def test_to_dict_and_exclude(self, tmp_path):
|
||||
"""Create a temp SQLite DB with a minimal model, verify to_dict."""
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
|
||||
class _Tmp(Base):
|
||||
__tablename__ = "_tmp_test"
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(128))
|
||||
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path / 'test.db'}")
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
sf = async_sessionmaker(engine, expire_on_commit=False)
|
||||
async with sf() as session:
|
||||
session.add(_Tmp(id="1", name="hello"))
|
||||
await session.commit()
|
||||
obj = await session.get(_Tmp, "1")
|
||||
|
||||
assert obj.to_dict() == {"id": "1", "name": "hello"}
|
||||
assert obj.to_dict(exclude={"name"}) == {"id": "1"}
|
||||
assert "_Tmp" in repr(obj)
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
# -- Engine lifecycle --
|
||||
|
||||
|
||||
class TestEngineLifecycle:
|
||||
@pytest.mark.anyio
|
||||
async def test_memory_is_noop(self):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
|
||||
await init_engine("memory")
|
||||
assert get_session_factory() is None
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_sqlite_creates_engine(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
sf = get_session_factory()
|
||||
assert sf is not None
|
||||
async with sf() as session:
|
||||
assert session is not None
|
||||
await close_engine()
|
||||
assert get_session_factory() is None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_postgres_without_asyncpg_gives_actionable_error(self):
|
||||
"""If asyncpg is not installed, error message tells user what to do."""
|
||||
from deerflow.persistence.engine import init_engine
|
||||
|
||||
try:
|
||||
import asyncpg # noqa: F401
|
||||
|
||||
pytest.skip("asyncpg is installed -- cannot test missing-dep path")
|
||||
except ImportError:
|
||||
# asyncpg is not installed — this is the expected state for this test.
|
||||
# We proceed to verify that init_engine raises an actionable ImportError.
|
||||
pass # noqa: S110 — intentionally ignored
|
||||
with pytest.raises(ImportError, match="uv sync --extra postgres"):
|
||||
await init_engine("postgres", url="postgresql+asyncpg://x:x@localhost/x")
|
||||
@@ -0,0 +1,500 @@
|
||||
"""Tests for RunEventStore contract across all backends.
|
||||
|
||||
Uses a helper to create the store for each backend type.
|
||||
Memory tests run directly; DB and JSONL tests create stores inside each test.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store():
|
||||
return MemoryRunEventStore()
|
||||
|
||||
|
||||
# -- Basic write and query --
|
||||
|
||||
|
||||
class TestPutAndSeq:
|
||||
@pytest.mark.anyio
|
||||
async def test_put_returns_dict_with_seq(self, store):
|
||||
record = await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="hello")
|
||||
assert "seq" in record
|
||||
assert record["seq"] == 1
|
||||
assert record["thread_id"] == "t1"
|
||||
assert record["run_id"] == "r1"
|
||||
assert record["event_type"] == "human_message"
|
||||
assert record["category"] == "message"
|
||||
assert record["content"] == "hello"
|
||||
assert "created_at" in record
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_seq_strictly_increasing_same_thread(self, store):
|
||||
r1 = await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
r2 = await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message")
|
||||
r3 = await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
|
||||
assert r1["seq"] == 1
|
||||
assert r2["seq"] == 2
|
||||
assert r3["seq"] == 3
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_seq_independent_across_threads(self, store):
|
||||
r1 = await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
r2 = await store.put(thread_id="t2", run_id="r2", event_type="human_message", category="message")
|
||||
assert r1["seq"] == 1
|
||||
assert r2["seq"] == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_put_respects_provided_created_at(self, store):
|
||||
ts = "2024-06-01T12:00:00+00:00"
|
||||
record = await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", created_at=ts)
|
||||
assert record["created_at"] == ts
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_put_metadata_preserved(self, store):
|
||||
meta = {"model": "gpt-4", "tokens": 100}
|
||||
record = await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace", metadata=meta)
|
||||
assert record["metadata"] == meta
|
||||
|
||||
|
||||
# -- list_messages --
|
||||
|
||||
|
||||
class TestListMessages:
|
||||
@pytest.mark.anyio
|
||||
async def test_only_returns_message_category(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="run_start", category="lifecycle")
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["category"] == "message"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_ascending_seq_order(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="first")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="second")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="third")
|
||||
messages = await store.list_messages("t1")
|
||||
seqs = [m["seq"] for m in messages]
|
||||
assert seqs == sorted(seqs)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_before_seq_pagination(self, store):
|
||||
for i in range(10):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content=str(i))
|
||||
messages = await store.list_messages("t1", before_seq=6, limit=3)
|
||||
assert len(messages) == 3
|
||||
assert [m["seq"] for m in messages] == [3, 4, 5]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_after_seq_pagination(self, store):
|
||||
for i in range(10):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content=str(i))
|
||||
messages = await store.list_messages("t1", after_seq=7, limit=3)
|
||||
assert len(messages) == 3
|
||||
assert [m["seq"] for m in messages] == [8, 9, 10]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_limit_restricts_count(self, store):
|
||||
for _ in range(20):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
messages = await store.list_messages("t1", limit=5)
|
||||
assert len(messages) == 5
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cross_run_unified_ordering(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r2", event_type="ai_message", category="message")
|
||||
messages = await store.list_messages("t1")
|
||||
assert [m["seq"] for m in messages] == [1, 2, 3, 4]
|
||||
assert messages[0]["run_id"] == "r1"
|
||||
assert messages[2]["run_id"] == "r2"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_default_returns_latest(self, store):
|
||||
for _ in range(10):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
messages = await store.list_messages("t1", limit=3)
|
||||
assert [m["seq"] for m in messages] == [8, 9, 10]
|
||||
|
||||
|
||||
# -- list_events --
|
||||
|
||||
|
||||
class TestListEvents:
|
||||
@pytest.mark.anyio
|
||||
async def test_returns_all_categories_for_run(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="run_start", category="lifecycle")
|
||||
events = await store.list_events("t1", "r1")
|
||||
assert len(events) == 3
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_event_types_filter(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="llm_start", category="trace")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="tool_start", category="trace")
|
||||
events = await store.list_events("t1", "r1", event_types=["llm_end"])
|
||||
assert len(events) == 1
|
||||
assert events[0]["event_type"] == "llm_end"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_only_returns_specified_run(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
|
||||
await store.put(thread_id="t1", run_id="r2", event_type="llm_end", category="trace")
|
||||
events = await store.list_events("t1", "r1")
|
||||
assert len(events) == 1
|
||||
assert events[0]["run_id"] == "r1"
|
||||
|
||||
|
||||
# -- list_messages_by_run --
|
||||
|
||||
|
||||
class TestListMessagesByRun:
|
||||
@pytest.mark.anyio
|
||||
async def test_only_messages_for_specified_run(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
|
||||
await store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
|
||||
messages = await store.list_messages_by_run("t1", "r1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["run_id"] == "r1"
|
||||
assert messages[0]["category"] == "message"
|
||||
|
||||
|
||||
# -- count_messages --
|
||||
|
||||
|
||||
class TestCountMessages:
|
||||
@pytest.mark.anyio
|
||||
async def test_counts_only_message_category(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
|
||||
assert await store.count_messages("t1") == 2
|
||||
|
||||
|
||||
# -- put_batch --
|
||||
|
||||
|
||||
class TestPutBatch:
|
||||
@pytest.mark.anyio
|
||||
async def test_batch_assigns_seq(self, store):
|
||||
events = [
|
||||
{"thread_id": "t1", "run_id": "r1", "event_type": "human_message", "category": "message", "content": "a"},
|
||||
{"thread_id": "t1", "run_id": "r1", "event_type": "ai_message", "category": "message", "content": "b"},
|
||||
{"thread_id": "t1", "run_id": "r1", "event_type": "llm_end", "category": "trace"},
|
||||
]
|
||||
results = await store.put_batch(events)
|
||||
assert len(results) == 3
|
||||
assert all("seq" in r for r in results)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_batch_seq_strictly_increasing(self, store):
|
||||
events = [
|
||||
{"thread_id": "t1", "run_id": "r1", "event_type": "human_message", "category": "message"},
|
||||
{"thread_id": "t1", "run_id": "r1", "event_type": "ai_message", "category": "message"},
|
||||
]
|
||||
results = await store.put_batch(events)
|
||||
assert results[0]["seq"] == 1
|
||||
assert results[1]["seq"] == 2
|
||||
|
||||
|
||||
# -- delete --
|
||||
|
||||
|
||||
class TestDelete:
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_by_thread(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r2", event_type="llm_end", category="trace")
|
||||
count = await store.delete_by_thread("t1")
|
||||
assert count == 3
|
||||
assert await store.list_messages("t1") == []
|
||||
assert await store.count_messages("t1") == 0
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_by_run(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r2", event_type="llm_end", category="trace")
|
||||
count = await store.delete_by_run("t1", "r2")
|
||||
assert count == 2
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["run_id"] == "r1"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_nonexistent_thread_returns_zero(self, store):
|
||||
assert await store.delete_by_thread("nope") == 0
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_nonexistent_run_returns_zero(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
assert await store.delete_by_run("t1", "nope") == 0
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_nonexistent_thread_for_run_returns_zero(self, store):
|
||||
assert await store.delete_by_run("nope", "r1") == 0
|
||||
|
||||
|
||||
# -- Edge cases --
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
@pytest.mark.anyio
|
||||
async def test_empty_thread_list_messages(self, store):
|
||||
assert await store.list_messages("empty") == []
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_empty_run_list_events(self, store):
|
||||
assert await store.list_events("empty", "r1") == []
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_empty_thread_count_messages(self, store):
|
||||
assert await store.count_messages("empty") == 0
|
||||
|
||||
|
||||
# -- DB-specific tests --
|
||||
|
||||
|
||||
class TestDbRunEventStore:
|
||||
"""Tests for DbRunEventStore with temp SQLite."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_basic_crud(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory())
|
||||
|
||||
r = await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="hi")
|
||||
assert r["seq"] == 1
|
||||
r2 = await s.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="hello")
|
||||
assert r2["seq"] == 2
|
||||
|
||||
messages = await s.list_messages("t1")
|
||||
assert len(messages) == 2
|
||||
|
||||
count = await s.count_messages("t1")
|
||||
assert count == 2
|
||||
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_trace_content_truncation(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory(), max_trace_content=100)
|
||||
|
||||
long = "x" * 200
|
||||
r = await s.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace", content=long)
|
||||
assert len(r["content"]) == 100
|
||||
assert r["metadata"].get("content_truncated") is True
|
||||
|
||||
# message content NOT truncated
|
||||
m = await s.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content=long)
|
||||
assert len(m["content"]) == 200
|
||||
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_pagination(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory())
|
||||
|
||||
for i in range(10):
|
||||
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content=str(i))
|
||||
|
||||
# before_seq
|
||||
msgs = await s.list_messages("t1", before_seq=6, limit=3)
|
||||
assert [m["seq"] for m in msgs] == [3, 4, 5]
|
||||
|
||||
# after_seq
|
||||
msgs = await s.list_messages("t1", after_seq=7, limit=3)
|
||||
assert [m["seq"] for m in msgs] == [8, 9, 10]
|
||||
|
||||
# default (latest)
|
||||
msgs = await s.list_messages("t1", limit=3)
|
||||
assert [m["seq"] for m in msgs] == [8, 9, 10]
|
||||
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory())
|
||||
|
||||
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await s.put(thread_id="t1", run_id="r2", event_type="ai_message", category="message")
|
||||
c = await s.delete_by_run("t1", "r2")
|
||||
assert c == 1
|
||||
assert await s.count_messages("t1") == 1
|
||||
|
||||
c = await s.delete_by_thread("t1")
|
||||
assert c == 1
|
||||
assert await s.count_messages("t1") == 0
|
||||
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_put_batch_seq_continuity(self, tmp_path):
|
||||
"""Batch write produces continuous seq values with no gaps."""
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory())
|
||||
|
||||
events = [{"thread_id": "t1", "run_id": "r1", "event_type": "trace", "category": "trace"} for _ in range(50)]
|
||||
results = await s.put_batch(events)
|
||||
seqs = [r["seq"] for r in results]
|
||||
assert seqs == list(range(1, 51))
|
||||
await close_engine()
|
||||
|
||||
|
||||
# -- Factory tests --
|
||||
|
||||
|
||||
class TestMakeRunEventStore:
|
||||
"""Tests for the make_run_event_store factory function."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_memory_backend_default(self):
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
store = make_run_event_store(None)
|
||||
assert type(store).__name__ == "MemoryRunEventStore"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_memory_backend_explicit(self):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
config = MagicMock()
|
||||
config.backend = "memory"
|
||||
store = make_run_event_store(config)
|
||||
assert type(store).__name__ == "MemoryRunEventStore"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_db_backend_with_engine(self, tmp_path):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.persistence.engine import close_engine, init_engine
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
|
||||
config = MagicMock()
|
||||
config.backend = "db"
|
||||
config.max_trace_content = 10240
|
||||
store = make_run_event_store(config)
|
||||
assert type(store).__name__ == "DbRunEventStore"
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_db_backend_no_engine_falls_back(self):
|
||||
"""db backend without engine falls back to memory."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.persistence.engine import close_engine, init_engine
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
await init_engine("memory") # no engine created
|
||||
|
||||
config = MagicMock()
|
||||
config.backend = "db"
|
||||
store = make_run_event_store(config)
|
||||
assert type(store).__name__ == "MemoryRunEventStore"
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_jsonl_backend(self):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
config = MagicMock()
|
||||
config.backend = "jsonl"
|
||||
store = make_run_event_store(config)
|
||||
assert type(store).__name__ == "JsonlRunEventStore"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_unknown_backend_raises(self):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
config = MagicMock()
|
||||
config.backend = "redis"
|
||||
with pytest.raises(ValueError, match="Unknown"):
|
||||
make_run_event_store(config)
|
||||
|
||||
|
||||
# -- JSONL-specific tests --
|
||||
|
||||
|
||||
class TestJsonlRunEventStore:
|
||||
@pytest.mark.anyio
|
||||
async def test_basic_crud(self, tmp_path):
|
||||
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
|
||||
|
||||
s = JsonlRunEventStore(base_dir=tmp_path / "jsonl")
|
||||
r = await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="hi")
|
||||
assert r["seq"] == 1
|
||||
messages = await s.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_file_at_correct_path(self, tmp_path):
|
||||
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
|
||||
|
||||
s = JsonlRunEventStore(base_dir=tmp_path / "jsonl")
|
||||
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
assert (tmp_path / "jsonl" / "threads" / "t1" / "runs" / "r1.jsonl").exists()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cross_run_messages(self, tmp_path):
|
||||
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
|
||||
|
||||
s = JsonlRunEventStore(base_dir=tmp_path / "jsonl")
|
||||
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await s.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
|
||||
messages = await s.list_messages("t1")
|
||||
assert len(messages) == 2
|
||||
assert [m["seq"] for m in messages] == [1, 2]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_by_run(self, tmp_path):
|
||||
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
|
||||
|
||||
s = JsonlRunEventStore(base_dir=tmp_path / "jsonl")
|
||||
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await s.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
|
||||
c = await s.delete_by_run("t1", "r2")
|
||||
assert c == 1
|
||||
assert not (tmp_path / "jsonl" / "threads" / "t1" / "runs" / "r2.jsonl").exists()
|
||||
assert await s.count_messages("t1") == 1
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user