123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519 |
- """
- API客户端工具
- 用于与外部API(如LLM API、文献检索API)进行交互
- """
- import httpx
- import json
- import os
- import time
- from typing import Dict, List, Any, Optional
- import logging
- from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
- import hashlib
- from backend.config import DEEPSEEK_API_KEY, DEEPSEEK_API_URL, DEEPSEEK_MODEL
- logger = logging.getLogger(__name__)
- class LLMCache:
- """LLM响应缓存,减少重复请求"""
-
- def __init__(self, cache_size=200):
- self.cache = {}
- self.cache_size = cache_size
- self.hits = 0
- self.misses = 0
-
- def get(self, prompt, temperature):
- """获取缓存响应"""
- key = self._make_key(prompt, temperature)
- if key in self.cache:
- self.hits += 1
- return self.cache[key]
- self.misses += 1
- return None
-
- def set(self, prompt, temperature, response):
- """设置缓存响应"""
- key = self._make_key(prompt, temperature)
- self.cache[key] = response
-
- # 如果缓存太大,删除最旧条目
- if len(self.cache) > self.cache_size:
- oldest_key = next(iter(self.cache))
- del self.cache[oldest_key]
-
- def _make_key(self, prompt, temperature):
- """创建缓存键"""
- # 使用提示内容的MD5哈希作为键
- prompt_hash = hashlib.md5(prompt.encode('utf-8')).hexdigest()
- return f"{prompt_hash}_{temperature}"
-
- def get_stats(self):
- """获取缓存统计信息"""
- total = self.hits + self.misses
- hit_rate = (self.hits / total) * 100 if total > 0 else 0
- return {
- "hits": self.hits,
- "misses": self.misses,
- "total": total,
- "hit_rate": hit_rate,
- "size": len(self.cache)
- }
- class LLMClient:
- """大型语言模型API客户端"""
-
- def __init__(
- self,
- api_key: Optional[str] = None,
- model: Optional[str] = None,
- api_url: Optional[str] = None
- ):
- self.api_key = api_key or DEEPSEEK_API_KEY
- self.model = model or DEEPSEEK_MODEL
- self.api_url = api_url or DEEPSEEK_API_URL
-
- if not self.api_key:
- logger.warning("No LLM API key provided. API calls will fail.")
-
- self.cache = LLMCache()
-
- @retry(
- stop=stop_after_attempt(3),
- wait=wait_exponential(multiplier=1, min=2, max=10),
- retry=retry_if_exception_type((httpx.ReadTimeout, httpx.ConnectTimeout))
- )
- async def generate_text(
- self,
- prompt: str,
- temperature: float = 0.3,
- max_tokens: int = 1000,
- timeout: float = 60.0 # 默认超时时间
- ) -> str:
- """优化后的文本生成函数,支持自定义超时和重试"""
- # 检查缓存
- cached_response = self.cache.get(prompt, temperature)
- if cached_response:
- logger.info("Using cached LLM response")
- return cached_response
-
- headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {self.api_key}"
- }
-
- payload = {
- "model": self.model,
- "messages": [{"role": "user", "content": prompt}],
- "temperature": temperature,
- "max_tokens": max_tokens
- }
-
- # 根据请求大小动态调整超时
- adjusted_timeout = timeout
- if max_tokens > 1000:
- # 更大的请求需要更长的超时
- adjusted_timeout = timeout * (max_tokens / 1000) * 1.5
-
- try:
- async with httpx.AsyncClient(timeout=adjusted_timeout) as client:
- response = await client.post(
- f"{self.api_url}/chat/completions",
- headers=headers,
- json=payload
- )
- response.raise_for_status()
- result = response.json()
-
- # 添加到缓存
- content = result["choices"][0]["message"]["content"]
- self.cache.set(prompt, temperature, content)
- return content
- except httpx.TimeoutException as e:
- logger.error(f"Timeout calling LLM API ({adjusted_timeout}s): {str(e)}")
- raise Exception(f"LLM服务响应超时,请稍后重试") from e
- except Exception as e:
- logger.error(f"Error calling LLM API: {str(e)}")
- raise Exception(f"LLM服务调用失败: {str(e)}") from e
-
- async def generate_text_chunked(
- self,
- prompt: str,
- max_chunk_tokens: int = 800,
- temperature: float = 0.3
- ) -> str:
- """
- 针对大型文本生成任务的分块处理函数
- 将大型生成任务分解为多个小块,然后合并结果
- """
- # 计算预计的请求大小,对超大请求进行拆分
- estimated_tokens = len(prompt.split()) * 1.5 # 粗略估计
-
- if estimated_tokens > 2000:
- # 简化提示,移除不必要的说明性文本
- prompt = self._simplify_large_prompt(prompt)
-
- # 对于报告生成这类大型任务,使用分块策略
- if "structured research report" in prompt.lower() and estimated_tokens > 1000:
- # 1. 首先生成报告大纲
- outline_prompt = f"""
- Based on the following research direction and papers, create a concise outline
- for a research report with 3-5 main sections.
-
- Research direction: {prompt.split('Research direction:', 1)[1].split('Key papers:', 1)[0].strip()}
-
- Only provide the outline with section titles, no additional text.
- """
-
- outline = await self.generate_text(outline_prompt, temperature=0.3, max_tokens=400)
-
- # 2. 然后为每个部分生成内容
- sections = [s.strip() for s in outline.split('\n') if s.strip() and not s.strip().startswith('#')]
-
- # 3. 逐个处理每个部分
- full_report = ""
- for section in sections[:5]: # 限制最多5个部分
- section_prompt = f"""
- Write the content for the "{section}" section of a research report about:
-
- {prompt.split('Research direction:', 1)[1].split('Key papers:', 1)[0].strip()}
-
- Based on these papers:
- {prompt.split('Key papers:', 1)[1].strip() if 'Key papers:' in prompt else ''}
-
- Write a detailed section (250-400 words) with relevant information, findings, and analysis.
- """
-
- section_content = await self.generate_text(
- section_prompt,
- temperature=temperature,
- max_tokens=600
- )
-
- full_report += f"## {section}\n\n{section_content}\n\n"
-
- return full_report
-
- else:
- # 对于较小的任务,使用标准方法
- return await self.generate_text(prompt, temperature=temperature, max_tokens=max_chunk_tokens)
-
- def _simplify_large_prompt(self, prompt: str) -> str:
- """简化大型提示,删除冗余指令"""
- # 删除重复的指令和多余的空白
- lines = [line.strip() for line in prompt.split('\n')]
- lines = [line for line in lines if line]
-
- # 移除多余的模板语言
- skip_phrases = [
- "Please follow these guidelines:",
- "Make sure to:",
- "Remember to:",
- "Follow this structure:",
- "Your report should include:"
- ]
-
- simplified_lines = []
- for line in lines:
- if not any(phrase in line for phrase in skip_phrases):
- simplified_lines.append(line)
-
- return "\n".join(simplified_lines)
- async def translate_text(
- self,
- text: str,
- source_language: str = "auto",
- target_language: str = "en"
- ) -> str:
- """
- Translate text from source language to target language
-
- Args:
- text: Text to translate
- source_language: Source language code (or "auto" for auto-detection)
- target_language: Target language code
-
- Returns:
- Translated text
- """
- prompt = f"""
- Translate the following text from {source_language} to {target_language}:
-
- {text}
-
- Provide only the translated text with no additional explanations or comments.
- """
-
- return await self.generate_text(prompt, temperature=0.1)
- async def detect_language(self, text: str) -> str:
- """
- Detect the language of the given text
-
- Args:
- text: Text to detect language for
-
- Returns:
- Language code (e.g., "en", "zh", "es")
- """
- prompt = f"""
- Detect the language of the following text.
- Respond with just the ISO 639-1 language code (e.g., "en" for English, "zh" for Chinese, "es" for Spanish).
-
- Text: {text}
-
- Language code:
- """
-
- response = await self.generate_text(prompt, temperature=0.1)
- return response.strip().lower()
- async def extract_keywords(self, research_topic: str, original_language: str = "en") -> Dict[str, Any]:
- """
- Extract keywords from research topic in any language
-
- Args:
- research_topic: User input research topic
- original_language: Original language code
-
- Returns:
- Dictionary with keywords in English and the original language
- """
- # Detect language if not provided
- if original_language == "auto":
- original_language = await self.detect_language(research_topic)
-
- # Translate to English if not already in English
- english_topic = research_topic
- if original_language != "en":
- english_topic = await self.translate_text(research_topic, original_language, "en")
-
- # Extract English keywords
- prompt = f"""
- As a research assistant, extract 5-8 key search terms from the following research topic.
- These terms should be useful for academic literature search. Consider core concepts, methodology,
- application fields, etc. Return only a comma-separated list of keywords without numbering or other text.
- Research topic: {english_topic}
-
- Keywords:
- """
-
- response = await self.generate_text(prompt, temperature=0.1)
- english_keywords = [kw.strip() for kw in response.split(',') if kw.strip()]
-
- # Translate keywords back to original language if needed
- original_keywords = english_keywords
- if original_language != "en":
- keywords_string = ", ".join(english_keywords)
- translated_keywords = await self.translate_text(keywords_string, "en", original_language)
- original_keywords = [kw.strip() for kw in translated_keywords.split(',') if kw.strip()]
-
- return {
- "english_keywords": english_keywords,
- "original_keywords": original_keywords,
- "language": original_language
- }
- async def generate_research_directions(self, keywords: List[str], original_language: str = "en") -> Dict[str, Any]:
- """
- Generate multiple research directions based on keywords
-
- Args:
- keywords: List of English keywords
- original_language: Original language of user input
-
- Returns:
- Dictionary with research directions in English and original language
- """
- keywords_string = ", ".join(keywords)
-
- prompt = f"""
- As a research advisor, generate 4-6 distinct research directions based on these keywords: {keywords_string}
-
- Each research direction should:
- 1. Be focused and specific enough for academic research
- 2. Represent a different aspect or approach to studying the topic
- 3. Be formulated as a clear research question or direction statement
-
- Provide each research direction on a new line without numbering.
- """
-
- response = await self.generate_text(prompt, temperature=0.7, max_tokens=1500)
-
- # Parse research directions
- english_directions = [dir.strip() for dir in response.splitlines() if dir.strip()]
-
- # Translate research directions to original language if needed
- original_directions = english_directions
- if original_language != "en":
- directions_text = "\n".join(english_directions)
- translated_directions = await self.translate_text(directions_text, "en", original_language)
- original_directions = [dir.strip() for dir in translated_directions.splitlines() if dir.strip()]
-
- return {
- "english_directions": english_directions,
- "original_directions": original_directions,
- "language": original_language
- }
- async def batch_translate(self, texts: List[str], source_lang="en", target_lang="zh") -> List[str]:
- """批量翻译多个文本片段"""
- if not texts:
- return []
-
- if source_lang == target_lang:
- return texts
-
- # 将多个文本合并,使用特殊标记分隔
- separator = "\n===SECTION_BREAK===\n"
- combined_text = separator.join(texts)
-
- prompt = f"""
- Translate the following text sections from {source_lang} to {target_lang}.
- Each section is separated by "{separator}".
- Keep each section separate in your translation.
-
- {combined_text}
- """
-
- try:
- translated = await self.generate_text(prompt, temperature=0.1)
-
- # 拆分翻译结果
- sections = translated.split(separator)
-
- # 确保返回与输入相同数量的部分
- while len(sections) < len(texts):
- sections.append("")
-
- return sections[:len(texts)]
- except Exception as e:
- logger.error(f"Batch translation error: {str(e)}")
- # 出错时返回原文
- return texts
- @retry(
- stop=stop_after_attempt(3), # 最多重试3次
- wait=wait_exponential(multiplier=1, min=2, max=10), # 指数退避等待
- retry=retry_if_exception_type((httpx.ReadTimeout, httpx.ConnectTimeout))
- )
- async def chat_completion(
- self,
- messages: List[Dict[str, str]],
- temperature: float = 0.3,
- max_tokens: int = 1000,
- response_format: Optional[Dict[str, str]] = None
- ) -> Dict[str, Any]:
- """
- 发送聊天完成请求到API
- """
- headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {self.api_key}"
- }
-
- payload = {
- "model": self.model,
- "messages": messages,
- "temperature": temperature,
- "max_tokens": max_tokens
- }
-
- if response_format:
- payload["response_format"] = response_format
-
- try:
- # 增加超时时间到120秒
- logger.info(f"正在发送请求到 {self.api_url}/chat/completions")
- logger.info(f"使用模型: {self.model}")
- logger.info(f"API密钥前4位: {self.api_key[:4]}****")
-
- async with httpx.AsyncClient(timeout=120.0) as client:
- response = await client.post(
- f"{self.api_url}/chat/completions",
- headers=headers,
- json=payload
- )
-
- # 记录完整的响应状态和正文
- logger.info(f"API响应状态码: {response.status_code}")
-
- if response.status_code != 200:
- logger.error(f"API响应错误: {response.text}")
-
- response.raise_for_status()
- return response.json()
- except httpx.HTTPStatusError as e:
- logger.error(f"HTTP状态错误: {e.response.status_code} - {e.response.text}")
- if e.response.status_code == 404:
- raise Exception("API端点未找到,请检查API URL配置是否正确")
- elif e.response.status_code == 401:
- raise Exception("API密钥无效或未授权,请检查您的API密钥")
- else:
- raise Exception(f"API请求失败: HTTP {e.response.status_code}")
- except httpx.RequestError as e:
- logger.error(f"请求错误: {str(e)}")
- raise Exception(f"无法连接到LLM API服务器: {str(e)}")
- except Exception as e:
- logger.error(f"chat_completion错误: {str(e)}")
- raise Exception(f"LLM API调用失败: {str(e)}")
- class ArxivClient:
- """arXiv API客户端"""
-
- async def search_papers(
- self,
- query: str,
- max_results: int = 4
- ) -> List[Dict[str, Any]]:
- """搜索arXiv文献"""
- try:
- import arxiv
- import asyncio
-
- # 简化查询以提高成功率
- simplified_query = query
- if len(query) > 100: # 如果查询太长
- # 只保留核心部分
- if '"' in query:
- import re
- main_terms = re.findall(r'"([^"]*)"', query)
- if main_terms and len(main_terms) > 1:
- simplified_query = ' AND '.join([f'"{term}"' for term in main_terms[:2]])
- elif main_terms:
- simplified_query = f'"{main_terms[0]}"'
- else:
- parts = query.split(' AND ')
- if len(parts) > 2:
- simplified_query = ' AND '.join(parts[:2])
-
- logger.info(f"使用查询: {simplified_query}")
-
- search = arxiv.Search(
- query=simplified_query,
- max_results=max_results,
- sort_by=arxiv.SortCriterion.Relevance
- )
-
- results = []
- for paper in search.results():
- results.append({
- "id": paper.get_short_id(),
- "title": paper.title,
- "authors": [author.name for author in paper.authors],
- "summary": paper.summary,
- "published": paper.published.isoformat() if paper.published else None,
- "updated": paper.updated.isoformat() if paper.updated else None,
- "link": paper.pdf_url,
- "source": "arxiv"
- })
-
- return results
- except Exception as e:
- logger.error(f"Error searching arXiv: {str(e)}")
- # 发生错误时返回空列表
- return []
|