""" 研究相关API路由 """ from fastapi import APIRouter, HTTPException, BackgroundTasks, Query, WebSocket, WebSocketDisconnect from pydantic import BaseModel, Field from typing import List, Dict, Any, Optional import logging import asyncio import json import time from pathlib import Path import uuid from fastapi.responses import StreamingResponse from backend.core.research import ResearchAgent from backend.core.clustering import PaperClusterer from backend.core.report import ReportGenerator from backend.config import MAX_SEARCH_RESULTS logger = logging.getLogger(__name__) router = APIRouter(prefix="/research", tags=["research"]) # 数据模型 class ResearchRequest(BaseModel): research_intent: str = Field(..., description="用户的研究意图") max_results: Optional[int] = Field(None, description="已废弃,每个研究方向固定返回3-5篇论文") class KeywordsRequest(BaseModel): research_intent: str = Field(..., description="用户的研究意图") class PaperSearchRequest(BaseModel): keywords: List[str] = Field(..., description="检索关键词") max_results: int = Field(MAX_SEARCH_RESULTS, description="最大检索结果数量") class ClusterRequest(BaseModel): papers: List[Dict[str, Any]] = Field(..., description="要聚类的论文") num_clusters: int = Field(0, description="聚类数量,0表示自动确定") class ReportRequest(BaseModel): research_intent: str = Field(..., description="研究意图") keywords: List[str] = Field(..., description="关键词") papers: List[Dict[str, Any]] = Field(..., description="检索到的论文") clusters: Optional[Dict[str, Any]] = Field(None, description="聚类结果") # 全局实例 research_agent = ResearchAgent() paper_clusterer = PaperClusterer() report_generator = ReportGenerator() # 存储进度信息的全局字典 task_progress = {} # 添加函数更新进度 def update_task_progress(task_id: str, progress: int, status: str, message: str = ""): """更新任务进度""" task_progress[task_id] = { "percentage": progress, "status": status, "message": message, "updated_at": time.time() } # 清理旧任务 current_time = time.time() expired_tasks = [tid for tid, data in task_progress.items() if current_time - data["updated_at"] > 3600] # 1小时过期 for tid in expired_tasks: del task_progress[tid] # 添加函数获取进度 async def get_task_progress(task_id: str): """获取任务进度""" if task_id in task_progress: return task_progress[task_id] else: return { "percentage": 0, "status": "unknown", "message": "任务未找到或已过期", "updated_at": time.time() } # 路由定义 @router.post("/process") async def process_research(request: ResearchRequest): """处理研究请求""" try: # 移除 progress_callback 参数 research_agent = ResearchAgent() # 处理研究 result = await research_agent.process_research_intent( request.research_intent, max_results=None ) if result["status"] == "error": raise HTTPException( status_code=500, detail=result.get("error", "未知错误") ) return result except Exception as e: logger.error(f"Error processing research: {str(e)}", exc_info=True) raise HTTPException( status_code=500, detail=str(e) ) @router.post("/extract-keywords") async def extract_keywords(request: KeywordsRequest): """从研究意图中提取关键词""" try: logger.info(f"Extracting keywords from: {request.research_intent}") result = await research_agent.llm_client.extract_keywords( research_topic=request.research_intent ) return {"keywords": result} except Exception as e: logger.error(f"Error extracting keywords: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @router.post("/search-papers") async def search_papers(request: PaperSearchRequest): """根据关键词检索论文""" try: logger.info(f"Searching papers with keywords: {request.keywords}") papers = [] for keyword in request.keywords: results = await research_agent.arxiv_client.search_papers( query=keyword, max_results=max(3, request.max_results // len(request.keywords)) ) papers.extend(results) # 去重 unique_papers = [] paper_ids = set() for paper in papers: if paper["id"] not in paper_ids: unique_papers.append(paper) paper_ids.add(paper["id"]) return {"papers": unique_papers, "count": len(unique_papers)} except Exception as e: logger.error(f"Error searching papers: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @router.post("/cluster-papers") async def cluster_papers(request: ClusterRequest): """对论文进行聚类分析""" try: logger.info(f"Clustering {len(request.papers)} papers") result = paper_clusterer.cluster_papers( papers=request.papers, num_clusters=request.num_clusters ) return result except Exception as e: logger.error(f"Error clustering papers: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @router.post("/generate-report") async def generate_report(request: ReportRequest): """生成研究报告""" try: logger.info(f"Generating report for: {request.research_intent}") report = await report_generator.generate_report( research_intent=request.research_intent, keywords=request.keywords, papers=request.papers, clusters=request.clusters ) return report except Exception as e: logger.error(f"Error generating report: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) # 添加WebSocket端点 @router.websocket("/ws/research/{task_id}") async def research_progress(websocket: WebSocket, task_id: str): await websocket.accept() try: # 等待进度更新 while True: # 查询任务进度 progress = await get_task_progress(task_id) # 发送给客户端 await websocket.send_json(progress) # 如果任务完成,退出循环 if progress["status"] in ["completed", "error"]: break # 等待下一次更新 await asyncio.sleep(1) except WebSocketDisconnect: logger.info(f"Client disconnected from progress updates for task {task_id}") # 创建事件流响应 @router.get("/process-stream") async def process_research_stream(research_intent: str): """以事件流形式返回研究处理结果""" if not research_intent or research_intent.strip() == "": raise HTTPException(status_code=400, detail="研究主题不能为空") async def event_generator(): try: # 初始状态 yield f"data: {json.dumps({'status': 'starting', 'message': '正在开始研究流程...'})}\n\n" await asyncio.sleep(0.1) research_agent = ResearchAgent() # 1. 提取关键词 yield f"data: {json.dumps({'status': 'extracting_keywords', 'message': '正在提取关键词...'})}\n\n" keywords_data = await research_agent.extract_keywords_only(research_intent) # 发送关键词结果 keywords_result = { 'status': 'keywords_ready', 'data': keywords_data, 'message': '关键词提取完成' } yield f"data: {json.dumps(keywords_result)}\n\n" await asyncio.sleep(0.1) # 2. 生成研究方向 yield f"data: {json.dumps({'status': 'generating_directions', 'message': '正在生成研究方向...'})}\n\n" directions_data = await research_agent.generate_directions_only( keywords_data["english_keywords"], keywords_data["language"] ) # 发送研究方向结果 directions_result = { 'status': 'directions_ready', 'data': directions_data, 'message': '研究方向生成完成' } yield f"data: {json.dumps(directions_result)}\n\n" await asyncio.sleep(0.1) # 3. 处理每个研究方向 for i, direction in enumerate(directions_data["english_directions"]): # 发送搜索状态 search_status = { 'status': 'searching_papers', 'message': f'正在搜索方向 {i+1}/{len(directions_data["english_directions"])} 的相关论文...', 'current_direction': i } yield f"data: {json.dumps(search_status)}\n\n" # 获取原始语言的方向 original_dir = directions_data["original_directions"][i] if i < len(directions_data["original_directions"]) else direction # 搜索论文 papers_data = await research_agent.search_papers_for_direction(direction) # 发送论文结果 papers_result = { 'status': 'papers_ready', 'data': { 'direction': direction, 'original_direction': original_dir, 'papers': papers_data, 'direction_index': i }, 'message': f'方向 {i+1} 的论文搜索完成' } yield f"data: {json.dumps(papers_result)}\n\n" await asyncio.sleep(0.1) # 如果有论文,生成报告 if papers_data: report_status = { 'status': 'generating_report', 'message': f'正在为方向 {i+1} 生成研究报告...', 'current_direction': i } yield f"data: {json.dumps(report_status)}\n\n" # 生成报告 report_data = await research_agent.generate_report_for_direction( direction, papers_data, keywords_data["language"] ) # 发送报告结果 report_result = { 'status': 'report_ready', 'data': { 'direction': direction, 'original_direction': original_dir, 'report': report_data, 'direction_index': i }, 'message': f'方向 {i+1} 的研究报告生成完成' } yield f"data: {json.dumps(report_result)}\n\n" await asyncio.sleep(0.1) # 全部完成 yield f"data: {json.dumps({'status': 'completed', 'message': '研究流程全部完成'})}\n\n" except Exception as e: # 错误处理 logger.error(f"Error in streaming process: {str(e)}", exc_info=True) error_result = {'status': 'error', 'message': f'处理过程中发生错误: {str(e)}'} yield f"data: {json.dumps(error_result)}\n\n" # 返回流式响应 return StreamingResponse( event_generator(), media_type="text/event-stream", headers={ 'Cache-Control': 'no-cache', 'Connection': 'keep-alive', 'X-Accel-Buffering': 'no' } ) @router.get("/enhanced-process-stream") async def enhanced_process_stream(research_intent: str): """以增强的事件流形式返回研究处理结果,按照研究主题逐一处理""" if not research_intent or research_intent.strip() == "": raise HTTPException(status_code=400, detail="研究主题不能为空") async def event_generator(): try: # 初始状态 yield f"data: {json.dumps({'status': 'starting', 'message': '正在开始研究流程...'})}\n\n" await asyncio.sleep(0.1) research_agent = ResearchAgent() # 1. 提取基础关键词 yield f"data: {json.dumps({'status': 'extracting_base_keywords', 'message': '正在提取基础关键词...'})}\n\n" base_keywords_data = await research_agent.extract_keywords_only(research_intent) # 发送基础关键词结果 base_keywords_result = { 'status': 'base_keywords_extracted', 'data': { 'language': base_keywords_data['language'], 'keywords': base_keywords_data['english_keywords'], 'translations': base_keywords_data['original_keywords'] if base_keywords_data['language'] != 'en' else None }, 'message': '基础关键词提取完成' } yield f"data: {json.dumps(base_keywords_result)}\n\n" await asyncio.sleep(0.1) # 2. 生成研究方向 yield f"data: {json.dumps({'status': 'generating_research_topics', 'message': '正在基于关键词生成研究方向...'})}\n\n" # 确保方法存在 if not hasattr(research_agent, 'generate_enhanced_topics'): # 临时实现一个简单版本作为后备 logger.warning("Method generate_enhanced_topics not found, using fallback implementation") research_topics = [] for i, keyword in enumerate(base_keywords_data["english_keywords"][:3]): research_topics.append({ "english_title": f"Research on {keyword}", "title": f"关于{keyword}的研究" if base_keywords_data["language"] != "en" else f"Research on {keyword}", "description": f"Investigating various aspects of {keyword} in the context of the research intent", "keywords": [keyword] }) else: research_topics = await research_agent.generate_enhanced_topics( base_keywords_data["english_keywords"], base_keywords_data["language"] ) # 发送研究方向结果 topics_result = { 'status': 'research_topics_generated', 'data': { 'language': base_keywords_data['language'], 'research_topics': research_topics }, 'message': '研究方向生成完成' } yield f"data: {json.dumps(topics_result)}\n\n" await asyncio.sleep(0.1) # 记录日志,验证进度 logger.info(f"Generated {len(research_topics)} research topics") # 3. 处理每个研究方向 - 防止后续方法未实现导致整个流程中断 for i, topic in enumerate(research_topics): try: # 3.1 为这个方向生成专门的搜索关键词 yield f"data: {json.dumps({'status': 'generating_search_keywords', 'message': f'正在为研究方向 {i+1}/{len(research_topics)} 生成搜索关键词...'})}\n\n" # 检查方法是否存在 if hasattr(research_agent, 'generate_search_keywords_for_topic'): search_keywords = await research_agent.generate_search_keywords_for_topic(topic) else: # 后备方案:使用主题自带关键词 search_keywords = topic.get('keywords', []) if not search_keywords: search_keywords = [topic['english_title']] # 确保搜索关键词不是错误消息 if isinstance(search_keywords, list) and all(isinstance(kw, str) for kw in search_keywords): # 验证关键词是否合理 if any(len(kw) > 3 for kw in search_keywords): # 关键词看起来合理 pass else: # 关键词可能不合理,使用后备关键词 search_keywords = [topic['english_title']] + topic.get('keywords', []) else: # 关键词格式不正确,使用后备关键词 search_keywords = [topic['english_title']] + topic.get('keywords', []) # 发送搜索关键词 keywords_result = { 'status': 'search_keywords_generated', 'data': { 'language': base_keywords_data['language'], 'keywords': search_keywords, 'topic_index': i, 'topic': topic['english_title'] }, 'message': f'研究方向 {i+1} 的搜索关键词生成完成' } yield f"data: {json.dumps(keywords_result)}\n\n" await asyncio.sleep(0.1) # 记录日志,验证进度 logger.info(f"Generated search keywords for topic {i+1}: {search_keywords}") except Exception as e: logger.error(f"Error generating search keywords for topic {i+1}: {str(e)}", exc_info=True) yield f"data: {{\"status\": \"warning\", \"message\": \"生成方向 {i+1} 搜索关键词失败: {str(e)}\"}}\n\n" # 使用更可靠的后备关键词 search_keywords = [topic['english_title']] if 'keywords' in topic and isinstance(topic['keywords'], list) and len(topic['keywords']) > 0: search_keywords.extend(topic['keywords'][:2]) # 3.2 搜索论文 try: yield f"data: {json.dumps({'status': 'searching_papers', 'message': f'正在搜索方向 {i+1}/{len(research_topics)} 的相关论文...', 'current_direction': i})}\n\n" # 检查方法是否存在 if hasattr(research_agent, 'search_papers_with_keywords'): papers_data = await research_agent.search_papers_with_keywords(search_keywords) else: # 后备方案:使用现有方法 papers_data = await research_agent.search_papers_for_direction(topic['english_title']) logger.info(f"Found {len(papers_data)} papers for topic {i+1}") except Exception as e: logger.error(f"Error searching papers for topic {i+1}: {str(e)}", exc_info=True) yield f"data: {json.dumps({'status': 'warning', 'message': f'搜索方向 {i+1} 论文失败: {str(e)}'})}\n\n" # 空列表作为后备 papers_data = [] # 3.3 论文聚类 - 可选步骤 try: if papers_data and len(papers_data) >= 3: yield f"data: {json.dumps({'status': 'clustering_papers', 'message': f'正在对方向 {i+1} 的论文进行聚类分析...'})}\n\n" # 检查方法是否存在 if hasattr(research_agent, 'cluster_papers_by_keywords'): clustered_papers = await research_agent.cluster_papers_by_keywords( papers_data, search_keywords, topic ) else: # 后备方案:简单结构 clustered_papers = { "papers": papers_data, "clusters": [{ "id": 0, "name": topic["english_title"], "keywords": search_keywords[:3] if search_keywords else [] }] } else: # 论文太少,不聚类 clustered_papers = { "papers": papers_data, "clusters": [{ "id": 0, "name": topic["english_title"], "keywords": search_keywords[:3] if search_keywords else [] }] } except Exception as e: logger.error(f"Error clustering papers for topic {i+1}: {str(e)}", exc_info=True) # 后备方案 clustered_papers = { "papers": papers_data, "clusters": [{ "id": 0, "name": topic["english_title"], "keywords": search_keywords[:3] if search_keywords else [] }] } # 3.4 发送论文结果 papers_result = { 'status': 'papers_ready', 'data': { 'direction': topic['english_title'], 'original_direction': topic['title'], 'search_keywords': search_keywords, 'papers': clustered_papers['papers'], 'cluster_info': clustered_papers.get('clusters', []), 'direction_index': i }, 'message': f'方向 {i+1} 的论文搜索和聚类完成' } yield f"data: {json.dumps(papers_result)}\n\n" await asyncio.sleep(0.1) # 3.5 生成报告 - 如果有论文 if papers_data: try: yield f"data: {{\"status\": \"generating_report\", \"message\": \"正在为方向 {i+1} 生成研究报告...\", \"current_direction\": {i}}}\n\n" await asyncio.sleep(0.1) # 使用asyncio.wait_for限制执行时间 try: # 始终使用generate_enhanced_report函数 report_data = await asyncio.wait_for( research_agent.generate_enhanced_report( topic, clustered_papers, search_keywords, base_keywords_data["language"] ), timeout=300.0 ) # 发送报告结果 report_result = { 'status': 'report_ready', 'data': { 'direction': topic['english_title'], 'original_direction': topic['title'], 'search_keywords': search_keywords, 'report': report_data, 'direction_index': i }, 'message': f'方向 {i+1} 的研究报告生成完成' } yield f"data: {json.dumps(report_result)}\n\n" logger.info(f"Generated report for topic {i+1}") except asyncio.TimeoutError: # 特别处理超时情况 logger.warning(f"Report generation timed out for topic {i+1}") # 使用更直接的后备机制 fallback_report = { "english_content": f"# Report generation timed out\n\nWe apologize, but the report generation for this topic took too long and was automatically stopped. This topic has {len(papers_data)} related papers that you can review manually.", "translated_content": f"# 报告生成超时\n\n很抱歉,此主题的报告生成时间过长,已自动停止。该主题有{len(papers_data)}篇相关论文,您可以手动查看。" } fallback_result = { 'status': 'report_ready', 'data': { 'direction': topic['english_title'], 'original_direction': topic['title'], 'search_keywords': search_keywords, 'report': fallback_report, 'direction_index': i, 'is_fallback': True }, 'message': f'方向 {i+1} 的报告生成超时,提供简化版本' } yield f"data: {json.dumps(fallback_result)}\n\n" await asyncio.sleep(0.1) # 添加完成消息,确保前端知道此主题处理完毕 yield f"data: {{\"status\": \"direction_completed\", \"message\": \"方向 {i+1} 处理完成\", \"direction_index\": {i}}}\n\n" await asyncio.sleep(0.1) except Exception as e: logger.error(f"Error handling report generation for topic {i+1}: {str(e)}", exc_info=True) yield f"data: {{\"status\": \"warning\", \"message\": \"生成方向 {i+1} 研究报告处理失败: {str(e)}\"}}\n\n" await asyncio.sleep(0.1) else: logger.warning(f"No papers found for topic {i+1}, skipping report generation") except Exception as e: # 错误处理 logger.error(f"Error in enhanced streaming process: {str(e)}", exc_info=True) error_result = {'status': 'error', 'message': f'处理过程中发生错误: {str(e)}'} yield f"data: {json.dumps(error_result)}\n\n" # 返回流式响应 return StreamingResponse( event_generator(), media_type="text/event-stream", headers={ 'Cache-Control': 'no-cache', 'Connection': 'keep-alive', 'X-Accel-Buffering': 'no' } )