""" 研究流程管理模块 整合关键词提取、文献检索和聚类分析 """ import asyncio from typing import Dict, List, Any, Optional import logging import time import json from pathlib import Path from backend.utils.api_client import LLMClient, ArxivClient from backend.utils.keywords import extract_keywords, expand_search_queries from backend.config import ( MAX_SEARCH_RESULTS, CACHE_DIR, ENABLE_CACHE ) logger = logging.getLogger(__name__) class ResearchAgent: """研究智能体,管理整个研究流程""" def __init__(self): self.llm_client = LLMClient() self.arxiv_client = ArxivClient() async def process_research_intent( self, research_intent: str, max_results: int = MAX_SEARCH_RESULTS ) -> Dict[str, Any]: """ 处理研究意图,执行完整的研究流程 Args: research_intent: 用户输入的研究意图 max_results: 最大检索结果数量 Returns: 包含研究结果的字典 """ start_time = time.time() # 初始化结果字典 result = { "research_intent": research_intent, "timestamp": time.time(), "keywords": [], "papers": [], "clusters": [], "status": "processing" } try: # 1. 提取关键词 logger.info(f"Extracting keywords from: {research_intent}") keywords = await extract_keywords(research_intent, self.llm_client) result["keywords"] = keywords logger.info(f"Extracted keywords: {keywords}") # 2. 扩展搜索查询 search_queries = await expand_search_queries(keywords, self.llm_client) result["search_queries"] = search_queries logger.info(f"Generated search queries: {search_queries}") # 3. 检索文献 all_papers = [] for query in search_queries: logger.info(f"Searching papers with query: {query}") papers = await self.arxiv_client.search_papers( query=query, max_results=max_results // len(search_queries) ) all_papers.extend(papers) # 去重 unique_papers = [] paper_ids = set() for paper in all_papers: if paper["id"] not in paper_ids: unique_papers.append(paper) paper_ids.add(paper["id"]) result["papers"] = unique_papers logger.info(f"Found {len(unique_papers)} unique papers") # 4. 聚类分析(这部分将在clustering.py中实现) # TODO: 实现文献聚类 result["status"] = "completed" result["processing_time"] = time.time() - start_time # 缓存结果 if ENABLE_CACHE: self._cache_result(result) return result except Exception as e: logger.error(f"Error in research process: {str(e)}", exc_info=True) result["status"] = "error" result["error"] = str(e) return result def _cache_result(self, result: Dict[str, Any]) -> None: """缓存研究结果""" try: cache_file = Path(CACHE_DIR) / f"research_{int(time.time())}.json" with open(cache_file, "w", encoding="utf-8") as f: json.dump(result, f, ensure_ascii=False, indent=2) logger.info(f"Cached result to {cache_file}") except Exception as e: logger.error(f"Failed to cache result: {str(e)}")