""" API客户端工具 用于与外部API(如LLM API、文献检索API)进行交互 """ import httpx import json import os import time from typing import Dict, List, Any, Optional import logging from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type import hashlib from backend.config import DEEPSEEK_API_KEY, DEEPSEEK_API_URL, DEEPSEEK_MODEL logger = logging.getLogger(__name__) class LLMCache: """LLM响应缓存,减少重复请求""" def __init__(self, cache_size=200): self.cache = {} self.cache_size = cache_size self.hits = 0 self.misses = 0 def get(self, prompt, temperature): """获取缓存响应""" key = self._make_key(prompt, temperature) if key in self.cache: self.hits += 1 return self.cache[key] self.misses += 1 return None def set(self, prompt, temperature, response): """设置缓存响应""" key = self._make_key(prompt, temperature) self.cache[key] = response # 如果缓存太大,删除最旧条目 if len(self.cache) > self.cache_size: oldest_key = next(iter(self.cache)) del self.cache[oldest_key] def _make_key(self, prompt, temperature): """创建缓存键""" # 使用提示内容的MD5哈希作为键 prompt_hash = hashlib.md5(prompt.encode('utf-8')).hexdigest() return f"{prompt_hash}_{temperature}" def get_stats(self): """获取缓存统计信息""" total = self.hits + self.misses hit_rate = (self.hits / total) * 100 if total > 0 else 0 return { "hits": self.hits, "misses": self.misses, "total": total, "hit_rate": hit_rate, "size": len(self.cache) } class LLMClient: """大型语言模型API客户端""" def __init__( self, api_key: Optional[str] = None, model: Optional[str] = None, api_url: Optional[str] = None ): self.api_key = api_key or DEEPSEEK_API_KEY self.model = model or DEEPSEEK_MODEL self.api_url = api_url or DEEPSEEK_API_URL if not self.api_key: logger.warning("No LLM API key provided. API calls will fail.") self.cache = LLMCache() @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10), retry=retry_if_exception_type((httpx.ReadTimeout, httpx.ConnectTimeout)) ) async def generate_text( self, prompt: str, temperature: float = 0.3, max_tokens: int = 1000, timeout: float = 60.0 # 默认超时时间 ) -> str: """优化后的文本生成函数,支持自定义超时和重试""" # 检查缓存 cached_response = self.cache.get(prompt, temperature) if cached_response: logger.info("Using cached LLM response") return cached_response headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}" } payload = { "model": self.model, "messages": [{"role": "user", "content": prompt}], "temperature": temperature, "max_tokens": max_tokens } # 根据请求大小动态调整超时 adjusted_timeout = timeout if max_tokens > 1000: # 更大的请求需要更长的超时 adjusted_timeout = timeout * (max_tokens / 1000) * 1.5 try: async with httpx.AsyncClient(timeout=adjusted_timeout) as client: response = await client.post( f"{self.api_url}/chat/completions", headers=headers, json=payload ) response.raise_for_status() result = response.json() # 添加到缓存 content = result["choices"][0]["message"]["content"] self.cache.set(prompt, temperature, content) return content except httpx.TimeoutException as e: logger.error(f"Timeout calling LLM API ({adjusted_timeout}s): {str(e)}") raise Exception(f"LLM服务响应超时,请稍后重试") from e except Exception as e: logger.error(f"Error calling LLM API: {str(e)}") raise Exception(f"LLM服务调用失败: {str(e)}") from e async def generate_text_chunked( self, prompt: str, max_chunk_tokens: int = 800, temperature: float = 0.3 ) -> str: """ 针对大型文本生成任务的分块处理函数 将大型生成任务分解为多个小块,然后合并结果 """ # 计算预计的请求大小,对超大请求进行拆分 estimated_tokens = len(prompt.split()) * 1.5 # 粗略估计 if estimated_tokens > 2000: # 简化提示,移除不必要的说明性文本 prompt = self._simplify_large_prompt(prompt) # 对于报告生成这类大型任务,使用分块策略 if "structured research report" in prompt.lower() and estimated_tokens > 1000: # 1. 首先生成报告大纲 outline_prompt = f""" Based on the following research direction and papers, create a concise outline for a research report with 3-5 main sections. Research direction: {prompt.split('Research direction:', 1)[1].split('Key papers:', 1)[0].strip()} Only provide the outline with section titles, no additional text. """ outline = await self.generate_text(outline_prompt, temperature=0.3, max_tokens=400) # 2. 然后为每个部分生成内容 sections = [s.strip() for s in outline.split('\n') if s.strip() and not s.strip().startswith('#')] # 3. 逐个处理每个部分 full_report = "" for section in sections[:5]: # 限制最多5个部分 section_prompt = f""" Write the content for the "{section}" section of a research report about: {prompt.split('Research direction:', 1)[1].split('Key papers:', 1)[0].strip()} Based on these papers: {prompt.split('Key papers:', 1)[1].strip() if 'Key papers:' in prompt else ''} Write a detailed section (250-400 words) with relevant information, findings, and analysis. """ section_content = await self.generate_text( section_prompt, temperature=temperature, max_tokens=600 ) full_report += f"## {section}\n\n{section_content}\n\n" return full_report else: # 对于较小的任务,使用标准方法 return await self.generate_text(prompt, temperature=temperature, max_tokens=max_chunk_tokens) def _simplify_large_prompt(self, prompt: str) -> str: """简化大型提示,删除冗余指令""" # 删除重复的指令和多余的空白 lines = [line.strip() for line in prompt.split('\n')] lines = [line for line in lines if line] # 移除多余的模板语言 skip_phrases = [ "Please follow these guidelines:", "Make sure to:", "Remember to:", "Follow this structure:", "Your report should include:" ] simplified_lines = [] for line in lines: if not any(phrase in line for phrase in skip_phrases): simplified_lines.append(line) return "\n".join(simplified_lines) async def translate_text( self, text: str, source_language: str = "auto", target_language: str = "en" ) -> str: """ Translate text from source language to target language Args: text: Text to translate source_language: Source language code (or "auto" for auto-detection) target_language: Target language code Returns: Translated text """ prompt = f""" Translate the following text from {source_language} to {target_language}: {text} Provide only the translated text with no additional explanations or comments. """ return await self.generate_text(prompt, temperature=0.1) async def detect_language(self, text: str) -> str: """ Detect the language of the given text Args: text: Text to detect language for Returns: Language code (e.g., "en", "zh", "es") """ prompt = f""" Detect the language of the following text. Respond with just the ISO 639-1 language code (e.g., "en" for English, "zh" for Chinese, "es" for Spanish). Text: {text} Language code: """ response = await self.generate_text(prompt, temperature=0.1) return response.strip().lower() async def extract_keywords(self, research_topic: str, original_language: str = "en") -> Dict[str, Any]: """ Extract keywords from research topic in any language Args: research_topic: User input research topic original_language: Original language code Returns: Dictionary with keywords in English and the original language """ # Detect language if not provided if original_language == "auto": original_language = await self.detect_language(research_topic) # Translate to English if not already in English english_topic = research_topic if original_language != "en": english_topic = await self.translate_text(research_topic, original_language, "en") # Extract English keywords prompt = f""" As a research assistant, extract 5-8 key search terms from the following research topic. These terms should be useful for academic literature search. Consider core concepts, methodology, application fields, etc. Return only a comma-separated list of keywords without numbering or other text. Research topic: {english_topic} Keywords: """ response = await self.generate_text(prompt, temperature=0.1) english_keywords = [kw.strip() for kw in response.split(',') if kw.strip()] # Translate keywords back to original language if needed original_keywords = english_keywords if original_language != "en": keywords_string = ", ".join(english_keywords) translated_keywords = await self.translate_text(keywords_string, "en", original_language) original_keywords = [kw.strip() for kw in translated_keywords.split(',') if kw.strip()] return { "english_keywords": english_keywords, "original_keywords": original_keywords, "language": original_language } async def generate_research_directions(self, keywords: List[str], original_language: str = "en") -> Dict[str, Any]: """ Generate multiple research directions based on keywords Args: keywords: List of English keywords original_language: Original language of user input Returns: Dictionary with research directions in English and original language """ keywords_string = ", ".join(keywords) prompt = f""" As a research advisor, generate 4-6 distinct research directions based on these keywords: {keywords_string} Each research direction should: 1. Be focused and specific enough for academic research 2. Represent a different aspect or approach to studying the topic 3. Be formulated as a clear research question or direction statement Provide each research direction on a new line without numbering. """ response = await self.generate_text(prompt, temperature=0.7, max_tokens=1500) # Parse research directions english_directions = [dir.strip() for dir in response.splitlines() if dir.strip()] # Translate research directions to original language if needed original_directions = english_directions if original_language != "en": directions_text = "\n".join(english_directions) translated_directions = await self.translate_text(directions_text, "en", original_language) original_directions = [dir.strip() for dir in translated_directions.splitlines() if dir.strip()] return { "english_directions": english_directions, "original_directions": original_directions, "language": original_language } async def batch_translate(self, texts: List[str], source_lang="en", target_lang="zh") -> List[str]: """批量翻译多个文本片段""" if not texts: return [] if source_lang == target_lang: return texts # 将多个文本合并,使用特殊标记分隔 separator = "\n===SECTION_BREAK===\n" combined_text = separator.join(texts) prompt = f""" Translate the following text sections from {source_lang} to {target_lang}. Each section is separated by "{separator}". Keep each section separate in your translation. {combined_text} """ try: translated = await self.generate_text(prompt, temperature=0.1) # 拆分翻译结果 sections = translated.split(separator) # 确保返回与输入相同数量的部分 while len(sections) < len(texts): sections.append("") return sections[:len(texts)] except Exception as e: logger.error(f"Batch translation error: {str(e)}") # 出错时返回原文 return texts @retry( stop=stop_after_attempt(3), # 最多重试3次 wait=wait_exponential(multiplier=1, min=2, max=10), # 指数退避等待 retry=retry_if_exception_type((httpx.ReadTimeout, httpx.ConnectTimeout)) ) async def chat_completion( self, messages: List[Dict[str, str]], temperature: float = 0.3, max_tokens: int = 1000, response_format: Optional[Dict[str, str]] = None ) -> Dict[str, Any]: """ 发送聊天完成请求到API """ headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}" } payload = { "model": self.model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens } if response_format: payload["response_format"] = response_format try: # 增加超时时间到120秒 logger.info(f"正在发送请求到 {self.api_url}/chat/completions") logger.info(f"使用模型: {self.model}") logger.info(f"API密钥前4位: {self.api_key[:4]}****") async with httpx.AsyncClient(timeout=120.0) as client: response = await client.post( f"{self.api_url}/chat/completions", headers=headers, json=payload ) # 记录完整的响应状态和正文 logger.info(f"API响应状态码: {response.status_code}") if response.status_code != 200: logger.error(f"API响应错误: {response.text}") response.raise_for_status() return response.json() except httpx.HTTPStatusError as e: logger.error(f"HTTP状态错误: {e.response.status_code} - {e.response.text}") if e.response.status_code == 404: raise Exception("API端点未找到,请检查API URL配置是否正确") elif e.response.status_code == 401: raise Exception("API密钥无效或未授权,请检查您的API密钥") else: raise Exception(f"API请求失败: HTTP {e.response.status_code}") except httpx.RequestError as e: logger.error(f"请求错误: {str(e)}") raise Exception(f"无法连接到LLM API服务器: {str(e)}") except Exception as e: logger.error(f"chat_completion错误: {str(e)}") raise Exception(f"LLM API调用失败: {str(e)}") class ArxivClient: """arXiv API客户端""" async def search_papers( self, query: str, max_results: int = 4 ) -> List[Dict[str, Any]]: """搜索arXiv文献""" try: import arxiv import asyncio # 简化查询以提高成功率 simplified_query = query if len(query) > 100: # 如果查询太长 # 只保留核心部分 if '"' in query: import re main_terms = re.findall(r'"([^"]*)"', query) if main_terms and len(main_terms) > 1: simplified_query = ' AND '.join([f'"{term}"' for term in main_terms[:2]]) elif main_terms: simplified_query = f'"{main_terms[0]}"' else: parts = query.split(' AND ') if len(parts) > 2: simplified_query = ' AND '.join(parts[:2]) logger.info(f"使用查询: {simplified_query}") search = arxiv.Search( query=simplified_query, max_results=max_results, sort_by=arxiv.SortCriterion.Relevance ) results = [] for paper in search.results(): results.append({ "id": paper.get_short_id(), "title": paper.title, "authors": [author.name for author in paper.authors], "summary": paper.summary, "published": paper.published.isoformat() if paper.published else None, "updated": paper.updated.isoformat() if paper.updated else None, "link": paper.pdf_url, "source": "arxiv" }) return results except Exception as e: logger.error(f"Error searching arXiv: {str(e)}") # 发生错误时返回空列表 return []