security(auth): wire @require_permission(owner_check=True) on isolation routes

Apply the require_permission decorator to all 28 routes that take a
{thread_id} path parameter. Combined with the strict middleware
(previous commit), this gives the double-layer protection that
AUTH_TEST_PLAN test 7.5.9 documents:

  Layer 1 (AuthMiddleware): cookie + JWT validation, rejects junk
                            cookies and stamps request.state.user
  Layer 2 (@require_permission with owner_check=True): per-resource
                            ownership verification via
                            ThreadMetaStore.check_access — returns
                            404 if a different user owns the thread

The decorator's owner_check branch is rewritten to use the SQL
thread_meta_repo (the 2.0-rc persistence layer) instead of the
LangGraph store path that PR #1728 used (_store_get / get_store
in routers/threads.py). The inject_record convenience is dropped
— no caller in 2.0 needs the LangGraph blob, and the SQL repo has
a different shape.

Routes decorated (28 total):
- threads.py: delete, patch, get, get-state, post-state, post-history
- thread_runs.py: post-runs, post-runs-stream, post-runs-wait,
  list_runs, get_run, cancel_run, join_run, stream_existing_run,
  list_thread_messages, list_run_messages, list_run_events,
  thread_token_usage
- feedback.py: create, list, stats, delete
- uploads.py: upload (added Request param), list, delete
- artifacts.py: get_artifact
- suggestions.py: generate (renamed body parameter to avoid
  conflict with FastAPI Request)

Test fixes:
- test_suggestions_router.py: bypass the decorator via __wrapped__
  (the unit tests cover parsing logic, not auth — no point spinning
  up a thread_meta_repo just to test JSON unwrapping)
- test_auth_middleware.py 4 fake-cookie tests: already updated in
  the previous commit (745bf432)

Tests: 293 passed (auth + persistence + isolation + suggestions)
Lint: clean
This commit is contained in:
greatmengqi
2026-04-08 13:32:39 +08:00
parent 745bf4324e
commit 2b33bfd78f
8 changed files with 79 additions and 29 deletions
+24 -16
View File
@@ -231,28 +231,36 @@ def require_permission(
detail=f"Permission denied: {resource}:{action}", detail=f"Permission denied: {resource}:{action}",
) )
# Owner check for thread-specific resources # Owner check for thread-specific resources.
#
# 2.0-rc moved thread metadata into the SQL persistence layer
# (``threads_meta`` table). We verify ownership via
# ``ThreadMetaStore.check_access`` instead of the LangGraph
# store path that the original PR #1728 used. ``check_access``
# returns True for missing rows (untracked legacy thread) and
# for rows whose ``owner_id`` is NULL (shared / pre-auth data),
# so this is a strict-deny check rather than strict-allow:
# only an *existing* row with a *different* owner_id triggers
# 404.
#
# ``inject_record`` is no longer supported — it was a
# convenience for handlers that wanted the LangGraph store
# blob; the SQL repo would need a different shape and no
# caller in 2.0 needs it.
if owner_check: if owner_check:
thread_id = kwargs.get("thread_id") thread_id = kwargs.get("thread_id")
if thread_id is None: if thread_id is None:
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter") raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
# Get thread and verify ownership from app.gateway.deps import get_thread_meta_repo
from app.gateway.routers.threads import _store_get, get_store
store = get_store(request) thread_meta_repo = get_thread_meta_repo(request)
if store is not None: allowed = await thread_meta_repo.check_access(thread_id, str(auth.user.id))
record = await _store_get(store, thread_id) if not allowed:
if record: raise HTTPException(
owner_id = record.get("metadata", {}).get(owner_filter_key) status_code=404,
if owner_id and owner_id != str(auth.user.id): detail=f"Thread {thread_id} not found",
raise HTTPException( )
status_code=404,
detail=f"Thread {thread_id} not found",
)
# Inject record if requested
if inject_record:
kwargs["thread_record"] = record
return await func(*args, **kwargs) return await func(*args, **kwargs)
+2
View File
@@ -7,6 +7,7 @@ from urllib.parse import quote
from fastapi import APIRouter, HTTPException, Request from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import FileResponse, PlainTextResponse, Response from fastapi.responses import FileResponse, PlainTextResponse, Response
from app.gateway.authz import require_permission
from app.gateway.path_utils import resolve_thread_virtual_path from app.gateway.path_utils import resolve_thread_virtual_path
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -81,6 +82,7 @@ def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> byte
summary="Get Artifact File", summary="Get Artifact File",
description="Retrieve an artifact file generated by the AI agent. Text and binary files can be viewed inline, while active web content is always downloaded.", description="Retrieve an artifact file generated by the AI agent. Text and binary files can be viewed inline, while active web content is always downloaded.",
) )
@require_permission("threads", "read", owner_check=True)
async def get_artifact(thread_id: str, path: str, request: Request, download: bool = False) -> Response: async def get_artifact(thread_id: str, path: str, request: Request, download: bool = False) -> Response:
"""Get an artifact file by its path. """Get an artifact file by its path.
+5
View File
@@ -12,6 +12,7 @@ from typing import Any
from fastapi import APIRouter, HTTPException, Request from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.gateway.authz import require_permission
from app.gateway.deps import get_current_user, get_feedback_repo, get_run_store from app.gateway.deps import get_current_user, get_feedback_repo, get_run_store
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -53,6 +54,7 @@ class FeedbackStatsResponse(BaseModel):
@router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse) @router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
@require_permission("threads", "write", owner_check=True)
async def create_feedback( async def create_feedback(
thread_id: str, thread_id: str,
run_id: str, run_id: str,
@@ -85,6 +87,7 @@ async def create_feedback(
@router.get("/{thread_id}/runs/{run_id}/feedback", response_model=list[FeedbackResponse]) @router.get("/{thread_id}/runs/{run_id}/feedback", response_model=list[FeedbackResponse])
@require_permission("threads", "read", owner_check=True)
async def list_feedback( async def list_feedback(
thread_id: str, thread_id: str,
run_id: str, run_id: str,
@@ -96,6 +99,7 @@ async def list_feedback(
@router.get("/{thread_id}/runs/{run_id}/feedback/stats", response_model=FeedbackStatsResponse) @router.get("/{thread_id}/runs/{run_id}/feedback/stats", response_model=FeedbackStatsResponse)
@require_permission("threads", "read", owner_check=True)
async def feedback_stats( async def feedback_stats(
thread_id: str, thread_id: str,
run_id: str, run_id: str,
@@ -107,6 +111,7 @@ async def feedback_stats(
@router.delete("/{thread_id}/runs/{run_id}/feedback/{feedback_id}") @router.delete("/{thread_id}/runs/{run_id}/feedback/{feedback_id}")
@require_permission("threads", "delete", owner_check=True)
async def delete_feedback( async def delete_feedback(
thread_id: str, thread_id: str,
run_id: str, run_id: str,
+8 -6
View File
@@ -1,10 +1,11 @@
import json import json
import logging import logging
from fastapi import APIRouter from fastapi import APIRouter, Request
from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.messages import HumanMessage, SystemMessage
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.gateway.authz import require_permission
from deerflow.models import create_chat_model from deerflow.models import create_chat_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -98,12 +99,13 @@ def _format_conversation(messages: list[SuggestionMessage]) -> str:
summary="Generate Follow-up Questions", summary="Generate Follow-up Questions",
description="Generate short follow-up questions a user might ask next, based on recent conversation context.", description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
) )
async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> SuggestionsResponse: @require_permission("threads", "read", owner_check=True)
if not request.messages: async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request: Request) -> SuggestionsResponse:
if not body.messages:
return SuggestionsResponse(suggestions=[]) return SuggestionsResponse(suggestions=[])
n = request.n n = body.n
conversation = _format_conversation(request.messages) conversation = _format_conversation(body.messages)
if not conversation: if not conversation:
return SuggestionsResponse(suggestions=[]) return SuggestionsResponse(suggestions=[])
@@ -120,7 +122,7 @@ async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> S
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions" user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
try: try:
model = create_chat_model(name=request.model_name, thinking_enabled=False) model = create_chat_model(name=body.model_name, thinking_enabled=False)
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)]) response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)])
raw = _extract_response_text(response.content) raw = _extract_response_text(response.content)
suggestions = _parse_json_string_list(raw) or [] suggestions = _parse_json_string_list(raw) or []
@@ -19,6 +19,7 @@ from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.gateway.authz import require_permission
from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.services import sse_consumer, start_run from app.gateway.services import sse_consumer, start_run
from deerflow.runtime import RunRecord, serialize_channel_values from deerflow.runtime import RunRecord, serialize_channel_values
@@ -93,6 +94,7 @@ def _record_to_response(record: RunRecord) -> RunResponse:
@router.post("/{thread_id}/runs", response_model=RunResponse) @router.post("/{thread_id}/runs", response_model=RunResponse)
@require_permission("runs", "create", owner_check=True)
async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -> RunResponse: async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -> RunResponse:
"""Create a background run (returns immediately).""" """Create a background run (returns immediately)."""
record = await start_run(body, thread_id, request) record = await start_run(body, thread_id, request)
@@ -100,6 +102,7 @@ async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -
@router.post("/{thread_id}/runs/stream") @router.post("/{thread_id}/runs/stream")
@require_permission("runs", "create", owner_check=True)
async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -> StreamingResponse: async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -> StreamingResponse:
"""Create a run and stream events via SSE. """Create a run and stream events via SSE.
@@ -127,6 +130,7 @@ async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -
@router.post("/{thread_id}/runs/wait", response_model=dict) @router.post("/{thread_id}/runs/wait", response_model=dict)
@require_permission("runs", "create", owner_check=True)
async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict: async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict:
"""Create a run and block until it completes, returning the final state.""" """Create a run and block until it completes, returning the final state."""
record = await start_run(body, thread_id, request) record = await start_run(body, thread_id, request)
@@ -152,6 +156,7 @@ async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) ->
@router.get("/{thread_id}/runs", response_model=list[RunResponse]) @router.get("/{thread_id}/runs", response_model=list[RunResponse])
@require_permission("runs", "read", owner_check=True)
async def list_runs(thread_id: str, request: Request) -> list[RunResponse]: async def list_runs(thread_id: str, request: Request) -> list[RunResponse]:
"""List all runs for a thread.""" """List all runs for a thread."""
run_mgr = get_run_manager(request) run_mgr = get_run_manager(request)
@@ -160,6 +165,7 @@ async def list_runs(thread_id: str, request: Request) -> list[RunResponse]:
@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse) @router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse)
@require_permission("runs", "read", owner_check=True)
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse: async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
"""Get details of a specific run.""" """Get details of a specific run."""
run_mgr = get_run_manager(request) run_mgr = get_run_manager(request)
@@ -170,6 +176,7 @@ async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
@router.post("/{thread_id}/runs/{run_id}/cancel") @router.post("/{thread_id}/runs/{run_id}/cancel")
@require_permission("runs", "cancel", owner_check=True)
async def cancel_run( async def cancel_run(
thread_id: str, thread_id: str,
run_id: str, run_id: str,
@@ -207,6 +214,7 @@ async def cancel_run(
@router.get("/{thread_id}/runs/{run_id}/join") @router.get("/{thread_id}/runs/{run_id}/join")
@require_permission("runs", "read", owner_check=True)
async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse: async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse:
"""Join an existing run's SSE stream.""" """Join an existing run's SSE stream."""
bridge = get_stream_bridge(request) bridge = get_stream_bridge(request)
@@ -227,6 +235,7 @@ async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingRe
@router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None) @router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None)
@require_permission("runs", "read", owner_check=True)
async def stream_existing_run( async def stream_existing_run(
thread_id: str, thread_id: str,
run_id: str, run_id: str,
@@ -274,6 +283,7 @@ async def stream_existing_run(
@router.get("/{thread_id}/messages") @router.get("/{thread_id}/messages")
@require_permission("runs", "read", owner_check=True)
async def list_thread_messages( async def list_thread_messages(
thread_id: str, thread_id: str,
request: Request, request: Request,
@@ -287,6 +297,7 @@ async def list_thread_messages(
@router.get("/{thread_id}/runs/{run_id}/messages") @router.get("/{thread_id}/runs/{run_id}/messages")
@require_permission("runs", "read", owner_check=True)
async def list_run_messages(thread_id: str, run_id: str, request: Request) -> list[dict]: async def list_run_messages(thread_id: str, run_id: str, request: Request) -> list[dict]:
"""Return displayable messages for a specific run.""" """Return displayable messages for a specific run."""
event_store = get_run_event_store(request) event_store = get_run_event_store(request)
@@ -294,6 +305,7 @@ async def list_run_messages(thread_id: str, run_id: str, request: Request) -> li
@router.get("/{thread_id}/runs/{run_id}/events") @router.get("/{thread_id}/runs/{run_id}/events")
@require_permission("runs", "read", owner_check=True)
async def list_run_events( async def list_run_events(
thread_id: str, thread_id: str,
run_id: str, run_id: str,
@@ -308,6 +320,7 @@ async def list_run_events(
@router.get("/{thread_id}/token-usage") @router.get("/{thread_id}/token-usage")
@require_permission("threads", "read", owner_check=True)
async def thread_token_usage(thread_id: str, request: Request) -> dict: async def thread_token_usage(thread_id: str, request: Request) -> dict:
"""Thread-level token usage aggregation.""" """Thread-level token usage aggregation."""
run_store = get_run_store(request) run_store = get_run_store(request)
+7
View File
@@ -20,6 +20,7 @@ from typing import Any
from fastapi import APIRouter, HTTPException, Request from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.gateway.authz import require_permission
from app.gateway.deps import get_checkpointer from app.gateway.deps import get_checkpointer
from app.gateway.utils import sanitize_log_param from app.gateway.utils import sanitize_log_param
from deerflow.config.paths import Paths, get_paths from deerflow.config.paths import Paths, get_paths
@@ -165,6 +166,7 @@ def _derive_thread_status(checkpoint_tuple) -> str:
@router.delete("/{thread_id}", response_model=ThreadDeleteResponse) @router.delete("/{thread_id}", response_model=ThreadDeleteResponse)
@require_permission("threads", "delete", owner_check=True)
async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteResponse: async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteResponse:
"""Delete local persisted filesystem data for a thread. """Delete local persisted filesystem data for a thread.
@@ -293,6 +295,7 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
@router.patch("/{thread_id}", response_model=ThreadResponse) @router.patch("/{thread_id}", response_model=ThreadResponse)
@require_permission("threads", "write", owner_check=True)
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse: async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
"""Merge metadata into a thread record.""" """Merge metadata into a thread record."""
from app.gateway.deps import get_thread_meta_repo from app.gateway.deps import get_thread_meta_repo
@@ -320,6 +323,7 @@ async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Reques
@router.get("/{thread_id}", response_model=ThreadResponse) @router.get("/{thread_id}", response_model=ThreadResponse)
@require_permission("threads", "read", owner_check=True)
async def get_thread(thread_id: str, request: Request) -> ThreadResponse: async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
"""Get thread info. """Get thread info.
@@ -376,6 +380,7 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
@router.get("/{thread_id}/state", response_model=ThreadStateResponse) @router.get("/{thread_id}/state", response_model=ThreadStateResponse)
@require_permission("threads", "read", owner_check=True)
async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse: async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse:
"""Get the latest state snapshot for a thread. """Get the latest state snapshot for a thread.
@@ -425,6 +430,7 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
@router.post("/{thread_id}/state", response_model=ThreadStateResponse) @router.post("/{thread_id}/state", response_model=ThreadStateResponse)
@require_permission("threads", "write", owner_check=True)
async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse: async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse:
"""Update thread state (e.g. for human-in-the-loop resume or title rename). """Update thread state (e.g. for human-in-the-loop resume or title rename).
@@ -514,6 +520,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
@router.post("/{thread_id}/history", response_model=list[HistoryEntry]) @router.post("/{thread_id}/history", response_model=list[HistoryEntry])
@require_permission("threads", "read", owner_check=True)
async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]: async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]:
"""Get checkpoint history for a thread. """Get checkpoint history for a thread.
+8 -3
View File
@@ -4,9 +4,10 @@ import logging
import os import os
import stat import stat
from fastapi import APIRouter, File, HTTPException, UploadFile from fastapi import APIRouter, File, HTTPException, Request, UploadFile
from pydantic import BaseModel from pydantic import BaseModel
from app.gateway.authz import require_permission
from deerflow.config.paths import get_paths from deerflow.config.paths import get_paths
from deerflow.sandbox.sandbox_provider import get_sandbox_provider from deerflow.sandbox.sandbox_provider import get_sandbox_provider
from deerflow.uploads.manager import ( from deerflow.uploads.manager import (
@@ -54,8 +55,10 @@ def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
@router.post("", response_model=UploadResponse) @router.post("", response_model=UploadResponse)
@require_permission("threads", "write", owner_check=True)
async def upload_files( async def upload_files(
thread_id: str, thread_id: str,
request: Request,
files: list[UploadFile] = File(...), files: list[UploadFile] = File(...),
) -> UploadResponse: ) -> UploadResponse:
"""Upload multiple files to a thread's uploads directory.""" """Upload multiple files to a thread's uploads directory."""
@@ -133,7 +136,8 @@ async def upload_files(
@router.get("/list", response_model=dict) @router.get("/list", response_model=dict)
async def list_uploaded_files(thread_id: str) -> dict: @require_permission("threads", "read", owner_check=True)
async def list_uploaded_files(thread_id: str, request: Request) -> dict:
"""List all files in a thread's uploads directory.""" """List all files in a thread's uploads directory."""
try: try:
uploads_dir = get_uploads_dir(thread_id) uploads_dir = get_uploads_dir(thread_id)
@@ -151,7 +155,8 @@ async def list_uploaded_files(thread_id: str) -> dict:
@router.delete("/{filename}") @router.delete("/{filename}")
async def delete_uploaded_file(thread_id: str, filename: str) -> dict: @require_permission("threads", "delete", owner_check=True)
async def delete_uploaded_file(thread_id: str, filename: str, request: Request) -> dict:
"""Delete a file from a thread's uploads directory.""" """Delete a file from a thread's uploads directory."""
try: try:
uploads_dir = get_uploads_dir(thread_id) uploads_dir = get_uploads_dir(thread_id)
+12 -4
View File
@@ -46,7 +46,9 @@ def test_generate_suggestions_parses_and_limits(monkeypatch):
fake_model.ainvoke = AsyncMock(return_value=MagicMock(content='```json\n["Q1", "Q2", "Q3", "Q4"]\n```')) fake_model.ainvoke = AsyncMock(return_value=MagicMock(content='```json\n["Q1", "Q2", "Q3", "Q4"]\n```'))
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model) monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
result = asyncio.run(suggestions.generate_suggestions("t1", req)) # Bypass the require_permission decorator (which needs request +
# thread_meta_repo) — these tests cover the parsing logic.
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
assert result.suggestions == ["Q1", "Q2", "Q3"] assert result.suggestions == ["Q1", "Q2", "Q3"]
@@ -64,7 +66,9 @@ def test_generate_suggestions_parses_list_block_content(monkeypatch):
fake_model.ainvoke = AsyncMock(return_value=MagicMock(content=[{"type": "text", "text": '```json\n["Q1", "Q2"]\n```'}])) fake_model.ainvoke = AsyncMock(return_value=MagicMock(content=[{"type": "text", "text": '```json\n["Q1", "Q2"]\n```'}]))
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model) monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
result = asyncio.run(suggestions.generate_suggestions("t1", req)) # Bypass the require_permission decorator (which needs request +
# thread_meta_repo) — these tests cover the parsing logic.
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
assert result.suggestions == ["Q1", "Q2"] assert result.suggestions == ["Q1", "Q2"]
@@ -82,7 +86,9 @@ def test_generate_suggestions_parses_output_text_block_content(monkeypatch):
fake_model.ainvoke = AsyncMock(return_value=MagicMock(content=[{"type": "output_text", "text": '```json\n["Q1", "Q2"]\n```'}])) fake_model.ainvoke = AsyncMock(return_value=MagicMock(content=[{"type": "output_text", "text": '```json\n["Q1", "Q2"]\n```'}]))
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model) monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
result = asyncio.run(suggestions.generate_suggestions("t1", req)) # Bypass the require_permission decorator (which needs request +
# thread_meta_repo) — these tests cover the parsing logic.
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
assert result.suggestions == ["Q1", "Q2"] assert result.suggestions == ["Q1", "Q2"]
@@ -97,6 +103,8 @@ def test_generate_suggestions_returns_empty_on_model_error(monkeypatch):
fake_model.ainvoke = AsyncMock(side_effect=RuntimeError("boom")) fake_model.ainvoke = AsyncMock(side_effect=RuntimeError("boom"))
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model) monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
result = asyncio.run(suggestions.generate_suggestions("t1", req)) # Bypass the require_permission decorator (which needs request +
# thread_meta_repo) — these tests cover the parsing logic.
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
assert result.suggestions == [] assert result.suggestions == []