123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116 |
- """
- 研究流程管理模块
- 整合关键词提取、文献检索和聚类分析
- """
- import asyncio
- from typing import Dict, List, Any, Optional
- import logging
- import time
- import json
- from pathlib import Path
- from backend.utils.api_client import LLMClient, ArxivClient
- from backend.utils.keywords import extract_keywords, expand_search_queries
- from backend.config import (
- MAX_SEARCH_RESULTS,
- CACHE_DIR,
- ENABLE_CACHE
- )
- logger = logging.getLogger(__name__)
- class ResearchAgent:
- """研究智能体,管理整个研究流程"""
-
- def __init__(self):
- self.llm_client = LLMClient()
- self.arxiv_client = ArxivClient()
-
- async def process_research_intent(
- self,
- research_intent: str,
- max_results: int = MAX_SEARCH_RESULTS
- ) -> Dict[str, Any]:
- """
- 处理研究意图,执行完整的研究流程
-
- Args:
- research_intent: 用户输入的研究意图
- max_results: 最大检索结果数量
-
- Returns:
- 包含研究结果的字典
- """
- start_time = time.time()
-
- # 初始化结果字典
- result = {
- "research_intent": research_intent,
- "timestamp": time.time(),
- "keywords": [],
- "papers": [],
- "clusters": [],
- "status": "processing"
- }
-
- try:
- # 1. 提取关键词
- logger.info(f"Extracting keywords from: {research_intent}")
- keywords = await extract_keywords(research_intent, self.llm_client)
- result["keywords"] = keywords
- logger.info(f"Extracted keywords: {keywords}")
-
- # 2. 扩展搜索查询
- search_queries = await expand_search_queries(keywords, self.llm_client)
- result["search_queries"] = search_queries
- logger.info(f"Generated search queries: {search_queries}")
-
- # 3. 检索文献
- all_papers = []
- for query in search_queries:
- logger.info(f"Searching papers with query: {query}")
- papers = await self.arxiv_client.search_papers(
- query=query,
- max_results=max_results // len(search_queries)
- )
- all_papers.extend(papers)
-
- # 去重
- unique_papers = []
- paper_ids = set()
- for paper in all_papers:
- if paper["id"] not in paper_ids:
- unique_papers.append(paper)
- paper_ids.add(paper["id"])
-
- result["papers"] = unique_papers
- logger.info(f"Found {len(unique_papers)} unique papers")
-
- # 4. 聚类分析(这部分将在clustering.py中实现)
- # TODO: 实现文献聚类
-
- result["status"] = "completed"
- result["processing_time"] = time.time() - start_time
-
- # 缓存结果
- if ENABLE_CACHE:
- self._cache_result(result)
-
- return result
-
- except Exception as e:
- logger.error(f"Error in research process: {str(e)}", exc_info=True)
- result["status"] = "error"
- result["error"] = str(e)
- return result
-
- def _cache_result(self, result: Dict[str, Any]) -> None:
- """缓存研究结果"""
- try:
- cache_file = Path(CACHE_DIR) / f"research_{int(time.time())}.json"
- with open(cache_file, "w", encoding="utf-8") as f:
- json.dump(result, f, ensure_ascii=False, indent=2)
- logger.info(f"Cached result to {cache_file}")
- except Exception as e:
- logger.error(f"Failed to cache result: {str(e)}")
|