123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605 |
- """
- 研究相关API路由
- """
- from fastapi import APIRouter, HTTPException, BackgroundTasks, Query, WebSocket, WebSocketDisconnect
- from pydantic import BaseModel, Field
- from typing import List, Dict, Any, Optional
- import logging
- import asyncio
- import json
- import time
- from pathlib import Path
- import uuid
- from fastapi.responses import StreamingResponse
- from backend.core.research import ResearchAgent
- from backend.core.clustering import PaperClusterer
- from backend.core.report import ReportGenerator
- from backend.config import MAX_SEARCH_RESULTS
- logger = logging.getLogger(__name__)
- router = APIRouter(prefix="/research", tags=["research"])
- # 数据模型
- class ResearchRequest(BaseModel):
- research_intent: str = Field(..., description="用户的研究意图")
- max_results: Optional[int] = Field(None, description="已废弃,每个研究方向固定返回3-5篇论文")
- class KeywordsRequest(BaseModel):
- research_intent: str = Field(..., description="用户的研究意图")
- class PaperSearchRequest(BaseModel):
- keywords: List[str] = Field(..., description="检索关键词")
- max_results: int = Field(MAX_SEARCH_RESULTS, description="最大检索结果数量")
- class ClusterRequest(BaseModel):
- papers: List[Dict[str, Any]] = Field(..., description="要聚类的论文")
- num_clusters: int = Field(0, description="聚类数量,0表示自动确定")
- class ReportRequest(BaseModel):
- research_intent: str = Field(..., description="研究意图")
- keywords: List[str] = Field(..., description="关键词")
- papers: List[Dict[str, Any]] = Field(..., description="检索到的论文")
- clusters: Optional[Dict[str, Any]] = Field(None, description="聚类结果")
- # 全局实例
- research_agent = ResearchAgent()
- paper_clusterer = PaperClusterer()
- report_generator = ReportGenerator()
- # 存储进度信息的全局字典
- task_progress = {}
- # 添加函数更新进度
- def update_task_progress(task_id: str, progress: int, status: str, message: str = ""):
- """更新任务进度"""
- task_progress[task_id] = {
- "percentage": progress,
- "status": status,
- "message": message,
- "updated_at": time.time()
- }
-
- # 清理旧任务
- current_time = time.time()
- expired_tasks = [tid for tid, data in task_progress.items()
- if current_time - data["updated_at"] > 3600] # 1小时过期
- for tid in expired_tasks:
- del task_progress[tid]
- # 添加函数获取进度
- async def get_task_progress(task_id: str):
- """获取任务进度"""
- if task_id in task_progress:
- return task_progress[task_id]
- else:
- return {
- "percentage": 0,
- "status": "unknown",
- "message": "任务未找到或已过期",
- "updated_at": time.time()
- }
- # 路由定义
- @router.post("/process")
- async def process_research(request: ResearchRequest):
- """处理研究请求"""
- try:
- # 移除 progress_callback 参数
- research_agent = ResearchAgent()
-
- # 处理研究
- result = await research_agent.process_research_intent(
- request.research_intent,
- max_results=None
- )
-
- if result["status"] == "error":
- raise HTTPException(
- status_code=500,
- detail=result.get("error", "未知错误")
- )
-
- return result
-
- except Exception as e:
- logger.error(f"Error processing research: {str(e)}", exc_info=True)
- raise HTTPException(
- status_code=500,
- detail=str(e)
- )
- @router.post("/extract-keywords")
- async def extract_keywords(request: KeywordsRequest):
- """从研究意图中提取关键词"""
- try:
- logger.info(f"Extracting keywords from: {request.research_intent}")
- result = await research_agent.llm_client.extract_keywords(
- research_topic=request.research_intent
- )
- return {"keywords": result}
- except Exception as e:
- logger.error(f"Error extracting keywords: {str(e)}", exc_info=True)
- raise HTTPException(status_code=500, detail=str(e))
- @router.post("/search-papers")
- async def search_papers(request: PaperSearchRequest):
- """根据关键词检索论文"""
- try:
- logger.info(f"Searching papers with keywords: {request.keywords}")
- papers = []
- for keyword in request.keywords:
- results = await research_agent.arxiv_client.search_papers(
- query=keyword,
- max_results=max(3, request.max_results // len(request.keywords))
- )
- papers.extend(results)
-
- # 去重
- unique_papers = []
- paper_ids = set()
- for paper in papers:
- if paper["id"] not in paper_ids:
- unique_papers.append(paper)
- paper_ids.add(paper["id"])
-
- return {"papers": unique_papers, "count": len(unique_papers)}
- except Exception as e:
- logger.error(f"Error searching papers: {str(e)}", exc_info=True)
- raise HTTPException(status_code=500, detail=str(e))
- @router.post("/cluster-papers")
- async def cluster_papers(request: ClusterRequest):
- """对论文进行聚类分析"""
- try:
- logger.info(f"Clustering {len(request.papers)} papers")
- result = paper_clusterer.cluster_papers(
- papers=request.papers,
- num_clusters=request.num_clusters
- )
- return result
- except Exception as e:
- logger.error(f"Error clustering papers: {str(e)}", exc_info=True)
- raise HTTPException(status_code=500, detail=str(e))
- @router.post("/generate-report")
- async def generate_report(request: ReportRequest):
- """生成研究报告"""
- try:
- logger.info(f"Generating report for: {request.research_intent}")
- report = await report_generator.generate_report(
- research_intent=request.research_intent,
- keywords=request.keywords,
- papers=request.papers,
- clusters=request.clusters
- )
- return report
- except Exception as e:
- logger.error(f"Error generating report: {str(e)}", exc_info=True)
- raise HTTPException(status_code=500, detail=str(e))
- # 添加WebSocket端点
- @router.websocket("/ws/research/{task_id}")
- async def research_progress(websocket: WebSocket, task_id: str):
- await websocket.accept()
-
- try:
- # 等待进度更新
- while True:
- # 查询任务进度
- progress = await get_task_progress(task_id)
-
- # 发送给客户端
- await websocket.send_json(progress)
-
- # 如果任务完成,退出循环
- if progress["status"] in ["completed", "error"]:
- break
-
- # 等待下一次更新
- await asyncio.sleep(1)
- except WebSocketDisconnect:
- logger.info(f"Client disconnected from progress updates for task {task_id}")
- # 创建事件流响应
- @router.get("/process-stream")
- async def process_research_stream(research_intent: str):
- """以事件流形式返回研究处理结果"""
-
- if not research_intent or research_intent.strip() == "":
- raise HTTPException(status_code=400, detail="研究主题不能为空")
-
- async def event_generator():
- try:
- # 初始状态
- yield f"data: {json.dumps({'status': 'starting', 'message': '正在开始研究流程...'})}\n\n"
- await asyncio.sleep(0.1)
-
- research_agent = ResearchAgent()
-
- # 1. 提取关键词
- yield f"data: {json.dumps({'status': 'extracting_keywords', 'message': '正在提取关键词...'})}\n\n"
- keywords_data = await research_agent.extract_keywords_only(research_intent)
-
- # 发送关键词结果
- keywords_result = {
- 'status': 'keywords_ready',
- 'data': keywords_data,
- 'message': '关键词提取完成'
- }
- yield f"data: {json.dumps(keywords_result)}\n\n"
- await asyncio.sleep(0.1)
-
- # 2. 生成研究方向
- yield f"data: {json.dumps({'status': 'generating_directions', 'message': '正在生成研究方向...'})}\n\n"
- directions_data = await research_agent.generate_directions_only(
- keywords_data["english_keywords"],
- keywords_data["language"]
- )
-
- # 发送研究方向结果
- directions_result = {
- 'status': 'directions_ready',
- 'data': directions_data,
- 'message': '研究方向生成完成'
- }
- yield f"data: {json.dumps(directions_result)}\n\n"
- await asyncio.sleep(0.1)
-
- # 3. 处理每个研究方向
- for i, direction in enumerate(directions_data["english_directions"]):
- # 发送搜索状态
- search_status = {
- 'status': 'searching_papers',
- 'message': f'正在搜索方向 {i+1}/{len(directions_data["english_directions"])} 的相关论文...',
- 'current_direction': i
- }
- yield f"data: {json.dumps(search_status)}\n\n"
-
- # 获取原始语言的方向
- original_dir = directions_data["original_directions"][i] if i < len(directions_data["original_directions"]) else direction
-
- # 搜索论文
- papers_data = await research_agent.search_papers_for_direction(direction)
-
- # 发送论文结果
- papers_result = {
- 'status': 'papers_ready',
- 'data': {
- 'direction': direction,
- 'original_direction': original_dir,
- 'papers': papers_data,
- 'direction_index': i
- },
- 'message': f'方向 {i+1} 的论文搜索完成'
- }
- yield f"data: {json.dumps(papers_result)}\n\n"
- await asyncio.sleep(0.1)
-
- # 如果有论文,生成报告
- if papers_data:
- report_status = {
- 'status': 'generating_report',
- 'message': f'正在为方向 {i+1} 生成研究报告...',
- 'current_direction': i
- }
- yield f"data: {json.dumps(report_status)}\n\n"
-
- # 生成报告
- report_data = await research_agent.generate_report_for_direction(
- direction,
- papers_data,
- keywords_data["language"]
- )
-
- # 发送报告结果
- report_result = {
- 'status': 'report_ready',
- 'data': {
- 'direction': direction,
- 'original_direction': original_dir,
- 'report': report_data,
- 'direction_index': i
- },
- 'message': f'方向 {i+1} 的研究报告生成完成'
- }
- yield f"data: {json.dumps(report_result)}\n\n"
- await asyncio.sleep(0.1)
-
- # 全部完成
- yield f"data: {json.dumps({'status': 'completed', 'message': '研究流程全部完成'})}\n\n"
-
- except Exception as e:
- # 错误处理
- logger.error(f"Error in streaming process: {str(e)}", exc_info=True)
- error_result = {'status': 'error', 'message': f'处理过程中发生错误: {str(e)}'}
- yield f"data: {json.dumps(error_result)}\n\n"
-
- # 返回流式响应
- return StreamingResponse(
- event_generator(),
- media_type="text/event-stream",
- headers={
- 'Cache-Control': 'no-cache',
- 'Connection': 'keep-alive',
- 'X-Accel-Buffering': 'no'
- }
- )
- @router.get("/enhanced-process-stream")
- async def enhanced_process_stream(research_intent: str):
- """以增强的事件流形式返回研究处理结果,按照研究主题逐一处理"""
-
- if not research_intent or research_intent.strip() == "":
- raise HTTPException(status_code=400, detail="研究主题不能为空")
-
- async def event_generator():
- try:
- # 初始状态
- yield f"data: {json.dumps({'status': 'starting', 'message': '正在开始研究流程...'})}\n\n"
- await asyncio.sleep(0.1)
-
- research_agent = ResearchAgent()
-
- # 1. 提取基础关键词
- yield f"data: {json.dumps({'status': 'extracting_base_keywords', 'message': '正在提取基础关键词...'})}\n\n"
- base_keywords_data = await research_agent.extract_keywords_only(research_intent)
-
- # 发送基础关键词结果
- base_keywords_result = {
- 'status': 'base_keywords_extracted',
- 'data': {
- 'language': base_keywords_data['language'],
- 'keywords': base_keywords_data['english_keywords'],
- 'translations': base_keywords_data['original_keywords'] if base_keywords_data['language'] != 'en' else None
- },
- 'message': '基础关键词提取完成'
- }
- yield f"data: {json.dumps(base_keywords_result)}\n\n"
- await asyncio.sleep(0.1)
-
- # 2. 生成研究方向
- yield f"data: {json.dumps({'status': 'generating_research_topics', 'message': '正在基于关键词生成研究方向...'})}\n\n"
-
- # 确保方法存在
- if not hasattr(research_agent, 'generate_enhanced_topics'):
- # 临时实现一个简单版本作为后备
- logger.warning("Method generate_enhanced_topics not found, using fallback implementation")
- research_topics = []
- for i, keyword in enumerate(base_keywords_data["english_keywords"][:3]):
- research_topics.append({
- "english_title": f"Research on {keyword}",
- "title": f"关于{keyword}的研究" if base_keywords_data["language"] != "en" else f"Research on {keyword}",
- "description": f"Investigating various aspects of {keyword} in the context of the research intent",
- "keywords": [keyword]
- })
- else:
- research_topics = await research_agent.generate_enhanced_topics(
- base_keywords_data["english_keywords"],
- base_keywords_data["language"]
- )
-
- # 发送研究方向结果
- topics_result = {
- 'status': 'research_topics_generated',
- 'data': {
- 'language': base_keywords_data['language'],
- 'research_topics': research_topics
- },
- 'message': '研究方向生成完成'
- }
- yield f"data: {json.dumps(topics_result)}\n\n"
- await asyncio.sleep(0.1)
-
- # 记录日志,验证进度
- logger.info(f"Generated {len(research_topics)} research topics")
-
- # 3. 处理每个研究方向 - 防止后续方法未实现导致整个流程中断
- for i, topic in enumerate(research_topics):
- try:
- # 3.1 为这个方向生成专门的搜索关键词
- yield f"data: {json.dumps({'status': 'generating_search_keywords', 'message': f'正在为研究方向 {i+1}/{len(research_topics)} 生成搜索关键词...'})}\n\n"
-
- # 检查方法是否存在
- if hasattr(research_agent, 'generate_search_keywords_for_topic'):
- search_keywords = await research_agent.generate_search_keywords_for_topic(topic)
- else:
- # 后备方案:使用主题自带关键词
- search_keywords = topic.get('keywords', [])
- if not search_keywords:
- search_keywords = [topic['english_title']]
-
- # 确保搜索关键词不是错误消息
- if isinstance(search_keywords, list) and all(isinstance(kw, str) for kw in search_keywords):
- # 验证关键词是否合理
- if any(len(kw) > 3 for kw in search_keywords):
- # 关键词看起来合理
- pass
- else:
- # 关键词可能不合理,使用后备关键词
- search_keywords = [topic['english_title']] + topic.get('keywords', [])
- else:
- # 关键词格式不正确,使用后备关键词
- search_keywords = [topic['english_title']] + topic.get('keywords', [])
-
- # 发送搜索关键词
- keywords_result = {
- 'status': 'search_keywords_generated',
- 'data': {
- 'language': base_keywords_data['language'],
- 'keywords': search_keywords,
- 'topic_index': i,
- 'topic': topic['english_title']
- },
- 'message': f'研究方向 {i+1} 的搜索关键词生成完成'
- }
- yield f"data: {json.dumps(keywords_result)}\n\n"
- await asyncio.sleep(0.1)
-
- # 记录日志,验证进度
- logger.info(f"Generated search keywords for topic {i+1}: {search_keywords}")
- except Exception as e:
- logger.error(f"Error generating search keywords for topic {i+1}: {str(e)}", exc_info=True)
- yield f"data: {{\"status\": \"warning\", \"message\": \"生成方向 {i+1} 搜索关键词失败: {str(e)}\"}}\n\n"
- # 使用更可靠的后备关键词
- search_keywords = [topic['english_title']]
- if 'keywords' in topic and isinstance(topic['keywords'], list) and len(topic['keywords']) > 0:
- search_keywords.extend(topic['keywords'][:2])
-
- # 3.2 搜索论文
- try:
- yield f"data: {json.dumps({'status': 'searching_papers', 'message': f'正在搜索方向 {i+1}/{len(research_topics)} 的相关论文...', 'current_direction': i})}\n\n"
-
- # 检查方法是否存在
- if hasattr(research_agent, 'search_papers_with_keywords'):
- papers_data = await research_agent.search_papers_with_keywords(search_keywords)
- else:
- # 后备方案:使用现有方法
- papers_data = await research_agent.search_papers_for_direction(topic['english_title'])
-
- logger.info(f"Found {len(papers_data)} papers for topic {i+1}")
- except Exception as e:
- logger.error(f"Error searching papers for topic {i+1}: {str(e)}", exc_info=True)
- yield f"data: {json.dumps({'status': 'warning', 'message': f'搜索方向 {i+1} 论文失败: {str(e)}'})}\n\n"
- # 空列表作为后备
- papers_data = []
-
- # 3.3 论文聚类 - 可选步骤
- try:
- if papers_data and len(papers_data) >= 3:
- yield f"data: {json.dumps({'status': 'clustering_papers', 'message': f'正在对方向 {i+1} 的论文进行聚类分析...'})}\n\n"
-
- # 检查方法是否存在
- if hasattr(research_agent, 'cluster_papers_by_keywords'):
- clustered_papers = await research_agent.cluster_papers_by_keywords(
- papers_data, search_keywords, topic
- )
- else:
- # 后备方案:简单结构
- clustered_papers = {
- "papers": papers_data,
- "clusters": [{
- "id": 0,
- "name": topic["english_title"],
- "keywords": search_keywords[:3] if search_keywords else []
- }]
- }
- else:
- # 论文太少,不聚类
- clustered_papers = {
- "papers": papers_data,
- "clusters": [{
- "id": 0,
- "name": topic["english_title"],
- "keywords": search_keywords[:3] if search_keywords else []
- }]
- }
- except Exception as e:
- logger.error(f"Error clustering papers for topic {i+1}: {str(e)}", exc_info=True)
- # 后备方案
- clustered_papers = {
- "papers": papers_data,
- "clusters": [{
- "id": 0,
- "name": topic["english_title"],
- "keywords": search_keywords[:3] if search_keywords else []
- }]
- }
-
- # 3.4 发送论文结果
- papers_result = {
- 'status': 'papers_ready',
- 'data': {
- 'direction': topic['english_title'],
- 'original_direction': topic['title'],
- 'search_keywords': search_keywords,
- 'papers': clustered_papers['papers'],
- 'cluster_info': clustered_papers.get('clusters', []),
- 'direction_index': i
- },
- 'message': f'方向 {i+1} 的论文搜索和聚类完成'
- }
- yield f"data: {json.dumps(papers_result)}\n\n"
- await asyncio.sleep(0.1)
-
- # 3.5 生成报告 - 如果有论文
- if papers_data:
- try:
- yield f"data: {{\"status\": \"generating_report\", \"message\": \"正在为方向 {i+1} 生成研究报告...\", \"current_direction\": {i}}}\n\n"
- await asyncio.sleep(0.1)
-
- # 使用asyncio.wait_for限制执行时间
- try:
- # 始终使用generate_enhanced_report函数
- report_data = await asyncio.wait_for(
- research_agent.generate_enhanced_report(
- topic, clustered_papers, search_keywords, base_keywords_data["language"]
- ),
- timeout=300.0
- )
-
- # 发送报告结果
- report_result = {
- 'status': 'report_ready',
- 'data': {
- 'direction': topic['english_title'],
- 'original_direction': topic['title'],
- 'search_keywords': search_keywords,
- 'report': report_data,
- 'direction_index': i
- },
- 'message': f'方向 {i+1} 的研究报告生成完成'
- }
- yield f"data: {json.dumps(report_result)}\n\n"
- logger.info(f"Generated report for topic {i+1}")
- except asyncio.TimeoutError:
- # 特别处理超时情况
- logger.warning(f"Report generation timed out for topic {i+1}")
- # 使用更直接的后备机制
- fallback_report = {
- "english_content": f"# Report generation timed out\n\nWe apologize, but the report generation for this topic took too long and was automatically stopped. This topic has {len(papers_data)} related papers that you can review manually.",
- "translated_content": f"# 报告生成超时\n\n很抱歉,此主题的报告生成时间过长,已自动停止。该主题有{len(papers_data)}篇相关论文,您可以手动查看。"
- }
-
- fallback_result = {
- 'status': 'report_ready',
- 'data': {
- 'direction': topic['english_title'],
- 'original_direction': topic['title'],
- 'search_keywords': search_keywords,
- 'report': fallback_report,
- 'direction_index': i,
- 'is_fallback': True
- },
- 'message': f'方向 {i+1} 的报告生成超时,提供简化版本'
- }
- yield f"data: {json.dumps(fallback_result)}\n\n"
- await asyncio.sleep(0.1)
-
- # 添加完成消息,确保前端知道此主题处理完毕
- yield f"data: {{\"status\": \"direction_completed\", \"message\": \"方向 {i+1} 处理完成\", \"direction_index\": {i}}}\n\n"
- await asyncio.sleep(0.1)
- except Exception as e:
- logger.error(f"Error handling report generation for topic {i+1}: {str(e)}", exc_info=True)
- yield f"data: {{\"status\": \"warning\", \"message\": \"生成方向 {i+1} 研究报告处理失败: {str(e)}\"}}\n\n"
- await asyncio.sleep(0.1)
- else:
- logger.warning(f"No papers found for topic {i+1}, skipping report generation")
- except Exception as e:
- # 错误处理
- logger.error(f"Error in enhanced streaming process: {str(e)}", exc_info=True)
- error_result = {'status': 'error', 'message': f'处理过程中发生错误: {str(e)}'}
- yield f"data: {json.dumps(error_result)}\n\n"
-
- # 返回流式响应
- return StreamingResponse(
- event_generator(),
- media_type="text/event-stream",
- headers={
- 'Cache-Control': 'no-cache',
- 'Connection': 'keep-alive',
- 'X-Accel-Buffering': 'no'
- }
- )
|