research.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. """
  2. 研究相关API路由
  3. """
  4. from fastapi import APIRouter, HTTPException, BackgroundTasks, Query
  5. from pydantic import BaseModel, Field
  6. from typing import List, Dict, Any, Optional
  7. import logging
  8. from backend.core.research import ResearchAgent
  9. from backend.core.clustering import PaperClusterer
  10. from backend.core.report import ReportGenerator
  11. from backend.config import MAX_SEARCH_RESULTS
  12. logger = logging.getLogger(__name__)
  13. router = APIRouter(prefix="/api/research", tags=["research"])
  14. # 数据模型
  15. class ResearchRequest(BaseModel):
  16. research_intent: str = Field(..., description="用户的研究意图")
  17. max_results: int = Field(MAX_SEARCH_RESULTS, description="最大检索结果数量")
  18. class KeywordsRequest(BaseModel):
  19. research_intent: str = Field(..., description="用户的研究意图")
  20. class PaperSearchRequest(BaseModel):
  21. keywords: List[str] = Field(..., description="检索关键词")
  22. max_results: int = Field(MAX_SEARCH_RESULTS, description="最大检索结果数量")
  23. class ClusterRequest(BaseModel):
  24. papers: List[Dict[str, Any]] = Field(..., description="要聚类的论文")
  25. num_clusters: int = Field(0, description="聚类数量,0表示自动确定")
  26. class ReportRequest(BaseModel):
  27. research_intent: str = Field(..., description="研究意图")
  28. keywords: List[str] = Field(..., description="关键词")
  29. papers: List[Dict[str, Any]] = Field(..., description="检索到的论文")
  30. clusters: Optional[Dict[str, Any]] = Field(None, description="聚类结果")
  31. # 全局实例
  32. research_agent = ResearchAgent()
  33. paper_clusterer = PaperClusterer()
  34. report_generator = ReportGenerator()
  35. # 路由定义
  36. @router.post("/process")
  37. async def process_research(request: ResearchRequest):
  38. """处理完整的研究流程"""
  39. try:
  40. logger.info(f"Processing research intent: {request.research_intent}")
  41. result = await research_agent.process_research_intent(
  42. research_intent=request.research_intent,
  43. max_results=request.max_results
  44. )
  45. return result
  46. except Exception as e:
  47. logger.error(f"Error in research process: {str(e)}", exc_info=True)
  48. raise HTTPException(status_code=500, detail=str(e))
  49. @router.post("/extract-keywords")
  50. async def extract_keywords(request: KeywordsRequest):
  51. """从研究意图中提取关键词"""
  52. try:
  53. logger.info(f"Extracting keywords from: {request.research_intent}")
  54. result = await research_agent.llm_client.extract_keywords(
  55. research_topic=request.research_intent
  56. )
  57. return {"keywords": result}
  58. except Exception as e:
  59. logger.error(f"Error extracting keywords: {str(e)}", exc_info=True)
  60. raise HTTPException(status_code=500, detail=str(e))
  61. @router.post("/search-papers")
  62. async def search_papers(request: PaperSearchRequest):
  63. """根据关键词检索论文"""
  64. try:
  65. logger.info(f"Searching papers with keywords: {request.keywords}")
  66. papers = []
  67. for keyword in request.keywords:
  68. results = await research_agent.arxiv_client.search_papers(
  69. query=keyword,
  70. max_results=max(3, request.max_results // len(request.keywords))
  71. )
  72. papers.extend(results)
  73. # 去重
  74. unique_papers = []
  75. paper_ids = set()
  76. for paper in papers:
  77. if paper["id"] not in paper_ids:
  78. unique_papers.append(paper)
  79. paper_ids.add(paper["id"])
  80. return {"papers": unique_papers, "count": len(unique_papers)}
  81. except Exception as e:
  82. logger.error(f"Error searching papers: {str(e)}", exc_info=True)
  83. raise HTTPException(status_code=500, detail=str(e))
  84. @router.post("/cluster-papers")
  85. async def cluster_papers(request: ClusterRequest):
  86. """对论文进行聚类分析"""
  87. try:
  88. logger.info(f"Clustering {len(request.papers)} papers")
  89. result = paper_clusterer.cluster_papers(
  90. papers=request.papers,
  91. num_clusters=request.num_clusters
  92. )
  93. return result
  94. except Exception as e:
  95. logger.error(f"Error clustering papers: {str(e)}", exc_info=True)
  96. raise HTTPException(status_code=500, detail=str(e))
  97. @router.post("/generate-report")
  98. async def generate_report(request: ReportRequest):
  99. """生成研究报告"""
  100. try:
  101. logger.info(f"Generating report for: {request.research_intent}")
  102. report = await report_generator.generate_report(
  103. research_intent=request.research_intent,
  104. keywords=request.keywords,
  105. papers=request.papers,
  106. clusters=request.clusters
  107. )
  108. return report
  109. except Exception as e:
  110. logger.error(f"Error generating report: {str(e)}", exc_info=True)
  111. raise HTTPException(status_code=500, detail=str(e))