research.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. """
  2. 研究流程管理模块
  3. 整合关键词提取、文献检索和聚类分析
  4. """
  5. import asyncio
  6. from typing import Dict, List, Any, Optional
  7. import logging
  8. import time
  9. import json
  10. from pathlib import Path
  11. from backend.utils.api_client import LLMClient, ArxivClient
  12. from backend.utils.keywords import extract_keywords, expand_search_queries
  13. from backend.config import (
  14. MAX_SEARCH_RESULTS,
  15. CACHE_DIR,
  16. ENABLE_CACHE
  17. )
  18. logger = logging.getLogger(__name__)
  19. class ResearchAgent:
  20. """研究智能体,管理整个研究流程"""
  21. def __init__(self):
  22. self.llm_client = LLMClient()
  23. self.arxiv_client = ArxivClient()
  24. async def process_research_intent(
  25. self,
  26. research_intent: str,
  27. max_results: int = MAX_SEARCH_RESULTS
  28. ) -> Dict[str, Any]:
  29. """
  30. 处理研究意图,执行完整的研究流程
  31. Args:
  32. research_intent: 用户输入的研究意图
  33. max_results: 最大检索结果数量
  34. Returns:
  35. 包含研究结果的字典
  36. """
  37. start_time = time.time()
  38. # 初始化结果字典
  39. result = {
  40. "research_intent": research_intent,
  41. "timestamp": time.time(),
  42. "keywords": [],
  43. "papers": [],
  44. "clusters": [],
  45. "status": "processing"
  46. }
  47. try:
  48. # 1. 提取关键词
  49. logger.info(f"Extracting keywords from: {research_intent}")
  50. keywords = await extract_keywords(research_intent, self.llm_client)
  51. result["keywords"] = keywords
  52. logger.info(f"Extracted keywords: {keywords}")
  53. # 2. 扩展搜索查询
  54. search_queries = await expand_search_queries(keywords, self.llm_client)
  55. result["search_queries"] = search_queries
  56. logger.info(f"Generated search queries: {search_queries}")
  57. # 3. 检索文献
  58. all_papers = []
  59. for query in search_queries:
  60. logger.info(f"Searching papers with query: {query}")
  61. papers = await self.arxiv_client.search_papers(
  62. query=query,
  63. max_results=max_results // len(search_queries)
  64. )
  65. all_papers.extend(papers)
  66. # 去重
  67. unique_papers = []
  68. paper_ids = set()
  69. for paper in all_papers:
  70. if paper["id"] not in paper_ids:
  71. unique_papers.append(paper)
  72. paper_ids.add(paper["id"])
  73. result["papers"] = unique_papers
  74. logger.info(f"Found {len(unique_papers)} unique papers")
  75. # 4. 聚类分析(这部分将在clustering.py中实现)
  76. # TODO: 实现文献聚类
  77. result["status"] = "completed"
  78. result["processing_time"] = time.time() - start_time
  79. # 缓存结果
  80. if ENABLE_CACHE:
  81. self._cache_result(result)
  82. return result
  83. except Exception as e:
  84. logger.error(f"Error in research process: {str(e)}", exc_info=True)
  85. result["status"] = "error"
  86. result["error"] = str(e)
  87. return result
  88. def _cache_result(self, result: Dict[str, Any]) -> None:
  89. """缓存研究结果"""
  90. try:
  91. cache_file = Path(CACHE_DIR) / f"research_{int(time.time())}.json"
  92. with open(cache_file, "w", encoding="utf-8") as f:
  93. json.dump(result, f, ensure_ascii=False, indent=2)
  94. logger.info(f"Cached result to {cache_file}")
  95. except Exception as e:
  96. logger.error(f"Failed to cache result: {str(e)}")