Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 14005b4255 |
@@ -31,13 +31,6 @@ https://github.com/user-attachments/assets/f3786598-1f2a-4d07-919e-8b99dfa1de3e
|
||||
- [如何装饰租赁公寓?](https://deerflow.tech/chat?replay=rental-apartment-decoration)
|
||||
- [访问我们的官方网站探索更多回放示例。](https://deerflow.tech/#case-studies)
|
||||
|
||||
### 火山引擎
|
||||
|
||||
目前,DeerFlow 已正式入驻[火山引擎的 FaaS 应用中心](https://console.volcengine.com/vefaas/region:vefaas+cn-beijing/market),用户可通过体验链接进行在线体验,直观感受其强大功能与便捷操作;同时,为满足不同用户的部署需求,DeerFlow 支持基于火山引擎一键部署,点击部署链接即可快速完成部署流程,开启高效研究之旅。[快来看看吧](https://console.volcengine.com/vefaas/region:vefaas+cn-beijing/market)~
|
||||
|
||||
<img width="1800" alt="截屏2025-06-12 13 25 12" src="https://github.com/user-attachments/assets/73c15966-6b79-4dc0-8803-efdaf7c4015e" />
|
||||
|
||||
|
||||
---
|
||||
|
||||
## 📑 目录
|
||||
|
||||
@@ -16,5 +16,4 @@ AGENT_LLM_MAP: dict[str, LLMType] = {
|
||||
"podcast_script_writer": "basic",
|
||||
"ppt_composer": "basic",
|
||||
"prose_writer": "basic",
|
||||
"prompt_enhancer": "basic",
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ from typing import Any, Optional
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from src.rag.retriever import Resource
|
||||
from src.config.report_style import ReportStyle
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@@ -22,7 +21,6 @@ class Configuration:
|
||||
max_step_num: int = 3 # Maximum number of steps in a plan
|
||||
max_search_results: int = 3 # Maximum number of search results
|
||||
mcp_settings: dict = None # MCP settings, including dynamic loaded tools
|
||||
report_style: str = ReportStyle.ACADEMIC.value # Report style
|
||||
|
||||
@classmethod
|
||||
def from_runnable_config(
|
||||
|
||||
@@ -12,7 +12,7 @@ def replace_env_vars(value: str) -> str:
|
||||
return value
|
||||
if value.startswith("$"):
|
||||
env_var = value[1:]
|
||||
return os.getenv(env_var, env_var)
|
||||
return os.getenv(env_var, value)
|
||||
return value
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
import enum
|
||||
|
||||
|
||||
class ReportStyle(enum.Enum):
|
||||
ACADEMIC = "academic"
|
||||
POPULAR_SCIENCE = "popular_science"
|
||||
NEWS = "news"
|
||||
SOCIAL_MEDIA = "social_media"
|
||||
@@ -3,7 +3,8 @@
|
||||
|
||||
from .article import Article
|
||||
from .crawler import Crawler
|
||||
from .jina_client import JinaClient
|
||||
from .readability_extractor import ReadabilityExtractor
|
||||
|
||||
__all__ = ["Article", "Crawler", "JinaClient", "ReadabilityExtractor"]
|
||||
__all__ = [
|
||||
"Article",
|
||||
"Crawler",
|
||||
]
|
||||
|
||||
@@ -26,3 +26,13 @@ class Crawler:
|
||||
article = extractor.extract_article(html)
|
||||
article.url = url
|
||||
return article
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) == 2:
|
||||
url = sys.argv[1]
|
||||
else:
|
||||
url = "https://fintel.io/zh-hant/s/br/nvdc34"
|
||||
crawler = Crawler()
|
||||
article = crawler.crawl(url)
|
||||
print(article.to_markdown())
|
||||
|
||||
@@ -18,7 +18,7 @@ from .nodes import (
|
||||
)
|
||||
|
||||
|
||||
def continue_to_running_research_team(state: State):
|
||||
def continue_to_running_research_step(state: State):
|
||||
current_plan = state.get("current_plan")
|
||||
if not current_plan or not current_plan.steps:
|
||||
return "planner"
|
||||
@@ -49,7 +49,7 @@ def _build_base_graph():
|
||||
builder.add_edge("background_investigator", "planner")
|
||||
builder.add_conditional_edges(
|
||||
"research_team",
|
||||
continue_to_running_research_team,
|
||||
continue_to_running_research_step,
|
||||
["planner", "researcher", "coder"],
|
||||
)
|
||||
builder.add_edge("reporter", END)
|
||||
|
||||
+11
-18
@@ -36,7 +36,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@tool
|
||||
def handoff_to_planner(
|
||||
research_topic: Annotated[str, "The topic of the research task to be handed off."],
|
||||
task_title: Annotated[str, "The title of the task to be handed off."],
|
||||
locale: Annotated[str, "The user's detected language locale (e.g., en-US, zh-CN)."],
|
||||
):
|
||||
"""Handoff to planner agent to do plan."""
|
||||
@@ -48,7 +48,7 @@ def handoff_to_planner(
|
||||
def background_investigation_node(state: State, config: RunnableConfig):
|
||||
logger.info("background investigation node is running.")
|
||||
configurable = Configuration.from_runnable_config(config)
|
||||
query = state.get("research_topic")
|
||||
query = state["messages"][-1].content
|
||||
background_investigation_results = None
|
||||
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value:
|
||||
searched_content = LoggedTavilySearch(
|
||||
@@ -87,8 +87,10 @@ def planner_node(
|
||||
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0
|
||||
messages = apply_prompt_template("planner", state, configurable)
|
||||
|
||||
if state.get("enable_background_investigation") and state.get(
|
||||
"background_investigation_results"
|
||||
if (
|
||||
plan_iterations == 0
|
||||
and state.get("enable_background_investigation")
|
||||
and state.get("background_investigation_results")
|
||||
):
|
||||
messages += [
|
||||
{
|
||||
@@ -219,7 +221,6 @@ def coordinator_node(
|
||||
|
||||
goto = "__end__"
|
||||
locale = state.get("locale", "en-US") # Default locale if not specified
|
||||
research_topic = state.get("research_topic", "")
|
||||
|
||||
if len(response.tool_calls) > 0:
|
||||
goto = "planner"
|
||||
@@ -230,11 +231,8 @@ def coordinator_node(
|
||||
for tool_call in response.tool_calls:
|
||||
if tool_call.get("name", "") != "handoff_to_planner":
|
||||
continue
|
||||
if tool_call.get("args", {}).get("locale") and tool_call.get(
|
||||
"args", {}
|
||||
).get("research_topic"):
|
||||
locale = tool_call.get("args", {}).get("locale")
|
||||
research_topic = tool_call.get("args", {}).get("research_topic")
|
||||
if tool_locale := tool_call.get("args", {}).get("locale"):
|
||||
locale = tool_locale
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing tool calls: {e}")
|
||||
@@ -245,19 +243,14 @@ def coordinator_node(
|
||||
logger.debug(f"Coordinator response: {response}")
|
||||
|
||||
return Command(
|
||||
update={
|
||||
"locale": locale,
|
||||
"research_topic": research_topic,
|
||||
"resources": configurable.resources,
|
||||
},
|
||||
update={"locale": locale, "resources": configurable.resources},
|
||||
goto=goto,
|
||||
)
|
||||
|
||||
|
||||
def reporter_node(state: State, config: RunnableConfig):
|
||||
def reporter_node(state: State):
|
||||
"""Reporter node that write a final report."""
|
||||
logger.info("Reporter write final report")
|
||||
configurable = Configuration.from_runnable_config(config)
|
||||
current_plan = state.get("current_plan")
|
||||
input_ = {
|
||||
"messages": [
|
||||
@@ -267,7 +260,7 @@ def reporter_node(state: State, config: RunnableConfig):
|
||||
],
|
||||
"locale": state.get("locale", "en-US"),
|
||||
}
|
||||
invoke_messages = apply_prompt_template("reporter", input_, configurable)
|
||||
invoke_messages = apply_prompt_template("reporter", input_)
|
||||
observations = state.get("observations", [])
|
||||
|
||||
# Add a reminder about the new report format, citation style, and table usage
|
||||
|
||||
@@ -12,7 +12,6 @@ class State(MessagesState):
|
||||
|
||||
# Runtime Variables
|
||||
locale: str = "en-US"
|
||||
research_topic: str = ""
|
||||
observations: list[str] = []
|
||||
resources: list[Resource] = []
|
||||
plan_iterations: int = 0
|
||||
|
||||
@@ -70,3 +70,9 @@ def get_llm_by_type(
|
||||
# In the future, we will use reasoning_llm and vl_llm for different purposes
|
||||
# reasoning_llm = get_llm_by_type("reasoning")
|
||||
# vl_llm = get_llm_by_type("vision")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Initialize LLMs for different purposes - now these will be cached
|
||||
basic_llm = get_llm_by_type("basic")
|
||||
print(basic_llm.invoke("Hello"))
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""Prompt enhancer module for improving user prompts."""
|
||||
@@ -1,25 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from src.prompt_enhancer.graph.enhancer_node import prompt_enhancer_node
|
||||
from src.prompt_enhancer.graph.state import PromptEnhancerState
|
||||
|
||||
|
||||
def build_graph():
|
||||
"""Build and return the prompt enhancer workflow graph."""
|
||||
# Build state graph
|
||||
builder = StateGraph(PromptEnhancerState)
|
||||
|
||||
# Add the enhancer node
|
||||
builder.add_node("enhancer", prompt_enhancer_node)
|
||||
|
||||
# Set entry point
|
||||
builder.set_entry_point("enhancer")
|
||||
|
||||
# Set finish point
|
||||
builder.set_finish_point("enhancer")
|
||||
|
||||
# Compile and return the graph
|
||||
return builder.compile()
|
||||
@@ -1,67 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import logging
|
||||
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from src.config.agents import AGENT_LLM_MAP
|
||||
from src.llms.llm import get_llm_by_type
|
||||
from src.prompts.template import env, apply_prompt_template
|
||||
from src.prompt_enhancer.graph.state import PromptEnhancerState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def prompt_enhancer_node(state: PromptEnhancerState):
|
||||
"""Node that enhances user prompts using AI analysis."""
|
||||
logger.info("Enhancing user prompt...")
|
||||
|
||||
model = get_llm_by_type(AGENT_LLM_MAP["prompt_enhancer"])
|
||||
|
||||
try:
|
||||
|
||||
# Create messages with context if provided
|
||||
context_info = ""
|
||||
if state.get("context"):
|
||||
context_info = f"\n\nAdditional context: {state['context']}"
|
||||
|
||||
original_prompt_message = HumanMessage(
|
||||
content=f"Please enhance this prompt:{context_info}\n\nOriginal prompt: {state['prompt']}"
|
||||
)
|
||||
|
||||
messages = apply_prompt_template(
|
||||
"prompt_enhancer/prompt_enhancer",
|
||||
{
|
||||
"messages": [original_prompt_message],
|
||||
"report_style": state.get("report_style"),
|
||||
},
|
||||
)
|
||||
|
||||
# Get the response from the model
|
||||
response = model.invoke(messages)
|
||||
|
||||
# Clean up the response - remove any extra formatting or comments
|
||||
enhanced_prompt = response.content.strip()
|
||||
|
||||
# Remove common prefixes that might be added by the model
|
||||
prefixes_to_remove = [
|
||||
"Enhanced Prompt:",
|
||||
"Enhanced prompt:",
|
||||
"Here's the enhanced prompt:",
|
||||
"Here is the enhanced prompt:",
|
||||
"**Enhanced Prompt**:",
|
||||
"**Enhanced prompt**:",
|
||||
]
|
||||
|
||||
for prefix in prefixes_to_remove:
|
||||
if enhanced_prompt.startswith(prefix):
|
||||
enhanced_prompt = enhanced_prompt[len(prefix) :].strip()
|
||||
break
|
||||
|
||||
logger.info("Prompt enhancement completed successfully")
|
||||
logger.debug(f"Enhanced prompt: {enhanced_prompt}")
|
||||
return {"output": enhanced_prompt}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in prompt enhancement: {str(e)}")
|
||||
return {"output": state["prompt"]}
|
||||
@@ -1,14 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from typing import TypedDict, Optional
|
||||
from src.config.report_style import ReportStyle
|
||||
|
||||
|
||||
class PromptEnhancerState(TypedDict):
|
||||
"""State for the prompt enhancer workflow."""
|
||||
|
||||
prompt: str # Original prompt to enhance
|
||||
context: Optional[str] # Additional context
|
||||
report_style: Optional[ReportStyle] # Report style preference
|
||||
output: Optional[str] # Enhanced prompt result
|
||||
@@ -1,104 +0,0 @@
|
||||
---
|
||||
CURRENT_TIME: {{ CURRENT_TIME }}
|
||||
---
|
||||
|
||||
You are an expert prompt engineer. Your task is to enhance user prompts to make them more effective, specific, and likely to produce high-quality results from AI systems.
|
||||
|
||||
# Your Role
|
||||
- Analyze the original prompt for clarity, specificity, and completeness
|
||||
- Enhance the prompt by adding relevant details, context, and structure
|
||||
- Make the prompt more actionable and results-oriented
|
||||
- Preserve the user's original intent while improving effectiveness
|
||||
|
||||
{% if report_style == "academic" %}
|
||||
# Enhancement Guidelines for Academic Style
|
||||
1. **Add methodological rigor**: Include research methodology, scope, and analytical framework
|
||||
2. **Specify academic structure**: Organize with clear thesis, literature review, analysis, and conclusions
|
||||
3. **Clarify scholarly expectations**: Specify citation requirements, evidence standards, and academic tone
|
||||
4. **Add theoretical context**: Include relevant theoretical frameworks and disciplinary perspectives
|
||||
5. **Ensure precision**: Use precise terminology and avoid ambiguous language
|
||||
6. **Include limitations**: Acknowledge scope limitations and potential biases
|
||||
{% elif report_style == "popular_science" %}
|
||||
# Enhancement Guidelines for Popular Science Style
|
||||
1. **Add accessibility**: Transform technical concepts into relatable analogies and examples
|
||||
2. **Improve narrative structure**: Organize as an engaging story with clear beginning, middle, and end
|
||||
3. **Clarify audience expectations**: Specify general audience level and engagement goals
|
||||
4. **Add human context**: Include real-world applications and human interest elements
|
||||
5. **Make it compelling**: Ensure the prompt guides toward fascinating and wonder-inspiring content
|
||||
6. **Include visual elements**: Suggest use of metaphors and descriptive language for complex concepts
|
||||
{% elif report_style == "news" %}
|
||||
# Enhancement Guidelines for News Style
|
||||
1. **Add journalistic rigor**: Include fact-checking requirements, source verification, and objectivity standards
|
||||
2. **Improve news structure**: Organize with inverted pyramid structure (most important information first)
|
||||
3. **Clarify reporting expectations**: Specify timeliness, accuracy, and balanced perspective requirements
|
||||
4. **Add contextual background**: Include relevant background information and broader implications
|
||||
5. **Make it newsworthy**: Ensure the prompt focuses on current relevance and public interest
|
||||
6. **Include attribution**: Specify source requirements and quote standards
|
||||
{% elif report_style == "social_media" %}
|
||||
# Enhancement Guidelines for Social Media Style
|
||||
1. **Add engagement focus**: Include attention-grabbing elements, hooks, and shareability factors
|
||||
2. **Improve platform structure**: Organize for specific platform requirements (character limits, hashtags, etc.)
|
||||
3. **Clarify audience expectations**: Specify target demographic and engagement goals
|
||||
4. **Add viral elements**: Include trending topics, relatable content, and interactive elements
|
||||
5. **Make it shareable**: Ensure the prompt guides toward content that encourages sharing and discussion
|
||||
6. **Include visual considerations**: Suggest emoji usage, formatting, and visual appeal elements
|
||||
{% else %}
|
||||
# General Enhancement Guidelines
|
||||
1. **Add specificity**: Include relevant details, scope, and constraints
|
||||
2. **Improve structure**: Organize the request logically with clear sections if needed
|
||||
3. **Clarify expectations**: Specify desired output format, length, or style
|
||||
4. **Add context**: Include background information that would help generate better results
|
||||
5. **Make it actionable**: Ensure the prompt guides toward concrete, useful outputs
|
||||
{% endif %}
|
||||
|
||||
# Output Requirements
|
||||
- Output ONLY the enhanced prompt
|
||||
- Do NOT include any explanations, comments, or meta-text
|
||||
- Do NOT use phrases like "Enhanced Prompt:" or "Here's the enhanced version:"
|
||||
- The output should be ready to use directly as a prompt
|
||||
|
||||
{% if report_style == "academic" %}
|
||||
# Academic Style Examples
|
||||
|
||||
**Original**: "Write about AI"
|
||||
**Enhanced**: "Conduct a comprehensive academic analysis of artificial intelligence applications across three key sectors: healthcare, education, and business. Employ a systematic literature review methodology to examine peer-reviewed sources from the past five years. Structure your analysis with: (1) theoretical framework defining AI and its taxonomies, (2) sector-specific case studies with quantitative performance metrics, (3) critical evaluation of implementation challenges and ethical considerations, (4) comparative analysis across sectors, and (5) evidence-based recommendations for future research directions. Maintain academic rigor with proper citations, acknowledge methodological limitations, and present findings with appropriate hedging language. Target length: 3000-4000 words with APA formatting."
|
||||
|
||||
**Original**: "Explain climate change"
|
||||
**Enhanced**: "Provide a rigorous academic examination of anthropogenic climate change, synthesizing current scientific consensus and recent research developments. Structure your analysis as follows: (1) theoretical foundations of greenhouse effect and radiative forcing mechanisms, (2) systematic review of empirical evidence from paleoclimatic, observational, and modeling studies, (3) critical analysis of attribution studies linking human activities to observed warming, (4) evaluation of climate sensitivity estimates and uncertainty ranges, (5) assessment of projected impacts under different emission scenarios, and (6) discussion of research gaps and methodological limitations. Include quantitative data, statistical significance levels, and confidence intervals where appropriate. Cite peer-reviewed sources extensively and maintain objective, third-person academic voice throughout."
|
||||
|
||||
{% elif report_style == "popular_science" %}
|
||||
# Popular Science Style Examples
|
||||
|
||||
**Original**: "Write about AI"
|
||||
**Enhanced**: "Tell the fascinating story of how artificial intelligence is quietly revolutionizing our daily lives in ways most people never realize. Take readers on an engaging journey through three surprising realms: the hospital where AI helps doctors spot diseases faster than ever before, the classroom where intelligent tutors adapt to each student's learning style, and the boardroom where algorithms are making million-dollar decisions. Use vivid analogies (like comparing neural networks to how our brains work) and real-world examples that readers can relate to. Include 'wow factor' moments that showcase AI's incredible capabilities, but also honest discussions about current limitations. Write with infectious enthusiasm while maintaining scientific accuracy, and conclude with exciting possibilities that await us in the near future. Aim for 1500-2000 words that feel like a captivating conversation with a brilliant friend."
|
||||
|
||||
**Original**: "Explain climate change"
|
||||
**Enhanced**: "Craft a compelling narrative that transforms the complex science of climate change into an accessible and engaging story for curious readers. Begin with a relatable scenario (like why your hometown weather feels different than when you were a kid) and use this as a gateway to explore the fascinating science behind our changing planet. Employ vivid analogies - compare Earth's atmosphere to a blanket, greenhouse gases to invisible heat-trapping molecules, and climate feedback loops to a snowball rolling downhill. Include surprising facts and 'aha moments' that will make readers think differently about the world around them. Weave in human stories of scientists making discoveries, communities adapting to change, and innovative solutions being developed. Balance the serious implications with hope and actionable insights, concluding with empowering steps readers can take. Write with wonder and curiosity, making complex concepts feel approachable and personally relevant."
|
||||
|
||||
{% elif report_style == "news" %}
|
||||
# News Style Examples
|
||||
|
||||
**Original**: "Write about AI"
|
||||
**Enhanced**: "Report on the current state and immediate impact of artificial intelligence across three critical sectors: healthcare, education, and business. Lead with the most newsworthy developments and recent breakthroughs that are affecting people today. Structure using inverted pyramid format: start with key findings and immediate implications, then provide essential background context, followed by detailed analysis and expert perspectives. Include specific, verifiable data points, recent statistics, and quotes from credible sources including industry leaders, researchers, and affected stakeholders. Address both benefits and concerns with balanced reporting, fact-check all claims, and provide proper attribution for all information. Focus on timeliness and relevance to current events, highlighting what's happening now and what readers need to know. Maintain journalistic objectivity while making the significance clear to a general news audience. Target 800-1200 words following AP style guidelines."
|
||||
|
||||
**Original**: "Explain climate change"
|
||||
**Enhanced**: "Provide comprehensive news coverage of climate change that explains the current scientific understanding and immediate implications for readers. Lead with the most recent and significant developments in climate science, policy, or impacts that are making headlines today. Structure the report with: breaking developments first, essential background for understanding the issue, current scientific consensus with specific data and timeframes, real-world impacts already being observed, policy responses and debates, and what experts say comes next. Include quotes from credible climate scientists, policy makers, and affected communities. Present information objectively while clearly communicating the scientific consensus, fact-check all claims, and provide proper source attribution. Address common misconceptions with factual corrections. Focus on what's happening now, why it matters to readers, and what they can expect in the near future. Follow journalistic standards for accuracy, balance, and timeliness."
|
||||
|
||||
{% elif report_style == "social_media" %}
|
||||
# Social Media Style Examples
|
||||
|
||||
**Original**: "Write about AI"
|
||||
**Enhanced**: "Create engaging social media content about AI that will stop the scroll and spark conversations! Start with an attention-grabbing hook like 'You won't believe what AI just did in hospitals this week 🤯' and structure as a compelling thread or post series. Include surprising facts, relatable examples (like AI helping doctors spot diseases or personalizing your Netflix recommendations), and interactive elements that encourage sharing and comments. Use strategic hashtags (#AI #Technology #Future), incorporate relevant emojis for visual appeal, and include questions that prompt audience engagement ('Have you noticed AI in your daily life? Drop examples below! 👇'). Make complex concepts digestible with bite-sized explanations, trending analogies, and shareable quotes. Include a clear call-to-action and optimize for the specific platform (Twitter threads, Instagram carousel, LinkedIn professional insights, or TikTok-style quick facts). Aim for high shareability with content that feels both informative and entertaining."
|
||||
|
||||
**Original**: "Explain climate change"
|
||||
**Enhanced**: "Develop viral-worthy social media content that makes climate change accessible and shareable without being preachy. Open with a scroll-stopping hook like 'The weather app on your phone is telling a bigger story than you think 📱🌡️' and break down complex science into digestible, engaging chunks. Use relatable comparisons (Earth's fever, atmosphere as a blanket), trending formats (before/after visuals, myth-busting series, quick facts), and interactive elements (polls, questions, challenges). Include strategic hashtags (#ClimateChange #Science #Environment), eye-catching emojis, and shareable graphics or infographics. Address common questions and misconceptions with clear, factual responses. Create content that encourages positive action rather than climate anxiety, ending with empowering steps followers can take. Optimize for platform-specific features (Instagram Stories, TikTok trends, Twitter threads) and include calls-to-action that drive engagement and sharing."
|
||||
|
||||
{% else %}
|
||||
# General Examples
|
||||
|
||||
**Original**: "Write about AI"
|
||||
**Enhanced**: "Write a comprehensive 1000-word analysis of artificial intelligence's current applications in healthcare, education, and business. Include specific examples of AI tools being used in each sector, discuss both benefits and challenges, and provide insights into future trends. Structure the response with clear sections for each industry and conclude with key takeaways."
|
||||
|
||||
**Original**: "Explain climate change"
|
||||
**Enhanced**: "Provide a detailed explanation of climate change suitable for a general audience. Cover the scientific mechanisms behind global warming, major causes including greenhouse gas emissions, observable effects we're seeing today, and projected future impacts. Include specific data and examples, and explain the difference between weather and climate. Organize the response with clear headings and conclude with actionable steps individuals can take."
|
||||
{% endif %}
|
||||
+2
-159
@@ -2,21 +2,7 @@
|
||||
CURRENT_TIME: {{ CURRENT_TIME }}
|
||||
---
|
||||
|
||||
{% if report_style == "academic" %}
|
||||
You are a distinguished academic researcher and scholarly writer. Your report must embody the highest standards of academic rigor and intellectual discourse. Write with the precision of a peer-reviewed journal article, employing sophisticated analytical frameworks, comprehensive literature synthesis, and methodological transparency. Your language should be formal, technical, and authoritative, utilizing discipline-specific terminology with exactitude. Structure arguments logically with clear thesis statements, supporting evidence, and nuanced conclusions. Maintain complete objectivity, acknowledge limitations, and present balanced perspectives on controversial topics. The report should demonstrate deep scholarly engagement and contribute meaningfully to academic knowledge.
|
||||
{% elif report_style == "popular_science" %}
|
||||
You are an award-winning science communicator and storyteller. Your mission is to transform complex scientific concepts into captivating narratives that spark curiosity and wonder in everyday readers. Write with the enthusiasm of a passionate educator, using vivid analogies, relatable examples, and compelling storytelling techniques. Your tone should be warm, approachable, and infectious in its excitement about discovery. Break down technical jargon into accessible language without sacrificing accuracy. Use metaphors, real-world comparisons, and human interest angles to make abstract concepts tangible. Think like a National Geographic writer or a TED Talk presenter - engaging, enlightening, and inspiring.
|
||||
{% elif report_style == "news" %}
|
||||
You are an NBC News correspondent and investigative journalist with decades of experience in breaking news and in-depth reporting. Your report must exemplify the gold standard of American broadcast journalism: authoritative, meticulously researched, and delivered with the gravitas and credibility that NBC News is known for. Write with the precision of a network news anchor, employing the classic inverted pyramid structure while weaving compelling human narratives. Your language should be clear, authoritative, and accessible to prime-time television audiences. Maintain NBC's tradition of balanced reporting, thorough fact-checking, and ethical journalism. Think like Lester Holt or Andrea Mitchell - delivering complex stories with clarity, context, and unwavering integrity.
|
||||
{% elif report_style == "social_media" %}
|
||||
{% if locale == "zh-CN" %}
|
||||
You are a popular 小红书 (Xiaohongshu) content creator specializing in lifestyle and knowledge sharing. Your report should embody the authentic, personal, and engaging style that resonates with 小红书 users. Write with genuine enthusiasm and a "姐妹们" (sisters) tone, as if sharing exciting discoveries with close friends. Use abundant emojis, create "种草" (grass-planting/recommendation) moments, and structure content for easy mobile consumption. Your writing should feel like a personal diary entry mixed with expert insights - warm, relatable, and irresistibly shareable. Think like a top 小红书 blogger who effortlessly combines personal experience with valuable information, making readers feel like they've discovered a hidden gem.
|
||||
{% else %}
|
||||
You are a viral Twitter content creator and digital influencer specializing in breaking down complex topics into engaging, shareable threads. Your report should be optimized for maximum engagement and viral potential across social media platforms. Write with energy, authenticity, and a conversational tone that resonates with global online communities. Use strategic hashtags, create quotable moments, and structure content for easy consumption and sharing. Think like a successful Twitter thought leader who can make any topic accessible, engaging, and discussion-worthy while maintaining credibility and accuracy.
|
||||
{% endif %}
|
||||
{% else %}
|
||||
You are a professional reporter responsible for writing clear, comprehensive reports based ONLY on provided information and verifiable facts. Your report should adopt a professional tone.
|
||||
{% endif %}
|
||||
You are a professional reporter responsible for writing clear, comprehensive reports based ONLY on provided information and verifiable facts.
|
||||
|
||||
# Role
|
||||
|
||||
@@ -57,40 +43,10 @@ Structure your report in the following format:
|
||||
- **Including images from the previous steps in the report is very helpful.**
|
||||
|
||||
5. **Survey Note** (for more comprehensive reports)
|
||||
{% if report_style == "academic" %}
|
||||
- **Literature Review & Theoretical Framework**: Comprehensive analysis of existing research and theoretical foundations
|
||||
- **Methodology & Data Analysis**: Detailed examination of research methods and analytical approaches
|
||||
- **Critical Discussion**: In-depth evaluation of findings with consideration of limitations and implications
|
||||
- **Future Research Directions**: Identification of gaps and recommendations for further investigation
|
||||
{% elif report_style == "popular_science" %}
|
||||
- **The Bigger Picture**: How this research fits into the broader scientific landscape
|
||||
- **Real-World Applications**: Practical implications and potential future developments
|
||||
- **Behind the Scenes**: Interesting details about the research process and challenges faced
|
||||
- **What's Next**: Exciting possibilities and upcoming developments in the field
|
||||
{% elif report_style == "news" %}
|
||||
- **NBC News Analysis**: In-depth examination of the story's broader implications and significance
|
||||
- **Impact Assessment**: How these developments affect different communities, industries, and stakeholders
|
||||
- **Expert Perspectives**: Insights from credible sources, analysts, and subject matter experts
|
||||
- **Timeline & Context**: Chronological background and historical context essential for understanding
|
||||
- **What's Next**: Expected developments, upcoming milestones, and stories to watch
|
||||
{% elif report_style == "social_media" %}
|
||||
{% if locale == "zh-CN" %}
|
||||
- **【种草时刻】**: 最值得关注的亮点和必须了解的核心信息
|
||||
- **【数据震撼】**: 用小红书风格展示重要统计数据和发现
|
||||
- **【姐妹们的看法】**: 社区热议话题和大家的真实反馈
|
||||
- **【行动指南】**: 实用建议和读者可以立即行动的清单
|
||||
{% else %}
|
||||
- **Thread Highlights**: Key takeaways formatted for maximum shareability
|
||||
- **Data That Matters**: Important statistics and findings presented for viral potential
|
||||
- **Community Pulse**: Trending discussions and reactions from the online community
|
||||
- **Action Steps**: Practical advice and immediate next steps for readers
|
||||
{% endif %}
|
||||
{% else %}
|
||||
- A more detailed, academic-style analysis.
|
||||
- Include comprehensive sections covering all aspects of the topic.
|
||||
- Can include comparative analysis, tables, and detailed feature breakdowns.
|
||||
- This section is optional for shorter reports.
|
||||
{% endif %}
|
||||
|
||||
6. **Key Citations**
|
||||
- List all references at the end in link reference format.
|
||||
@@ -100,64 +56,7 @@ Structure your report in the following format:
|
||||
# Writing Guidelines
|
||||
|
||||
1. Writing style:
|
||||
{% if report_style == "academic" %}
|
||||
**Academic Excellence Standards:**
|
||||
- Employ sophisticated, formal academic discourse with discipline-specific terminology
|
||||
- Construct complex, nuanced arguments with clear thesis statements and logical progression
|
||||
- Use third-person perspective and passive voice where appropriate for objectivity
|
||||
- Include methodological considerations and acknowledge research limitations
|
||||
- Reference theoretical frameworks and cite relevant scholarly work patterns
|
||||
- Maintain intellectual rigor with precise, unambiguous language
|
||||
- Avoid contractions, colloquialisms, and informal expressions entirely
|
||||
- Use hedging language appropriately ("suggests," "indicates," "appears to")
|
||||
{% elif report_style == "popular_science" %}
|
||||
**Science Communication Excellence:**
|
||||
- Write with infectious enthusiasm and genuine curiosity about discoveries
|
||||
- Transform technical jargon into vivid, relatable analogies and metaphors
|
||||
- Use active voice and engaging narrative techniques to tell scientific stories
|
||||
- Include "wow factor" moments and surprising revelations to maintain interest
|
||||
- Employ conversational tone while maintaining scientific accuracy
|
||||
- Use rhetorical questions to engage readers and guide their thinking
|
||||
- Include human elements: researcher personalities, discovery stories, real-world impacts
|
||||
- Balance accessibility with intellectual respect for your audience
|
||||
{% elif report_style == "news" %}
|
||||
**NBC News Editorial Standards:**
|
||||
- Open with a compelling lede that captures the essence of the story in 25-35 words
|
||||
- Use the classic inverted pyramid: most newsworthy information first, supporting details follow
|
||||
- Write in clear, conversational broadcast style that sounds natural when read aloud
|
||||
- Employ active voice and strong, precise verbs that convey action and urgency
|
||||
- Attribute every claim to specific, credible sources using NBC's attribution standards
|
||||
- Use present tense for ongoing situations, past tense for completed events
|
||||
- Maintain NBC's commitment to balanced reporting with multiple perspectives
|
||||
- Include essential context and background without overwhelming the main story
|
||||
- Verify information through at least two independent sources when possible
|
||||
- Clearly label speculation, analysis, and ongoing investigations
|
||||
- Use transitional phrases that guide readers smoothly through the narrative
|
||||
{% elif report_style == "social_media" %}
|
||||
{% if locale == "zh-CN" %}
|
||||
**小红书风格写作标准:**
|
||||
- 用"姐妹们!"、"宝子们!"等亲切称呼开头,营造闺蜜聊天氛围
|
||||
- 大量使用emoji表情符号增强表达力和视觉吸引力 ✨��
|
||||
- 采用"种草"语言:"真的绝了!"、"必须安利给大家!"、"不看后悔系列!"
|
||||
- 使用小红书特色标题格式:"【干货分享】"、"【亲测有效】"、"【避雷指南】"
|
||||
- 穿插个人感受和体验:"我当时看到这个数据真的震惊了!"
|
||||
- 用数字和符号增强视觉效果:①②③、✅❌、🔥💡⭐
|
||||
- 创造"金句"和可截图分享的内容段落
|
||||
- 结尾用互动性语言:"你们觉得呢?"、"评论区聊聊!"、"记得点赞收藏哦!"
|
||||
{% else %}
|
||||
**Twitter/X Engagement Standards:**
|
||||
- Open with attention-grabbing hooks that stop the scroll
|
||||
- Use thread-style formatting with numbered points (1/n, 2/n, etc.)
|
||||
- Incorporate strategic hashtags for discoverability and trending topics
|
||||
- Write quotable, tweetable snippets that beg to be shared
|
||||
- Use conversational, authentic voice with personality and wit
|
||||
- Include relevant emojis to enhance meaning and visual appeal 🧵📊💡
|
||||
- Create "thread-worthy" content with clear progression and payoff
|
||||
- End with engagement prompts: "What do you think?", "Retweet if you agree"
|
||||
{% endif %}
|
||||
{% else %}
|
||||
- Use a professional tone.
|
||||
{% endif %}
|
||||
- Use professional tone.
|
||||
- Be concise and precise.
|
||||
- Avoid speculation.
|
||||
- Support claims with evidence.
|
||||
@@ -178,62 +77,6 @@ Structure your report in the following format:
|
||||
- Use horizontal rules (---) to separate major sections.
|
||||
- Track the sources of information but keep the main text clean and readable.
|
||||
|
||||
{% if report_style == "academic" %}
|
||||
**Academic Formatting Specifications:**
|
||||
- Use formal section headings with clear hierarchical structure (## Introduction, ### Methodology, #### Subsection)
|
||||
- Employ numbered lists for methodological steps and logical sequences
|
||||
- Use block quotes for important definitions or key theoretical concepts
|
||||
- Include detailed tables with comprehensive headers and statistical data
|
||||
- Use footnote-style formatting for additional context or clarifications
|
||||
- Maintain consistent academic citation patterns throughout
|
||||
- Use `code blocks` for technical specifications, formulas, or data samples
|
||||
{% elif report_style == "popular_science" %}
|
||||
**Science Communication Formatting:**
|
||||
- Use engaging, descriptive headings that spark curiosity ("The Surprising Discovery That Changed Everything")
|
||||
- Employ creative formatting like callout boxes for "Did You Know?" facts
|
||||
- Use bullet points for easy-to-digest key findings
|
||||
- Include visual breaks with strategic use of bold text for emphasis
|
||||
- Format analogies and metaphors prominently to aid understanding
|
||||
- Use numbered lists for step-by-step explanations of complex processes
|
||||
- Highlight surprising statistics or findings with special formatting
|
||||
{% elif report_style == "news" %}
|
||||
**NBC News Formatting Standards:**
|
||||
- Craft headlines that are informative yet compelling, following NBC's style guide
|
||||
- Use NBC-style datelines and bylines for professional credibility
|
||||
- Structure paragraphs for broadcast readability (1-2 sentences for digital, 2-3 for print)
|
||||
- Employ strategic subheadings that advance the story narrative
|
||||
- Format direct quotes with proper attribution and context
|
||||
- Use bullet points sparingly, primarily for breaking news updates or key facts
|
||||
- Include "BREAKING" or "DEVELOPING" labels for ongoing stories
|
||||
- Format source attribution clearly: "according to NBC News," "sources tell NBC News"
|
||||
- Use italics for emphasis on key terms or breaking developments
|
||||
- Structure the story with clear sections: Lede, Context, Analysis, Looking Ahead
|
||||
{% elif report_style == "social_media" %}
|
||||
{% if locale == "zh-CN" %}
|
||||
**小红书格式优化标准:**
|
||||
- 使用吸睛标题配合emoji:"🔥【重磅】这个发现太震撼了!"
|
||||
- 关键数据用醒目格式突出:「 重点数据 」或 ⭐ 核心发现 ⭐
|
||||
- 适度使用大写强调:真的YYDS!、绝绝子!
|
||||
- 用emoji作为分点符号:✨、🌟、�、�、💯
|
||||
- 创建话题标签区域:#科技前沿 #必看干货 #涨知识了
|
||||
- 设置"划重点"总结区域,方便快速阅读
|
||||
- 利用换行和空白营造手机阅读友好的版式
|
||||
- 制作"金句卡片"格式,便于截图分享
|
||||
- 使用分割线和特殊符号:「」『』【】━━━━━━
|
||||
{% else %}
|
||||
**Twitter/X Formatting Standards:**
|
||||
- Use compelling headlines with strategic emoji placement 🧵⚡️🔥
|
||||
- Format key insights as standalone, quotable tweet blocks
|
||||
- Employ thread numbering for multi-part content (1/12, 2/12, etc.)
|
||||
- Use bullet points with emoji bullets for visual appeal
|
||||
- Include strategic hashtags at the end: #TechNews #Innovation #MustRead
|
||||
- Create "TL;DR" summaries for quick consumption
|
||||
- Use line breaks and white space for mobile readability
|
||||
- Format "quotable moments" with clear visual separation
|
||||
- Include call-to-action elements: "🔄 RT to share" "💬 What's your take?"
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
# Data Integrity
|
||||
|
||||
- Only use information explicitly provided in the input.
|
||||
|
||||
@@ -122,3 +122,12 @@ def parse_uri(uri: str) -> tuple[str, str]:
|
||||
if parsed.scheme != "rag":
|
||||
raise ValueError(f"Invalid URI: {uri}")
|
||||
return parsed.path.split("/")[1], parsed.fragment
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uri = "rag://dataset/123#abc"
|
||||
parsed = urlparse(uri)
|
||||
print(parsed.scheme)
|
||||
print(parsed.netloc)
|
||||
print(parsed.path)
|
||||
print(parsed.fragment)
|
||||
|
||||
+2
-53
@@ -14,19 +14,16 @@ from fastapi.responses import Response, StreamingResponse
|
||||
from langchain_core.messages import AIMessageChunk, ToolMessage, BaseMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
from src.config.report_style import ReportStyle
|
||||
from src.config.tools import SELECTED_RAG_PROVIDER
|
||||
from src.graph.builder import build_graph_with_memory
|
||||
from src.podcast.graph.builder import build_graph as build_podcast_graph
|
||||
from src.ppt.graph.builder import build_graph as build_ppt_graph
|
||||
from src.prose.graph.builder import build_graph as build_prose_graph
|
||||
from src.prompt_enhancer.graph.builder import build_graph as build_prompt_enhancer_graph
|
||||
from src.rag.builder import build_retriever
|
||||
from src.rag.retriever import Resource
|
||||
from src.server.chat_request import (
|
||||
ChatMessage,
|
||||
ChatRequest,
|
||||
EnhancePromptRequest,
|
||||
GeneratePodcastRequest,
|
||||
GeneratePPTRequest,
|
||||
GenerateProseRequest,
|
||||
@@ -80,14 +77,13 @@ async def chat_stream(request: ChatRequest):
|
||||
request.interrupt_feedback,
|
||||
request.mcp_settings,
|
||||
request.enable_background_investigation,
|
||||
request.report_style,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
async def _astream_workflow_generator(
|
||||
messages: List[dict],
|
||||
messages: List[ChatMessage],
|
||||
thread_id: str,
|
||||
resources: List[Resource],
|
||||
max_plan_iterations: int,
|
||||
@@ -96,8 +92,7 @@ async def _astream_workflow_generator(
|
||||
auto_accepted_plan: bool,
|
||||
interrupt_feedback: str,
|
||||
mcp_settings: dict,
|
||||
enable_background_investigation: bool,
|
||||
report_style: ReportStyle,
|
||||
enable_background_investigation,
|
||||
):
|
||||
input_ = {
|
||||
"messages": messages,
|
||||
@@ -107,7 +102,6 @@ async def _astream_workflow_generator(
|
||||
"observations": [],
|
||||
"auto_accepted_plan": auto_accepted_plan,
|
||||
"enable_background_investigation": enable_background_investigation,
|
||||
"research_topic": messages[-1]["content"] if messages else "",
|
||||
}
|
||||
if not auto_accepted_plan and interrupt_feedback:
|
||||
resume_msg = f"[{interrupt_feedback}]"
|
||||
@@ -124,7 +118,6 @@ async def _astream_workflow_generator(
|
||||
"max_step_num": max_step_num,
|
||||
"max_search_results": max_search_results,
|
||||
"mcp_settings": mcp_settings,
|
||||
"report_style": report_style.value,
|
||||
},
|
||||
stream_mode=["messages", "updates"],
|
||||
subgraphs=True,
|
||||
@@ -303,50 +296,6 @@ async def generate_prose(request: GenerateProseRequest):
|
||||
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
|
||||
|
||||
|
||||
@app.post("/api/prompt/enhance")
|
||||
async def enhance_prompt(request: EnhancePromptRequest):
|
||||
try:
|
||||
sanitized_prompt = request.prompt.replace("\r\n", "").replace("\n", "")
|
||||
logger.info(f"Enhancing prompt: {sanitized_prompt}")
|
||||
|
||||
# Convert string report_style to ReportStyle enum
|
||||
report_style = None
|
||||
if request.report_style:
|
||||
try:
|
||||
# Handle both uppercase and lowercase input
|
||||
style_mapping = {
|
||||
"ACADEMIC": ReportStyle.ACADEMIC,
|
||||
"POPULAR_SCIENCE": ReportStyle.POPULAR_SCIENCE,
|
||||
"NEWS": ReportStyle.NEWS,
|
||||
"SOCIAL_MEDIA": ReportStyle.SOCIAL_MEDIA,
|
||||
"academic": ReportStyle.ACADEMIC,
|
||||
"popular_science": ReportStyle.POPULAR_SCIENCE,
|
||||
"news": ReportStyle.NEWS,
|
||||
"social_media": ReportStyle.SOCIAL_MEDIA,
|
||||
}
|
||||
report_style = style_mapping.get(
|
||||
request.report_style, ReportStyle.ACADEMIC
|
||||
)
|
||||
except Exception:
|
||||
# If invalid style, default to ACADEMIC
|
||||
report_style = ReportStyle.ACADEMIC
|
||||
else:
|
||||
report_style = ReportStyle.ACADEMIC
|
||||
|
||||
workflow = build_prompt_enhancer_graph()
|
||||
final_state = workflow.invoke(
|
||||
{
|
||||
"prompt": request.prompt,
|
||||
"context": request.context,
|
||||
"report_style": report_style,
|
||||
}
|
||||
)
|
||||
return {"result": final_state["output"]}
|
||||
except Exception as e:
|
||||
logger.exception(f"Error occurred during prompt enhancement: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
|
||||
|
||||
|
||||
@app.post("/api/mcp/server/metadata", response_model=MCPServerMetadataResponse)
|
||||
async def mcp_server_metadata(request: MCPServerMetadataRequest):
|
||||
"""Get information about an MCP server."""
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import List, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.rag.retriever import Resource
|
||||
from src.config.report_style import ReportStyle
|
||||
|
||||
|
||||
class ContentItem(BaseModel):
|
||||
@@ -59,9 +58,6 @@ class ChatRequest(BaseModel):
|
||||
enable_background_investigation: Optional[bool] = Field(
|
||||
True, description="Whether to get background investigation before plan"
|
||||
)
|
||||
report_style: Optional[ReportStyle] = Field(
|
||||
ReportStyle.ACADEMIC, description="The style of the report"
|
||||
)
|
||||
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
@@ -94,13 +90,3 @@ class GenerateProseRequest(BaseModel):
|
||||
command: Optional[str] = Field(
|
||||
"", description="The user custom command of the prose writer"
|
||||
)
|
||||
|
||||
|
||||
class EnhancePromptRequest(BaseModel):
|
||||
prompt: str = Field(..., description="The original prompt to enhance")
|
||||
context: Optional[str] = Field(
|
||||
"", description="Additional context about the intended use"
|
||||
)
|
||||
report_style: Optional[str] = Field(
|
||||
"academic", description="The style of the report"
|
||||
)
|
||||
|
||||
@@ -70,7 +70,7 @@ class EnhancedTavilySearchAPIWrapper(OriginalTavilySearchAPIWrapper):
|
||||
"include_images": include_images,
|
||||
"include_image_descriptions": include_image_descriptions,
|
||||
}
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f"{TAVILY_API_URL}/search", json=params) as res:
|
||||
if res.status == 200:
|
||||
data = await res.text()
|
||||
|
||||
@@ -20,7 +20,6 @@ MOCK_SEARCH_RESULTS = [
|
||||
def mock_state():
|
||||
return {
|
||||
"messages": [HumanMessage(content="test query")],
|
||||
"research_topic": "test query",
|
||||
"background_investigation_results": None,
|
||||
}
|
||||
|
||||
|
||||
@@ -106,40 +106,3 @@ def test_current_time_format():
|
||||
assert any(
|
||||
line.strip().startswith("CURRENT_TIME:") for line in system_content.split("\n")
|
||||
)
|
||||
|
||||
|
||||
def test_apply_prompt_template_reporter():
|
||||
"""Test reporter template rendering with different styles and locale"""
|
||||
|
||||
test_state_news = {
|
||||
"messages": [],
|
||||
"task": "test reporter task",
|
||||
"workspace_context": "test reporter context",
|
||||
"report_style": "news",
|
||||
"locale": "en-US",
|
||||
}
|
||||
messages_news = apply_prompt_template("reporter", test_state_news)
|
||||
system_content_news = messages_news[0]["content"]
|
||||
assert "NBC News" in system_content_news
|
||||
|
||||
test_state_social_media_en = {
|
||||
"messages": [],
|
||||
"task": "test reporter task",
|
||||
"workspace_context": "test reporter context",
|
||||
"report_style": "social_media",
|
||||
"locale": "en-US",
|
||||
}
|
||||
messages_default = apply_prompt_template("reporter", test_state_social_media_en)
|
||||
system_content_default = messages_default[0]["content"]
|
||||
assert "Twitter/X" in system_content_default
|
||||
|
||||
test_state_social_media_cn = {
|
||||
"messages": [],
|
||||
"task": "test reporter task",
|
||||
"workspace_context": "test reporter context",
|
||||
"report_style": "social_media",
|
||||
"locale": "zh-CN",
|
||||
}
|
||||
messages_cn = apply_prompt_template("reporter", test_state_social_media_cn)
|
||||
system_content_cn = messages_cn[0]["content"]
|
||||
assert "小红书" in system_content_cn
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
import builtins
|
||||
import importlib
|
||||
from src.config.configuration import Configuration
|
||||
|
||||
# Patch sys.path so relative import works
|
||||
|
||||
# Patch Resource for import
|
||||
mock_resource = type("Resource", (), {})
|
||||
|
||||
# Patch src.rag.retriever.Resource for import
|
||||
|
||||
module_name = "src.rag.retriever"
|
||||
if module_name not in sys.modules:
|
||||
retriever_mod = types.ModuleType(module_name)
|
||||
retriever_mod.Resource = mock_resource
|
||||
sys.modules[module_name] = retriever_mod
|
||||
|
||||
# Relative import of Configuration
|
||||
|
||||
|
||||
def test_default_configuration():
|
||||
config = Configuration()
|
||||
assert config.resources == []
|
||||
assert config.max_plan_iterations == 1
|
||||
assert config.max_step_num == 3
|
||||
assert config.max_search_results == 3
|
||||
assert config.mcp_settings is None
|
||||
|
||||
|
||||
def test_from_runnable_config_with_config_dict(monkeypatch):
|
||||
config_dict = {
|
||||
"configurable": {
|
||||
"max_plan_iterations": 5,
|
||||
"max_step_num": 7,
|
||||
"max_search_results": 10,
|
||||
"mcp_settings": {"foo": "bar"},
|
||||
}
|
||||
}
|
||||
config = Configuration.from_runnable_config(config_dict)
|
||||
assert config.max_plan_iterations == 5
|
||||
assert config.max_step_num == 7
|
||||
assert config.max_search_results == 10
|
||||
assert config.mcp_settings == {"foo": "bar"}
|
||||
|
||||
|
||||
def test_from_runnable_config_with_env_override(monkeypatch):
|
||||
monkeypatch.setenv("MAX_PLAN_ITERATIONS", "9")
|
||||
monkeypatch.setenv("MAX_STEP_NUM", "11")
|
||||
config_dict = {
|
||||
"configurable": {
|
||||
"max_plan_iterations": 2,
|
||||
"max_step_num": 3,
|
||||
"max_search_results": 4,
|
||||
}
|
||||
}
|
||||
config = Configuration.from_runnable_config(config_dict)
|
||||
# Environment variables take precedence and are strings
|
||||
assert config.max_plan_iterations == "9"
|
||||
assert config.max_step_num == "11"
|
||||
assert config.max_search_results == 4 # not overridden
|
||||
# Clean up
|
||||
monkeypatch.delenv("MAX_PLAN_ITERATIONS")
|
||||
monkeypatch.delenv("MAX_STEP_NUM")
|
||||
|
||||
|
||||
def test_from_runnable_config_with_none_and_falsy(monkeypatch):
|
||||
config_dict = {
|
||||
"configurable": {
|
||||
"max_plan_iterations": None,
|
||||
"max_step_num": 0, # falsy, should be skipped
|
||||
"max_search_results": "",
|
||||
}
|
||||
}
|
||||
config = Configuration.from_runnable_config(config_dict)
|
||||
# Should fall back to defaults for skipped/falsy values
|
||||
assert config.max_plan_iterations == 1
|
||||
assert config.max_step_num == 3
|
||||
assert config.max_search_results == 3
|
||||
|
||||
|
||||
def test_from_runnable_config_with_no_config():
|
||||
config = Configuration.from_runnable_config()
|
||||
assert config.max_plan_iterations == 1
|
||||
assert config.max_step_num == 3
|
||||
assert config.max_search_results == 3
|
||||
assert config.resources == []
|
||||
assert config.mcp_settings is None
|
||||
@@ -1,83 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import yaml
|
||||
import pytest
|
||||
from src.config.loader import load_yaml_config, process_dict, replace_env_vars
|
||||
|
||||
|
||||
def test_replace_env_vars_with_env(monkeypatch):
|
||||
monkeypatch.setenv("TEST_ENV", "env_value")
|
||||
assert replace_env_vars("$TEST_ENV") == "env_value"
|
||||
|
||||
|
||||
def test_replace_env_vars_without_env(monkeypatch):
|
||||
monkeypatch.delenv("NOT_SET_ENV", raising=False)
|
||||
assert replace_env_vars("$NOT_SET_ENV") == "NOT_SET_ENV"
|
||||
|
||||
|
||||
def test_replace_env_vars_non_string():
|
||||
assert replace_env_vars(123) == 123
|
||||
|
||||
|
||||
def test_replace_env_vars_regular_string():
|
||||
assert replace_env_vars("no_env") == "no_env"
|
||||
|
||||
|
||||
def test_process_dict_nested(monkeypatch):
|
||||
monkeypatch.setenv("FOO", "bar")
|
||||
config = {"a": "$FOO", "b": {"c": "$FOO", "d": 42, "e": "$NOT_SET_ENV"}}
|
||||
processed = process_dict(config)
|
||||
assert processed["a"] == "bar"
|
||||
assert processed["b"]["c"] == "bar"
|
||||
assert processed["b"]["d"] == 42
|
||||
assert processed["b"]["e"] == "NOT_SET_ENV"
|
||||
|
||||
|
||||
def test_process_dict_empty():
|
||||
assert process_dict({}) == {}
|
||||
|
||||
|
||||
def test_load_yaml_config_file_not_exist():
|
||||
assert load_yaml_config("non_existent_file.yaml") == {}
|
||||
|
||||
|
||||
def test_load_yaml_config(monkeypatch):
|
||||
monkeypatch.setenv("MY_ENV", "my_value")
|
||||
yaml_content = """
|
||||
key1: value1
|
||||
key2: $MY_ENV
|
||||
nested:
|
||||
key3: $MY_ENV
|
||||
key4: 123
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile("w+", delete=False) as tmp:
|
||||
tmp.write(yaml_content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
config = load_yaml_config(tmp_path)
|
||||
assert config["key1"] == "value1"
|
||||
assert config["key2"] == "my_value"
|
||||
assert config["nested"]["key3"] == "my_value"
|
||||
assert config["nested"]["key4"] == 123
|
||||
finally:
|
||||
os.remove(tmp_path)
|
||||
|
||||
|
||||
def test_load_yaml_config_cache(monkeypatch):
|
||||
monkeypatch.setenv("CACHE_ENV", "cache_value")
|
||||
yaml_content = "foo: $CACHE_ENV"
|
||||
with tempfile.NamedTemporaryFile("w+", delete=False) as tmp:
|
||||
tmp.write(yaml_content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
config1 = load_yaml_config(tmp_path)
|
||||
config2 = load_yaml_config(tmp_path)
|
||||
assert config1 is config2 # Should be cached (same object)
|
||||
assert config1["foo"] == "cache_value"
|
||||
finally:
|
||||
os.remove(tmp_path)
|
||||
@@ -1,74 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
import pytest
|
||||
from src.crawler.article import Article
|
||||
|
||||
|
||||
class DummyMarkdownify:
|
||||
"""A dummy markdownify replacement for patching if needed."""
|
||||
|
||||
@staticmethod
|
||||
def markdownify(html):
|
||||
return html
|
||||
|
||||
|
||||
def test_to_markdown_includes_title(monkeypatch):
|
||||
article = Article("Test Title", "<p>Hello <b>world</b>!</p>")
|
||||
result = article.to_markdown(including_title=True)
|
||||
assert result.startswith("# Test Title")
|
||||
assert "Hello" in result
|
||||
|
||||
|
||||
def test_to_markdown_excludes_title():
|
||||
article = Article("Test Title", "<p>Hello <b>world</b>!</p>")
|
||||
result = article.to_markdown(including_title=False)
|
||||
assert not result.startswith("# Test Title")
|
||||
assert "Hello" in result
|
||||
|
||||
|
||||
def test_to_message_with_text_only():
|
||||
article = Article("Test Title", "<p>Hello world!</p>")
|
||||
article.url = "https://example.com/"
|
||||
result = article.to_message()
|
||||
assert isinstance(result, list)
|
||||
assert any(item["type"] == "text" for item in result)
|
||||
assert all("type" in item for item in result)
|
||||
|
||||
|
||||
def test_to_message_with_image(monkeypatch):
|
||||
html = '<p>Intro</p><img src="img/pic.png"/>'
|
||||
article = Article("Title", html)
|
||||
article.url = "https://host.com/path/"
|
||||
# The markdownify library will convert <img> to markdown image syntax
|
||||
result = article.to_message()
|
||||
# Should have both text and image_url types
|
||||
types = [item["type"] for item in result]
|
||||
assert "image_url" in types
|
||||
assert "text" in types
|
||||
# Check that the image_url is correctly joined
|
||||
image_items = [item for item in result if item["type"] == "image_url"]
|
||||
assert image_items
|
||||
assert image_items[0]["image_url"]["url"] == "https://host.com/path/img/pic.png"
|
||||
|
||||
|
||||
def test_to_message_multiple_images():
|
||||
html = '<p>Start</p><img src="a.png"/><p>Mid</p><img src="b.jpg"/>End'
|
||||
article = Article("Title", html)
|
||||
article.url = "http://x/"
|
||||
result = article.to_message()
|
||||
image_urls = [
|
||||
item["image_url"]["url"] for item in result if item["type"] == "image_url"
|
||||
]
|
||||
assert "http://x/a.png" in image_urls
|
||||
assert "http://x/b.jpg" in image_urls
|
||||
text_items = [item for item in result if item["type"] == "text"]
|
||||
assert any("Start" in item["text"] for item in text_items)
|
||||
assert any("Mid" in item["text"] for item in text_items)
|
||||
|
||||
|
||||
def test_to_message_handles_empty_html():
|
||||
article = Article("Empty", "")
|
||||
article.url = "http://test/"
|
||||
result = article.to_message()
|
||||
assert isinstance(result, list)
|
||||
assert result[0]["type"] == "text"
|
||||
@@ -1,72 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import pytest
|
||||
import src.crawler as crawler_module
|
||||
from src.crawler import Crawler
|
||||
|
||||
|
||||
def test_crawler_sets_article_url(monkeypatch):
|
||||
"""Test that the crawler sets the article.url field correctly."""
|
||||
|
||||
class DummyArticle:
|
||||
def __init__(self):
|
||||
self.url = None
|
||||
|
||||
def to_markdown(self):
|
||||
return "# Dummy"
|
||||
|
||||
class DummyJinaClient:
|
||||
def crawl(self, url, return_format=None):
|
||||
return "<html>dummy</html>"
|
||||
|
||||
class DummyReadabilityExtractor:
|
||||
def extract_article(self, html):
|
||||
return DummyArticle()
|
||||
|
||||
monkeypatch.setattr("src.crawler.crawler.JinaClient", DummyJinaClient)
|
||||
monkeypatch.setattr(
|
||||
"src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor
|
||||
)
|
||||
|
||||
crawler = crawler_module.Crawler()
|
||||
url = "http://example.com"
|
||||
article = crawler.crawl(url)
|
||||
assert article.url == url
|
||||
assert article.to_markdown() == "# Dummy"
|
||||
|
||||
|
||||
def test_crawler_calls_dependencies(monkeypatch):
|
||||
"""Test that Crawler calls JinaClient.crawl and ReadabilityExtractor.extract_article."""
|
||||
calls = {}
|
||||
|
||||
class DummyJinaClient:
|
||||
def crawl(self, url, return_format=None):
|
||||
calls["jina"] = (url, return_format)
|
||||
return "<html>dummy</html>"
|
||||
|
||||
class DummyReadabilityExtractor:
|
||||
def extract_article(self, html):
|
||||
calls["extractor"] = html
|
||||
|
||||
class DummyArticle:
|
||||
url = None
|
||||
|
||||
def to_markdown(self):
|
||||
return "# Dummy"
|
||||
|
||||
return DummyArticle()
|
||||
|
||||
monkeypatch.setattr("src.crawler.crawler.JinaClient", DummyJinaClient)
|
||||
monkeypatch.setattr(
|
||||
"src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor
|
||||
)
|
||||
|
||||
crawler = crawler_module.Crawler()
|
||||
url = "http://example.com"
|
||||
crawler.crawl(url)
|
||||
assert "jina" in calls
|
||||
assert calls["jina"][0] == url
|
||||
assert calls["jina"][1] == "html"
|
||||
assert "extractor" in calls
|
||||
assert calls["extractor"] == "<html>dummy</html>"
|
||||
@@ -1,70 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
import types
|
||||
import pytest
|
||||
from src.llms import llm
|
||||
|
||||
|
||||
class DummyChatOpenAI:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
def invoke(self, msg):
|
||||
return f"Echo: {msg}"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_chat_openai(monkeypatch):
|
||||
monkeypatch.setattr(llm, "ChatOpenAI", DummyChatOpenAI)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_conf():
|
||||
return {
|
||||
"BASIC_MODEL": {"api_key": "test_key", "base_url": "http://test"},
|
||||
"REASONING_MODEL": {"api_key": "reason_key"},
|
||||
"VISION_MODEL": {"api_key": "vision_key"},
|
||||
}
|
||||
|
||||
|
||||
def test_get_env_llm_conf(monkeypatch):
|
||||
monkeypatch.setenv("BASIC_MODEL__API_KEY", "env_key")
|
||||
monkeypatch.setenv("BASIC_MODEL__BASE_URL", "http://env")
|
||||
conf = llm._get_env_llm_conf("basic")
|
||||
assert conf["api_key"] == "env_key"
|
||||
assert conf["base_url"] == "http://env"
|
||||
|
||||
|
||||
def test_create_llm_use_conf_merges_env(monkeypatch, dummy_conf):
|
||||
monkeypatch.setenv("BASIC_MODEL__API_KEY", "env_key")
|
||||
result = llm._create_llm_use_conf("basic", dummy_conf)
|
||||
assert isinstance(result, DummyChatOpenAI)
|
||||
assert result.kwargs["api_key"] == "env_key"
|
||||
assert result.kwargs["base_url"] == "http://test"
|
||||
|
||||
|
||||
def test_create_llm_use_conf_invalid_type(dummy_conf):
|
||||
with pytest.raises(ValueError):
|
||||
llm._create_llm_use_conf("unknown", dummy_conf)
|
||||
|
||||
|
||||
def test_create_llm_use_conf_empty_conf(monkeypatch):
|
||||
with pytest.raises(ValueError):
|
||||
llm._create_llm_use_conf("basic", {})
|
||||
|
||||
|
||||
def test_get_llm_by_type_caches(monkeypatch, dummy_conf):
|
||||
called = {}
|
||||
|
||||
def fake_load_yaml_config(path):
|
||||
called["called"] = True
|
||||
return dummy_conf
|
||||
|
||||
monkeypatch.setattr(llm, "load_yaml_config", fake_load_yaml_config)
|
||||
llm._llm_cache.clear()
|
||||
inst1 = llm.get_llm_by_type("basic")
|
||||
inst2 = llm.get_llm_by_type("basic")
|
||||
assert inst1 is inst2
|
||||
assert called["called"]
|
||||
@@ -1,2 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
@@ -1,2 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
@@ -1,156 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from src.prompt_enhancer.graph.builder import build_graph
|
||||
from src.prompt_enhancer.graph.state import PromptEnhancerState
|
||||
from src.config.report_style import ReportStyle
|
||||
|
||||
|
||||
class TestBuildGraph:
|
||||
"""Test cases for build_graph function."""
|
||||
|
||||
@patch("src.prompt_enhancer.graph.builder.StateGraph")
|
||||
def test_build_graph_structure(self, mock_state_graph):
|
||||
"""Test that build_graph creates the correct graph structure."""
|
||||
mock_builder = MagicMock()
|
||||
mock_compiled_graph = MagicMock()
|
||||
|
||||
mock_state_graph.return_value = mock_builder
|
||||
mock_builder.compile.return_value = mock_compiled_graph
|
||||
|
||||
result = build_graph()
|
||||
|
||||
# Verify StateGraph was created with correct state type
|
||||
mock_state_graph.assert_called_once_with(PromptEnhancerState)
|
||||
|
||||
# Verify entry point was set
|
||||
mock_builder.set_entry_point.assert_called_once_with("enhancer")
|
||||
|
||||
# Verify finish point was set
|
||||
mock_builder.set_finish_point.assert_called_once_with("enhancer")
|
||||
|
||||
# Verify graph was compiled
|
||||
mock_builder.compile.assert_called_once()
|
||||
|
||||
# Verify return value
|
||||
assert result == mock_compiled_graph
|
||||
|
||||
@patch("src.prompt_enhancer.graph.builder.StateGraph")
|
||||
@patch("src.prompt_enhancer.graph.builder.prompt_enhancer_node")
|
||||
def test_build_graph_node_function(self, mock_enhancer_node, mock_state_graph):
|
||||
"""Test that the correct node function is added to the graph."""
|
||||
mock_builder = MagicMock()
|
||||
mock_compiled_graph = MagicMock()
|
||||
|
||||
mock_state_graph.return_value = mock_builder
|
||||
mock_builder.compile.return_value = mock_compiled_graph
|
||||
|
||||
result = build_graph()
|
||||
|
||||
# Verify the correct node function was added
|
||||
mock_builder.add_node.assert_called_once_with("enhancer", mock_enhancer_node)
|
||||
|
||||
def test_build_graph_returns_compiled_graph(self):
|
||||
"""Test that build_graph returns a compiled graph object."""
|
||||
with patch("src.prompt_enhancer.graph.builder.StateGraph") as mock_state_graph:
|
||||
mock_builder = MagicMock()
|
||||
mock_compiled_graph = MagicMock()
|
||||
|
||||
mock_state_graph.return_value = mock_builder
|
||||
mock_builder.compile.return_value = mock_compiled_graph
|
||||
|
||||
result = build_graph()
|
||||
|
||||
assert result is mock_compiled_graph
|
||||
|
||||
@patch("src.prompt_enhancer.graph.builder.StateGraph")
|
||||
def test_build_graph_call_sequence(self, mock_state_graph):
|
||||
"""Test that build_graph calls methods in the correct sequence."""
|
||||
mock_builder = MagicMock()
|
||||
mock_compiled_graph = MagicMock()
|
||||
|
||||
mock_state_graph.return_value = mock_builder
|
||||
mock_builder.compile.return_value = mock_compiled_graph
|
||||
|
||||
# Track call order
|
||||
call_order = []
|
||||
|
||||
def track_add_node(*args, **kwargs):
|
||||
call_order.append("add_node")
|
||||
|
||||
def track_set_entry_point(*args, **kwargs):
|
||||
call_order.append("set_entry_point")
|
||||
|
||||
def track_set_finish_point(*args, **kwargs):
|
||||
call_order.append("set_finish_point")
|
||||
|
||||
def track_compile(*args, **kwargs):
|
||||
call_order.append("compile")
|
||||
return mock_compiled_graph
|
||||
|
||||
mock_builder.add_node.side_effect = track_add_node
|
||||
mock_builder.set_entry_point.side_effect = track_set_entry_point
|
||||
mock_builder.set_finish_point.side_effect = track_set_finish_point
|
||||
mock_builder.compile.side_effect = track_compile
|
||||
|
||||
build_graph()
|
||||
|
||||
# Verify the correct call sequence
|
||||
expected_order = ["add_node", "set_entry_point", "set_finish_point", "compile"]
|
||||
assert call_order == expected_order
|
||||
|
||||
def test_build_graph_integration(self):
|
||||
"""Integration test to verify the graph can be built without mocking."""
|
||||
# This test verifies that all imports and dependencies are correct
|
||||
try:
|
||||
graph = build_graph()
|
||||
assert graph is not None
|
||||
# The graph should be a compiled LangGraph object
|
||||
assert hasattr(graph, "invoke") or hasattr(graph, "stream")
|
||||
except ImportError as e:
|
||||
pytest.skip(f"Skipping integration test due to missing dependencies: {e}")
|
||||
except Exception as e:
|
||||
# If there are configuration issues (like missing LLM config),
|
||||
# we still consider the test successful if the graph structure is built
|
||||
if "LLM" in str(e) or "configuration" in str(e).lower():
|
||||
pytest.skip(
|
||||
f"Skipping integration test due to configuration issues: {e}"
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
@patch("src.prompt_enhancer.graph.builder.StateGraph")
|
||||
def test_build_graph_single_node_workflow(self, mock_state_graph):
|
||||
"""Test that the graph is configured as a single-node workflow."""
|
||||
mock_builder = MagicMock()
|
||||
mock_compiled_graph = MagicMock()
|
||||
|
||||
mock_state_graph.return_value = mock_builder
|
||||
mock_builder.compile.return_value = mock_compiled_graph
|
||||
|
||||
build_graph()
|
||||
|
||||
# Verify only one node is added
|
||||
assert mock_builder.add_node.call_count == 1
|
||||
|
||||
# Verify entry and finish points are the same node
|
||||
mock_builder.set_entry_point.assert_called_once_with("enhancer")
|
||||
mock_builder.set_finish_point.assert_called_once_with("enhancer")
|
||||
|
||||
@patch("src.prompt_enhancer.graph.builder.StateGraph")
|
||||
def test_build_graph_state_type(self, mock_state_graph):
|
||||
"""Test that the graph is initialized with the correct state type."""
|
||||
mock_builder = MagicMock()
|
||||
mock_compiled_graph = MagicMock()
|
||||
|
||||
mock_state_graph.return_value = mock_builder
|
||||
mock_builder.compile.return_value = mock_compiled_graph
|
||||
|
||||
build_graph()
|
||||
|
||||
# Verify StateGraph was initialized with PromptEnhancerState
|
||||
args, kwargs = mock_state_graph.call_args
|
||||
assert args[0] == PromptEnhancerState
|
||||
@@ -1,219 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from src.prompt_enhancer.graph.enhancer_node import prompt_enhancer_node
|
||||
from src.prompt_enhancer.graph.state import PromptEnhancerState
|
||||
from src.config.report_style import ReportStyle
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
"""Mock LLM that returns a test response."""
|
||||
llm = MagicMock()
|
||||
llm.invoke.return_value = MagicMock(content="Enhanced test prompt")
|
||||
return llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_messages():
|
||||
"""Mock messages returned by apply_prompt_template."""
|
||||
return [
|
||||
SystemMessage(content="System prompt template"),
|
||||
HumanMessage(content="Test human message"),
|
||||
]
|
||||
|
||||
|
||||
class TestPromptEnhancerNode:
|
||||
"""Test cases for prompt_enhancer_node function."""
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_basic_prompt_enhancement(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test basic prompt enhancement without context or report style."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
state = PromptEnhancerState(prompt="Write about AI")
|
||||
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
# Verify LLM was called
|
||||
mock_get_llm.assert_called_once_with("basic")
|
||||
mock_llm.invoke.assert_called_once_with(mock_messages)
|
||||
|
||||
# Verify apply_prompt_template was called correctly
|
||||
mock_apply_template.assert_called_once()
|
||||
call_args = mock_apply_template.call_args
|
||||
assert call_args[0][0] == "prompt_enhancer/prompt_enhancer"
|
||||
assert "messages" in call_args[0][1]
|
||||
assert "report_style" in call_args[0][1]
|
||||
|
||||
# Verify result
|
||||
assert result == {"output": "Enhanced test prompt"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_prompt_enhancement_with_report_style(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test prompt enhancement with report style."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
state = PromptEnhancerState(
|
||||
prompt="Write about AI", report_style=ReportStyle.ACADEMIC
|
||||
)
|
||||
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
# Verify apply_prompt_template was called with report_style
|
||||
mock_apply_template.assert_called_once()
|
||||
call_args = mock_apply_template.call_args
|
||||
assert call_args[0][0] == "prompt_enhancer/prompt_enhancer"
|
||||
assert call_args[0][1]["report_style"] == ReportStyle.ACADEMIC
|
||||
|
||||
# Verify result
|
||||
assert result == {"output": "Enhanced test prompt"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_prompt_enhancement_with_context(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test prompt enhancement with additional context."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
state = PromptEnhancerState(
|
||||
prompt="Write about AI", context="Focus on machine learning applications"
|
||||
)
|
||||
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
# Verify apply_prompt_template was called
|
||||
mock_apply_template.assert_called_once()
|
||||
call_args = mock_apply_template.call_args
|
||||
|
||||
# Check that the context was included in the human message
|
||||
messages_arg = call_args[0][1]["messages"]
|
||||
assert len(messages_arg) == 1
|
||||
human_message = messages_arg[0]
|
||||
assert isinstance(human_message, HumanMessage)
|
||||
assert "Focus on machine learning applications" in human_message.content
|
||||
|
||||
assert result == {"output": "Enhanced test prompt"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_error_handling(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test error handling when LLM call fails."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
# Mock LLM to raise an exception
|
||||
mock_llm.invoke.side_effect = Exception("LLM error")
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
# Should return original prompt on error
|
||||
assert result == {"output": "Test prompt"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_template_error_handling(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test error handling when template application fails."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
# Mock apply_prompt_template to raise an exception
|
||||
mock_apply_template.side_effect = Exception("Template error")
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
# Should return original prompt on error
|
||||
assert result == {"output": "Test prompt"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_prefix_removal(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test that common prefixes are removed from LLM response."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
# Test different prefixes that should be removed
|
||||
test_cases = [
|
||||
"Enhanced Prompt: This is the enhanced prompt",
|
||||
"Enhanced prompt: This is the enhanced prompt",
|
||||
"Here's the enhanced prompt: This is the enhanced prompt",
|
||||
"Here is the enhanced prompt: This is the enhanced prompt",
|
||||
"**Enhanced Prompt**: This is the enhanced prompt",
|
||||
"**Enhanced prompt**: This is the enhanced prompt",
|
||||
]
|
||||
|
||||
for response_with_prefix in test_cases:
|
||||
mock_llm.invoke.return_value = MagicMock(content=response_with_prefix)
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
assert result == {"output": "This is the enhanced prompt"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_whitespace_handling(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test that whitespace is properly stripped from LLM response."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
# Mock LLM response with extra whitespace
|
||||
mock_llm.invoke.return_value = MagicMock(
|
||||
content=" \n\n Enhanced prompt \n\n "
|
||||
)
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
assert result == {"output": "Enhanced prompt"}
|
||||
@@ -1,108 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import pytest
|
||||
from src.prompt_enhancer.graph.state import PromptEnhancerState
|
||||
from src.config.report_style import ReportStyle
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_creation():
|
||||
"""Test that PromptEnhancerState can be created with required fields."""
|
||||
state = PromptEnhancerState(
|
||||
prompt="Test prompt", context=None, report_style=None, output=None
|
||||
)
|
||||
|
||||
assert state["prompt"] == "Test prompt"
|
||||
assert state["context"] is None
|
||||
assert state["report_style"] is None
|
||||
assert state["output"] is None
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_with_all_fields():
|
||||
"""Test PromptEnhancerState with all fields populated."""
|
||||
state = PromptEnhancerState(
|
||||
prompt="Write about AI",
|
||||
context="Additional context about AI research",
|
||||
report_style=ReportStyle.ACADEMIC,
|
||||
output="Enhanced prompt about AI research",
|
||||
)
|
||||
|
||||
assert state["prompt"] == "Write about AI"
|
||||
assert state["context"] == "Additional context about AI research"
|
||||
assert state["report_style"] == ReportStyle.ACADEMIC
|
||||
assert state["output"] == "Enhanced prompt about AI research"
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_minimal():
|
||||
"""Test PromptEnhancerState with only required prompt field."""
|
||||
state = PromptEnhancerState(prompt="Minimal prompt")
|
||||
|
||||
assert state["prompt"] == "Minimal prompt"
|
||||
# Optional fields should not be present if not specified
|
||||
assert "context" not in state
|
||||
assert "report_style" not in state
|
||||
assert "output" not in state
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_with_different_report_styles():
|
||||
"""Test PromptEnhancerState with different ReportStyle values."""
|
||||
styles = [
|
||||
ReportStyle.ACADEMIC,
|
||||
ReportStyle.POPULAR_SCIENCE,
|
||||
ReportStyle.NEWS,
|
||||
ReportStyle.SOCIAL_MEDIA,
|
||||
]
|
||||
|
||||
for style in styles:
|
||||
state = PromptEnhancerState(prompt="Test prompt", report_style=style)
|
||||
assert state["report_style"] == style
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_update():
|
||||
"""Test updating PromptEnhancerState fields."""
|
||||
state = PromptEnhancerState(prompt="Original prompt")
|
||||
|
||||
# Update with new fields
|
||||
state.update(
|
||||
{
|
||||
"context": "New context",
|
||||
"report_style": ReportStyle.NEWS,
|
||||
"output": "Enhanced output",
|
||||
}
|
||||
)
|
||||
|
||||
assert state["prompt"] == "Original prompt"
|
||||
assert state["context"] == "New context"
|
||||
assert state["report_style"] == ReportStyle.NEWS
|
||||
assert state["output"] == "Enhanced output"
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_get_method():
|
||||
"""Test using get() method on PromptEnhancerState."""
|
||||
state = PromptEnhancerState(prompt="Test prompt", report_style=ReportStyle.ACADEMIC)
|
||||
|
||||
# Test get with existing keys
|
||||
assert state.get("prompt") == "Test prompt"
|
||||
assert state.get("report_style") == ReportStyle.ACADEMIC
|
||||
|
||||
# Test get with non-existing keys
|
||||
assert state.get("context") is None
|
||||
assert state.get("output") is None
|
||||
assert state.get("nonexistent", "default") == "default"
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_type_annotations():
|
||||
"""Test that the state accepts correct types."""
|
||||
# This test ensures the TypedDict structure is working correctly
|
||||
state = PromptEnhancerState(
|
||||
prompt="Test prompt",
|
||||
context="Test context",
|
||||
report_style=ReportStyle.POPULAR_SCIENCE,
|
||||
output="Test output",
|
||||
)
|
||||
|
||||
# Verify types
|
||||
assert isinstance(state["prompt"], str)
|
||||
assert isinstance(state["context"], str)
|
||||
assert isinstance(state["report_style"], ReportStyle)
|
||||
assert isinstance(state["output"], str)
|
||||
@@ -1,181 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import requests
|
||||
from unittest.mock import patch, MagicMock
|
||||
from src.rag.ragflow import RAGFlowProvider, parse_uri
|
||||
|
||||
|
||||
# Dummy classes to mock dependencies
|
||||
class DummyResource:
|
||||
def __init__(self, uri, title="", description=""):
|
||||
self.uri = uri
|
||||
self.title = title
|
||||
self.description = description
|
||||
|
||||
|
||||
class DummyChunk:
|
||||
def __init__(self, content, similarity):
|
||||
self.content = content
|
||||
self.similarity = similarity
|
||||
|
||||
|
||||
class DummyDocument:
|
||||
def __init__(self, id, title, chunks=None):
|
||||
self.id = id
|
||||
self.title = title
|
||||
self.chunks = chunks or []
|
||||
|
||||
|
||||
# Patch imports in ragflow.py to use dummy classes
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_imports(monkeypatch):
|
||||
import src.rag.ragflow as ragflow
|
||||
|
||||
ragflow.Resource = DummyResource
|
||||
ragflow.Chunk = DummyChunk
|
||||
ragflow.Document = DummyDocument
|
||||
yield
|
||||
|
||||
|
||||
def test_parse_uri_valid():
|
||||
uri = "rag://dataset/123#abc"
|
||||
dataset_id, document_id = parse_uri(uri)
|
||||
assert dataset_id == "123"
|
||||
assert document_id == "abc"
|
||||
|
||||
|
||||
def test_parse_uri_invalid():
|
||||
with pytest.raises(ValueError):
|
||||
parse_uri("http://dataset/123#abc")
|
||||
|
||||
|
||||
def test_init_env_vars(monkeypatch):
|
||||
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
|
||||
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
|
||||
monkeypatch.delenv("RAGFLOW_PAGE_SIZE", raising=False)
|
||||
provider = RAGFlowProvider()
|
||||
assert provider.api_url == "http://api"
|
||||
assert provider.api_key == "key"
|
||||
assert provider.page_size == 10
|
||||
|
||||
|
||||
def test_init_page_size(monkeypatch):
|
||||
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
|
||||
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
|
||||
monkeypatch.setenv("RAGFLOW_PAGE_SIZE", "5")
|
||||
provider = RAGFlowProvider()
|
||||
assert provider.page_size == 5
|
||||
|
||||
|
||||
def test_init_missing_env(monkeypatch):
|
||||
monkeypatch.delenv("RAGFLOW_API_URL", raising=False)
|
||||
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
|
||||
with pytest.raises(ValueError):
|
||||
RAGFlowProvider()
|
||||
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
|
||||
monkeypatch.delenv("RAGFLOW_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError):
|
||||
RAGFlowProvider()
|
||||
|
||||
|
||||
@patch("src.rag.ragflow.requests.post")
|
||||
def test_query_relevant_documents_success(mock_post, monkeypatch):
|
||||
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
|
||||
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
|
||||
provider = RAGFlowProvider()
|
||||
resource = DummyResource("rag://dataset/123#doc456")
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"data": {
|
||||
"doc_aggs": [{"doc_id": "doc456", "doc_name": "Doc Title"}],
|
||||
"chunks": [
|
||||
{"document_id": "doc456", "content": "chunk text", "similarity": 0.9}
|
||||
],
|
||||
}
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
docs = provider.query_relevant_documents("query", [resource])
|
||||
assert len(docs) == 1
|
||||
assert docs[0].id == "doc456"
|
||||
assert docs[0].title == "Doc Title"
|
||||
assert len(docs[0].chunks) == 1
|
||||
assert docs[0].chunks[0].content == "chunk text"
|
||||
assert docs[0].chunks[0].similarity == 0.9
|
||||
|
||||
|
||||
@patch("src.rag.ragflow.requests.post")
|
||||
def test_query_relevant_documents_error(mock_post, monkeypatch):
|
||||
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
|
||||
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
|
||||
provider = RAGFlowProvider()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 400
|
||||
mock_response.text = "error"
|
||||
mock_post.return_value = mock_response
|
||||
with pytest.raises(Exception):
|
||||
provider.query_relevant_documents("query", [])
|
||||
|
||||
|
||||
@patch("src.rag.ragflow.requests.get")
|
||||
def test_list_resources_success(mock_get, monkeypatch):
|
||||
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
|
||||
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
|
||||
provider = RAGFlowProvider()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"data": [
|
||||
{"id": "123", "name": "Dataset1", "description": "desc1"},
|
||||
{"id": "456", "name": "Dataset2", "description": "desc2"},
|
||||
]
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
resources = provider.list_resources()
|
||||
assert len(resources) == 2
|
||||
assert resources[0].uri == "rag://dataset/123"
|
||||
assert resources[0].title == "Dataset1"
|
||||
assert resources[0].description == "desc1"
|
||||
assert resources[1].uri == "rag://dataset/456"
|
||||
assert resources[1].title == "Dataset2"
|
||||
assert resources[1].description == "desc2"
|
||||
|
||||
|
||||
@patch("src.rag.ragflow.requests.get")
|
||||
def test_list_resources_success(mock_get, monkeypatch):
|
||||
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
|
||||
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
|
||||
provider = RAGFlowProvider()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"data": [
|
||||
{"id": "123", "name": "Dataset1", "description": "desc1"},
|
||||
{"id": "456", "name": "Dataset2", "description": "desc2"},
|
||||
]
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
resources = provider.list_resources()
|
||||
assert len(resources) == 2
|
||||
assert resources[0].uri == "rag://dataset/123"
|
||||
assert resources[0].title == "Dataset1"
|
||||
assert resources[0].description == "desc1"
|
||||
assert resources[1].uri == "rag://dataset/456"
|
||||
assert resources[1].title == "Dataset2"
|
||||
assert resources[1].description == "desc2"
|
||||
|
||||
|
||||
@patch("src.rag.ragflow.requests.get")
|
||||
def test_list_resources_error(mock_get, monkeypatch):
|
||||
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
|
||||
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
|
||||
provider = RAGFlowProvider()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.text = "fail"
|
||||
mock_get.return_value = mock_response
|
||||
with pytest.raises(Exception):
|
||||
provider.list_resources()
|
||||
@@ -1,72 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import pytest
|
||||
from src.rag.retriever import Chunk, Document, Resource, Retriever
|
||||
|
||||
|
||||
def test_chunk_init():
|
||||
chunk = Chunk(content="test content", similarity=0.9)
|
||||
assert chunk.content == "test content"
|
||||
assert chunk.similarity == 0.9
|
||||
|
||||
|
||||
def test_document_init_and_to_dict():
|
||||
chunk1 = Chunk(content="chunk1", similarity=0.8)
|
||||
chunk2 = Chunk(content="chunk2", similarity=0.7)
|
||||
doc = Document(
|
||||
id="doc1", url="http://example.com", title="Title", chunks=[chunk1, chunk2]
|
||||
)
|
||||
assert doc.id == "doc1"
|
||||
assert doc.url == "http://example.com"
|
||||
assert doc.title == "Title"
|
||||
assert doc.chunks == [chunk1, chunk2]
|
||||
d = doc.to_dict()
|
||||
assert d["id"] == "doc1"
|
||||
assert d["content"] == "chunk1\n\nchunk2"
|
||||
assert d["url"] == "http://example.com"
|
||||
assert d["title"] == "Title"
|
||||
|
||||
|
||||
def test_document_to_dict_optional_fields():
|
||||
chunk = Chunk(content="only chunk", similarity=1.0)
|
||||
doc = Document(id="doc2", chunks=[chunk])
|
||||
d = doc.to_dict()
|
||||
assert d["id"] == "doc2"
|
||||
assert d["content"] == "only chunk"
|
||||
assert "url" not in d
|
||||
assert "title" not in d
|
||||
|
||||
|
||||
def test_resource_model():
|
||||
resource = Resource(uri="uri1", title="Resource Title")
|
||||
assert resource.uri == "uri1"
|
||||
assert resource.title == "Resource Title"
|
||||
assert resource.description == ""
|
||||
|
||||
|
||||
def test_resource_model_with_description():
|
||||
resource = Resource(uri="uri2", title="Resource2", description="desc")
|
||||
assert resource.description == "desc"
|
||||
|
||||
|
||||
def test_retriever_abstract_methods():
|
||||
class DummyRetriever(Retriever):
|
||||
def list_resources(self, query=None):
|
||||
return [Resource(uri="uri", title="title")]
|
||||
|
||||
def query_relevant_documents(self, query, resources=[]):
|
||||
return [Document(id="id", chunks=[])]
|
||||
|
||||
retriever = DummyRetriever()
|
||||
resources = retriever.list_resources()
|
||||
assert isinstance(resources, list)
|
||||
assert isinstance(resources[0], Resource)
|
||||
docs = retriever.query_relevant_documents("query", resources)
|
||||
assert isinstance(docs, list)
|
||||
assert isinstance(docs[0], Document)
|
||||
|
||||
|
||||
def test_retriever_cannot_instantiate():
|
||||
with pytest.raises(TypeError):
|
||||
Retriever()
|
||||
@@ -1,20 +1,16 @@
|
||||
// Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
import { MagicWandIcon } from "@radix-ui/react-icons";
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
import { ArrowUp, X } from "lucide-react";
|
||||
import { useCallback, useRef, useState } from "react";
|
||||
import { useCallback, useRef } from "react";
|
||||
|
||||
import { Detective } from "~/components/deer-flow/icons/detective";
|
||||
import MessageInput, {
|
||||
type MessageInputRef,
|
||||
} from "~/components/deer-flow/message-input";
|
||||
import { ReportStyleDialog } from "~/components/deer-flow/report-style-dialog";
|
||||
import { Tooltip } from "~/components/deer-flow/tooltip";
|
||||
import { BorderBeam } from "~/components/magicui/border-beam";
|
||||
import { Button } from "~/components/ui/button";
|
||||
import { enhancePrompt } from "~/core/api";
|
||||
import type { Option, Resource } from "~/core/messages";
|
||||
import {
|
||||
setEnableBackgroundInvestigation,
|
||||
@@ -47,16 +43,10 @@ export function InputBox({
|
||||
const backgroundInvestigation = useSettingsStore(
|
||||
(state) => state.general.enableBackgroundInvestigation,
|
||||
);
|
||||
const reportStyle = useSettingsStore((state) => state.general.reportStyle);
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const inputRef = useRef<MessageInputRef>(null);
|
||||
const feedbackRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
// Enhancement state
|
||||
const [isEnhancing, setIsEnhancing] = useState(false);
|
||||
const [isEnhanceAnimating, setIsEnhanceAnimating] = useState(false);
|
||||
const [currentPrompt, setCurrentPrompt] = useState("");
|
||||
|
||||
const handleSendMessage = useCallback(
|
||||
(message: string, resources: Array<Resource>) => {
|
||||
if (responding) {
|
||||
@@ -71,50 +61,12 @@ export function InputBox({
|
||||
resources,
|
||||
});
|
||||
onRemoveFeedback?.();
|
||||
// Clear enhancement animation after sending
|
||||
setIsEnhanceAnimating(false);
|
||||
}
|
||||
}
|
||||
},
|
||||
[responding, onCancel, onSend, feedback, onRemoveFeedback],
|
||||
);
|
||||
|
||||
const handleEnhancePrompt = useCallback(async () => {
|
||||
if (currentPrompt.trim() === "" || isEnhancing) {
|
||||
return;
|
||||
}
|
||||
|
||||
setIsEnhancing(true);
|
||||
setIsEnhanceAnimating(true);
|
||||
|
||||
try {
|
||||
const enhancedPrompt = await enhancePrompt({
|
||||
prompt: currentPrompt,
|
||||
report_style: reportStyle.toUpperCase(),
|
||||
});
|
||||
|
||||
// Add a small delay for better UX
|
||||
await new Promise((resolve) => setTimeout(resolve, 500));
|
||||
|
||||
// Update the input with the enhanced prompt with animation
|
||||
if (inputRef.current) {
|
||||
inputRef.current.setContent(enhancedPrompt);
|
||||
setCurrentPrompt(enhancedPrompt);
|
||||
}
|
||||
|
||||
// Keep animation for a bit longer to show the effect
|
||||
setTimeout(() => {
|
||||
setIsEnhanceAnimating(false);
|
||||
}, 1000);
|
||||
} catch (error) {
|
||||
console.error("Failed to enhance prompt:", error);
|
||||
setIsEnhanceAnimating(false);
|
||||
// Could add toast notification here
|
||||
} finally {
|
||||
setIsEnhancing(false);
|
||||
}
|
||||
}, [currentPrompt, isEnhancing, reportStyle]);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
@@ -144,65 +96,15 @@ export function InputBox({
|
||||
/>
|
||||
</motion.div>
|
||||
)}
|
||||
{isEnhanceAnimating && (
|
||||
<motion.div
|
||||
className="pointer-events-none absolute inset-0 z-20"
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
exit={{ opacity: 0 }}
|
||||
transition={{ duration: 0.3 }}
|
||||
>
|
||||
<div className="relative h-full w-full">
|
||||
{/* Sparkle effect overlay */}
|
||||
<motion.div
|
||||
className="absolute inset-0 rounded-[24px] bg-gradient-to-r from-blue-500/10 via-purple-500/10 to-blue-500/10"
|
||||
animate={{
|
||||
background: [
|
||||
"linear-gradient(45deg, rgba(59, 130, 246, 0.1), rgba(147, 51, 234, 0.1), rgba(59, 130, 246, 0.1))",
|
||||
"linear-gradient(225deg, rgba(147, 51, 234, 0.1), rgba(59, 130, 246, 0.1), rgba(147, 51, 234, 0.1))",
|
||||
"linear-gradient(45deg, rgba(59, 130, 246, 0.1), rgba(147, 51, 234, 0.1), rgba(59, 130, 246, 0.1))",
|
||||
],
|
||||
}}
|
||||
transition={{ duration: 2, repeat: Infinity }}
|
||||
/>
|
||||
{/* Floating sparkles */}
|
||||
{[...Array(6)].map((_, i) => (
|
||||
<motion.div
|
||||
key={i}
|
||||
className="absolute h-2 w-2 rounded-full bg-blue-400"
|
||||
style={{
|
||||
left: `${20 + i * 12}%`,
|
||||
top: `${30 + (i % 2) * 40}%`,
|
||||
}}
|
||||
animate={{
|
||||
y: [-10, -20, -10],
|
||||
opacity: [0, 1, 0],
|
||||
scale: [0.5, 1, 0.5],
|
||||
}}
|
||||
transition={{
|
||||
duration: 1.5,
|
||||
repeat: Infinity,
|
||||
delay: i * 0.2,
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</motion.div>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
<MessageInput
|
||||
className={cn(
|
||||
"h-24 px-4 pt-5",
|
||||
feedback && "pt-9",
|
||||
isEnhanceAnimating && "transition-all duration-500",
|
||||
)}
|
||||
className={cn("h-24 px-4 pt-5", feedback && "pt-9")}
|
||||
ref={inputRef}
|
||||
onEnter={handleSendMessage}
|
||||
onChange={setCurrentPrompt}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-center px-4 py-2">
|
||||
<div className="flex grow gap-2">
|
||||
<div className="flex grow">
|
||||
<Tooltip
|
||||
className="max-w-60"
|
||||
title={
|
||||
@@ -231,29 +133,8 @@ export function InputBox({
|
||||
<Detective /> Investigation
|
||||
</Button>
|
||||
</Tooltip>
|
||||
<ReportStyleDialog />
|
||||
</div>
|
||||
<div className="flex shrink-0 items-center gap-2">
|
||||
<Tooltip title="Enhance prompt with AI">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className={cn(
|
||||
"hover:bg-accent h-10 w-10",
|
||||
isEnhancing && "animate-pulse",
|
||||
)}
|
||||
onClick={handleEnhancePrompt}
|
||||
disabled={isEnhancing || currentPrompt.trim() === ""}
|
||||
>
|
||||
{isEnhancing ? (
|
||||
<div className="flex h-10 w-10 items-center justify-center">
|
||||
<div className="bg-foreground h-3 w-3 animate-bounce rounded-full opacity-70" />
|
||||
</div>
|
||||
) : (
|
||||
<MagicWandIcon className="text-brand" />
|
||||
)}
|
||||
</Button>
|
||||
</Tooltip>
|
||||
<Tooltip title={responding ? "Stop" : "Send"}>
|
||||
<Button
|
||||
variant="outline"
|
||||
@@ -272,21 +153,6 @@ export function InputBox({
|
||||
</Tooltip>
|
||||
</div>
|
||||
</div>
|
||||
{isEnhancing && (
|
||||
<>
|
||||
<BorderBeam
|
||||
duration={5}
|
||||
size={250}
|
||||
className="from-transparent via-red-500 to-transparent"
|
||||
/>
|
||||
<BorderBeam
|
||||
duration={5}
|
||||
delay={3}
|
||||
size={250}
|
||||
className="from-transparent via-blue-500 to-transparent"
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
import { Check, Copy, Headphones, Pencil, Undo2, X, Download } from "lucide-react";
|
||||
import { Check, Copy, Headphones, Pencil, Undo2, X } from "lucide-react";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
|
||||
import { ScrollContainer } from "~/components/deer-flow/scroll-container";
|
||||
@@ -64,33 +64,6 @@ export function ResearchBlock({
|
||||
}, 1000);
|
||||
}, [reportId]);
|
||||
|
||||
// Download report as markdown
|
||||
const handleDownload = useCallback(() => {
|
||||
if (!reportId) {
|
||||
return;
|
||||
}
|
||||
const report = useStore.getState().messages.get(reportId);
|
||||
if (!report) {
|
||||
return;
|
||||
}
|
||||
const now = new Date();
|
||||
const pad = (n: number) => n.toString().padStart(2, '0');
|
||||
const timestamp = `${now.getFullYear()}-${pad(now.getMonth() + 1)}-${pad(now.getDate())}_${pad(now.getHours())}-${pad(now.getMinutes())}-${pad(now.getSeconds())}`;
|
||||
const filename = `research-report-${timestamp}.md`;
|
||||
const blob = new Blob([report.content], { type: 'text/markdown' });
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement('a');
|
||||
a.href = url;
|
||||
a.download = filename;
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
setTimeout(() => {
|
||||
document.body.removeChild(a);
|
||||
URL.revokeObjectURL(url);
|
||||
}, 0);
|
||||
}, [reportId]);
|
||||
|
||||
|
||||
const handleEdit = useCallback(() => {
|
||||
setEditing((editing) => !editing);
|
||||
}, []);
|
||||
@@ -140,16 +113,6 @@ export function ResearchBlock({
|
||||
{copied ? <Check /> : <Copy />}
|
||||
</Button>
|
||||
</Tooltip>
|
||||
<Tooltip title="Download report as markdown">
|
||||
<Button
|
||||
className="text-gray-400"
|
||||
size="icon"
|
||||
variant="ghost"
|
||||
onClick={handleDownload}
|
||||
>
|
||||
<Download />
|
||||
</Button>
|
||||
</Tooltip>
|
||||
</>
|
||||
)}
|
||||
<Tooltip title="Close">
|
||||
|
||||
@@ -25,6 +25,7 @@ import type { Tab } from "./types";
|
||||
|
||||
const generalFormSchema = z.object({
|
||||
autoAcceptedPlan: z.boolean(),
|
||||
enableBackgroundInvestigation: z.boolean(),
|
||||
maxPlanIterations: z.number().min(1, {
|
||||
message: "Max plan iterations must be at least 1.",
|
||||
}),
|
||||
@@ -34,9 +35,6 @@ const generalFormSchema = z.object({
|
||||
maxSearchResults: z.number().min(1, {
|
||||
message: "Max search results must be at least 1.",
|
||||
}),
|
||||
// Others
|
||||
enableBackgroundInvestigation: z.boolean(),
|
||||
reportStyle: z.enum(["academic", "popular_science", "news", "social_media"]),
|
||||
});
|
||||
|
||||
export const GeneralTab: Tab = ({
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
// Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
import type { SVGProps } from "react";
|
||||
|
||||
export function Enhance(props: SVGProps<SVGSVGElement>) {
|
||||
return (
|
||||
<svg
|
||||
width="16"
|
||||
height="16"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M12 2L13.09 8.26L20 9L13.09 9.74L12 16L10.91 9.74L4 9L10.91 8.26L12 2Z"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
fill="none"
|
||||
/>
|
||||
<path
|
||||
d="M19 14L19.5 16.5L22 17L19.5 17.5L19 20L18.5 17.5L16 17L18.5 16.5L19 14Z"
|
||||
stroke="currentColor"
|
||||
strokeWidth="1.5"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
fill="none"
|
||||
/>
|
||||
<path
|
||||
d="M5 6L5.5 7.5L7 8L5.5 8.5L5 10L4.5 8.5L3 8L4.5 7.5L5 6Z"
|
||||
stroke="currentColor"
|
||||
strokeWidth="1.5"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
fill="none"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
export function ReportStyle({ className }: { className?: string }) {
|
||||
return (
|
||||
<svg
|
||||
className={className}
|
||||
version="1.1"
|
||||
width="800px"
|
||||
height="800px"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
>
|
||||
<g fill="currentcolor">
|
||||
<path
|
||||
d="M4 4C4 3.44772 4.44772 3 5 3H19C19.5523 3 20 3.44772 20 4V20C20 20.5523 19.5523 21 19 21H5C4.44772 21 4 20.5523 4 20V4Z"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
fill="none"
|
||||
/>
|
||||
<path
|
||||
d="M8 7H16"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
/>
|
||||
<path
|
||||
d="M8 11H16"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
/>
|
||||
<path
|
||||
d="M8 15H12"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
/>
|
||||
<circle
|
||||
cx="16"
|
||||
cy="15"
|
||||
r="2"
|
||||
fill="currentColor"
|
||||
/>
|
||||
</g>
|
||||
</svg>
|
||||
);
|
||||
}
|
||||
@@ -26,7 +26,6 @@ import { LoadingOutlined } from "@ant-design/icons";
|
||||
export interface MessageInputRef {
|
||||
focus: () => void;
|
||||
submit: () => void;
|
||||
setContent: (content: string) => void;
|
||||
}
|
||||
|
||||
export interface MessageInputProps {
|
||||
@@ -83,9 +82,8 @@ const MessageInput = forwardRef<MessageInputRef, MessageInputProps>(
|
||||
const debouncedUpdates = useDebouncedCallback(
|
||||
async (editor: EditorInstance) => {
|
||||
if (onChange) {
|
||||
// Get the plain text content for prompt enhancement
|
||||
const { text } = formatMessage(editor.getJSON() ?? []);
|
||||
onChange(text);
|
||||
const markdown = editor.storage.markdown.getMarkdown();
|
||||
onChange(markdown);
|
||||
}
|
||||
},
|
||||
200,
|
||||
@@ -102,12 +100,6 @@ const MessageInput = forwardRef<MessageInputRef, MessageInputProps>(
|
||||
);
|
||||
onEnter(text, resources);
|
||||
}
|
||||
editorRef.current?.commands.clearContent();
|
||||
},
|
||||
setContent: (content: string) => {
|
||||
if (editorRef.current) {
|
||||
editorRef.current.commands.setContent(content);
|
||||
}
|
||||
},
|
||||
}));
|
||||
|
||||
|
||||
@@ -1,128 +0,0 @@
|
||||
// Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
import { useState } from "react";
|
||||
import { Check, FileText, Newspaper, Users, GraduationCap } from "lucide-react";
|
||||
|
||||
import { Button } from "~/components/ui/button";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogTrigger,
|
||||
} from "~/components/ui/dialog";
|
||||
import { setReportStyle, useSettingsStore } from "~/core/store";
|
||||
import { cn } from "~/lib/utils";
|
||||
|
||||
import { Tooltip } from "./tooltip";
|
||||
|
||||
const REPORT_STYLES = [
|
||||
{
|
||||
value: "academic" as const,
|
||||
label: "Academic",
|
||||
description: "Formal, objective, and analytical with precise terminology",
|
||||
icon: GraduationCap,
|
||||
},
|
||||
{
|
||||
value: "popular_science" as const,
|
||||
label: "Popular Science",
|
||||
description: "Engaging and accessible for general audience",
|
||||
icon: FileText,
|
||||
},
|
||||
{
|
||||
value: "news" as const,
|
||||
label: "News",
|
||||
description: "Factual, concise, and impartial journalistic style",
|
||||
icon: Newspaper,
|
||||
},
|
||||
{
|
||||
value: "social_media" as const,
|
||||
label: "Social Media",
|
||||
description: "Concise, attention-grabbing, and shareable",
|
||||
icon: Users,
|
||||
},
|
||||
];
|
||||
|
||||
export function ReportStyleDialog() {
|
||||
const [open, setOpen] = useState(false);
|
||||
const currentStyle = useSettingsStore((state) => state.general.reportStyle);
|
||||
|
||||
const handleStyleChange = (
|
||||
style: "academic" | "popular_science" | "news" | "social_media",
|
||||
) => {
|
||||
setReportStyle(style);
|
||||
setOpen(false);
|
||||
};
|
||||
|
||||
const currentStyleConfig =
|
||||
REPORT_STYLES.find((style) => style.value === currentStyle) ||
|
||||
REPORT_STYLES[0]!;
|
||||
const CurrentIcon = currentStyleConfig.icon;
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={setOpen}>
|
||||
<Tooltip
|
||||
className="max-w-60"
|
||||
title={
|
||||
<div>
|
||||
<h3 className="mb-2 font-bold">
|
||||
Writing Style: {currentStyleConfig.label}
|
||||
</h3>
|
||||
<p>
|
||||
Choose the writing style for your research reports. Different
|
||||
styles are optimized for different audiences and purposes.
|
||||
</p>
|
||||
</div>
|
||||
}
|
||||
>
|
||||
<DialogTrigger asChild>
|
||||
<Button
|
||||
className="!border-brand !text-brand rounded-2xl"
|
||||
variant="outline"
|
||||
>
|
||||
<CurrentIcon className="h-4 w-4" /> {currentStyleConfig.label}
|
||||
</Button>
|
||||
</DialogTrigger>
|
||||
</Tooltip>
|
||||
<DialogContent className="sm:max-w-[500px]">
|
||||
<DialogHeader>
|
||||
<DialogTitle>Choose Writing Style</DialogTitle>
|
||||
<DialogDescription>
|
||||
Select the writing style for your research reports. Each style is
|
||||
optimized for different audiences and purposes.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<div className="grid gap-3 py-4">
|
||||
{REPORT_STYLES.map((style) => {
|
||||
const Icon = style.icon;
|
||||
const isSelected = currentStyle === style.value;
|
||||
|
||||
return (
|
||||
<button
|
||||
key={style.value}
|
||||
className={cn(
|
||||
"hover:bg-accent flex items-start gap-3 rounded-lg border p-4 text-left transition-colors",
|
||||
isSelected && "border-primary bg-accent",
|
||||
)}
|
||||
onClick={() => handleStyleChange(style.value)}
|
||||
>
|
||||
<Icon className="mt-0.5 h-5 w-5 shrink-0" />
|
||||
<div className="flex-1 space-y-1">
|
||||
<div className="flex items-center gap-2">
|
||||
<h4 className="font-medium">{style.label}</h4>
|
||||
{isSelected && <Check className="text-primary h-4 w-4" />}
|
||||
</div>
|
||||
<p className="text-muted-foreground text-sm">
|
||||
{style.description}
|
||||
</p>
|
||||
</div>
|
||||
</button>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,94 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { cn } from "~/lib/utils";
|
||||
import { motion, type MotionStyle, type Transition } from "motion/react";
|
||||
|
||||
interface BorderBeamProps {
|
||||
/**
|
||||
* The size of the border beam.
|
||||
*/
|
||||
size?: number;
|
||||
/**
|
||||
* The duration of the border beam.
|
||||
*/
|
||||
duration?: number;
|
||||
/**
|
||||
* The delay of the border beam.
|
||||
*/
|
||||
delay?: number;
|
||||
/**
|
||||
* The color of the border beam from.
|
||||
*/
|
||||
colorFrom?: string;
|
||||
/**
|
||||
* The color of the border beam to.
|
||||
*/
|
||||
colorTo?: string;
|
||||
/**
|
||||
* The motion transition of the border beam.
|
||||
*/
|
||||
transition?: Transition;
|
||||
/**
|
||||
* The class name of the border beam.
|
||||
*/
|
||||
className?: string;
|
||||
/**
|
||||
* The style of the border beam.
|
||||
*/
|
||||
style?: React.CSSProperties;
|
||||
/**
|
||||
* Whether to reverse the animation direction.
|
||||
*/
|
||||
reverse?: boolean;
|
||||
/**
|
||||
* The initial offset position (0-100).
|
||||
*/
|
||||
initialOffset?: number;
|
||||
}
|
||||
|
||||
export const BorderBeam = ({
|
||||
className,
|
||||
size = 50,
|
||||
delay = 0,
|
||||
duration = 6,
|
||||
colorFrom = "#ffaa40",
|
||||
colorTo = "#9c40ff",
|
||||
transition,
|
||||
style,
|
||||
reverse = false,
|
||||
initialOffset = 0,
|
||||
}: BorderBeamProps) => {
|
||||
return (
|
||||
<div className="pointer-events-none absolute inset-0 rounded-[inherit] border border-transparent [mask-image:linear-gradient(transparent,transparent),linear-gradient(#000,#000)] [mask-composite:intersect] [mask-clip:padding-box,border-box]">
|
||||
<motion.div
|
||||
className={cn(
|
||||
"absolute aspect-square",
|
||||
"bg-gradient-to-l from-[var(--color-from)] via-[var(--color-to)] to-transparent",
|
||||
className,
|
||||
)}
|
||||
style={
|
||||
{
|
||||
width: size,
|
||||
offsetPath: `rect(0 auto auto 0 round ${size}px)`,
|
||||
"--color-from": colorFrom,
|
||||
"--color-to": colorTo,
|
||||
...style,
|
||||
} as MotionStyle
|
||||
}
|
||||
initial={{ offsetDistance: `${initialOffset}%` }}
|
||||
animate={{
|
||||
offsetDistance: reverse
|
||||
? [`${100 - initialOffset}%`, `${-initialOffset}%`]
|
||||
: [`${initialOffset}%`, `${100 + initialOffset}%`],
|
||||
}}
|
||||
transition={{
|
||||
repeat: Infinity,
|
||||
ease: "linear",
|
||||
duration,
|
||||
delay: -delay,
|
||||
...transition,
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -23,7 +23,6 @@ export async function* chatStream(
|
||||
max_search_results?: number;
|
||||
interrupt_feedback?: string;
|
||||
enable_background_investigation: boolean;
|
||||
report_style?: "academic" | "popular_science" | "news" | "social_media";
|
||||
mcp_settings?: {
|
||||
servers: Record<
|
||||
string,
|
||||
|
||||
@@ -4,5 +4,4 @@
|
||||
export * from "./chat";
|
||||
export * from "./mcp";
|
||||
export * from "./podcast";
|
||||
export * from "./prompt-enhancer";
|
||||
export * from "./types";
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
// Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
import { resolveServiceURL } from "./resolve-service-url";
|
||||
|
||||
export interface EnhancePromptRequest {
|
||||
prompt: string;
|
||||
context?: string;
|
||||
report_style?: string;
|
||||
}
|
||||
|
||||
export interface EnhancePromptResponse {
|
||||
enhanced_prompt: string;
|
||||
}
|
||||
|
||||
export async function enhancePrompt(
|
||||
request: EnhancePromptRequest,
|
||||
): Promise<string> {
|
||||
const response = await fetch(resolveServiceURL("prompt/enhance"), {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(request),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
console.log("Raw API response:", data); // Debug log
|
||||
|
||||
// The backend now returns the enhanced prompt directly in the result field
|
||||
let enhancedPrompt = data.result;
|
||||
|
||||
// If the result is somehow still a JSON object, extract the enhanced_prompt
|
||||
if (typeof enhancedPrompt === "object" && enhancedPrompt.enhanced_prompt) {
|
||||
enhancedPrompt = enhancedPrompt.enhanced_prompt;
|
||||
}
|
||||
|
||||
// If the result is a JSON string, try to parse it
|
||||
if (typeof enhancedPrompt === "string") {
|
||||
try {
|
||||
const parsed = JSON.parse(enhancedPrompt);
|
||||
if (parsed.enhanced_prompt) {
|
||||
enhancedPrompt = parsed.enhanced_prompt;
|
||||
}
|
||||
} catch {
|
||||
// If parsing fails, use the string as-is (which is what we want)
|
||||
console.log("Using enhanced prompt as-is:", enhancedPrompt);
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to original prompt if something went wrong
|
||||
if (!enhancedPrompt || enhancedPrompt.trim() === "") {
|
||||
console.warn("No enhanced prompt received, using original");
|
||||
enhancedPrompt = request.prompt;
|
||||
}
|
||||
|
||||
return enhancedPrompt;
|
||||
}
|
||||
@@ -14,7 +14,6 @@ const DEFAULT_SETTINGS: SettingsState = {
|
||||
maxPlanIterations: 1,
|
||||
maxStepNum: 3,
|
||||
maxSearchResults: 3,
|
||||
reportStyle: "academic",
|
||||
},
|
||||
mcp: {
|
||||
servers: [],
|
||||
@@ -28,7 +27,6 @@ export type SettingsState = {
|
||||
maxPlanIterations: number;
|
||||
maxStepNum: number;
|
||||
maxSearchResults: number;
|
||||
reportStyle: "academic" | "popular_science" | "news" | "social_media";
|
||||
};
|
||||
mcp: {
|
||||
servers: MCPServerMetadata[];
|
||||
@@ -127,16 +125,6 @@ export const getChatStreamSettings = () => {
|
||||
};
|
||||
};
|
||||
|
||||
export function setReportStyle(value: "academic" | "popular_science" | "news" | "social_media") {
|
||||
useSettingsStore.setState((state) => ({
|
||||
general: {
|
||||
...state.general,
|
||||
reportStyle: value,
|
||||
},
|
||||
}));
|
||||
saveSettings();
|
||||
}
|
||||
|
||||
export function setEnableBackgroundInvestigation(value: boolean) {
|
||||
useSettingsStore.setState((state) => ({
|
||||
general: {
|
||||
|
||||
@@ -109,7 +109,6 @@ export async function sendMessage(
|
||||
max_plan_iterations: settings.maxPlanIterations,
|
||||
max_step_num: settings.maxStepNum,
|
||||
max_search_results: settings.maxSearchResults,
|
||||
report_style: settings.reportStyle,
|
||||
mcp_settings: settings.mcpSettings,
|
||||
},
|
||||
options,
|
||||
|
||||
Reference in New Issue
Block a user