api_client.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. """
  2. API客户端工具
  3. 用于与外部API(如LLM API、文献检索API)进行交互
  4. """
  5. import httpx
  6. import json
  7. import os
  8. import time
  9. from typing import Dict, List, Any, Optional
  10. import logging
  11. from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
  12. import hashlib
  13. from backend.config import DEEPSEEK_API_KEY, DEEPSEEK_API_URL, DEEPSEEK_MODEL
  14. logger = logging.getLogger(__name__)
  15. class LLMCache:
  16. """LLM响应缓存,减少重复请求"""
  17. def __init__(self, cache_size=200):
  18. self.cache = {}
  19. self.cache_size = cache_size
  20. self.hits = 0
  21. self.misses = 0
  22. def get(self, prompt, temperature):
  23. """获取缓存响应"""
  24. key = self._make_key(prompt, temperature)
  25. if key in self.cache:
  26. self.hits += 1
  27. return self.cache[key]
  28. self.misses += 1
  29. return None
  30. def set(self, prompt, temperature, response):
  31. """设置缓存响应"""
  32. key = self._make_key(prompt, temperature)
  33. self.cache[key] = response
  34. # 如果缓存太大,删除最旧条目
  35. if len(self.cache) > self.cache_size:
  36. oldest_key = next(iter(self.cache))
  37. del self.cache[oldest_key]
  38. def _make_key(self, prompt, temperature):
  39. """创建缓存键"""
  40. # 使用提示内容的MD5哈希作为键
  41. prompt_hash = hashlib.md5(prompt.encode('utf-8')).hexdigest()
  42. return f"{prompt_hash}_{temperature}"
  43. def get_stats(self):
  44. """获取缓存统计信息"""
  45. total = self.hits + self.misses
  46. hit_rate = (self.hits / total) * 100 if total > 0 else 0
  47. return {
  48. "hits": self.hits,
  49. "misses": self.misses,
  50. "total": total,
  51. "hit_rate": hit_rate,
  52. "size": len(self.cache)
  53. }
  54. class LLMClient:
  55. """大型语言模型API客户端"""
  56. def __init__(
  57. self,
  58. api_key: Optional[str] = None,
  59. model: Optional[str] = None,
  60. api_url: Optional[str] = None
  61. ):
  62. self.api_key = api_key or DEEPSEEK_API_KEY
  63. self.model = model or DEEPSEEK_MODEL
  64. self.api_url = api_url or DEEPSEEK_API_URL
  65. if not self.api_key:
  66. logger.warning("No LLM API key provided. API calls will fail.")
  67. self.cache = LLMCache()
  68. @retry(
  69. stop=stop_after_attempt(3),
  70. wait=wait_exponential(multiplier=1, min=2, max=10),
  71. retry=retry_if_exception_type((httpx.ReadTimeout, httpx.ConnectTimeout))
  72. )
  73. async def generate_text(
  74. self,
  75. prompt: str,
  76. temperature: float = 0.3,
  77. max_tokens: int = 1000,
  78. timeout: float = 60.0 # 默认超时时间
  79. ) -> str:
  80. """优化后的文本生成函数,支持自定义超时和重试"""
  81. # 检查缓存
  82. cached_response = self.cache.get(prompt, temperature)
  83. if cached_response:
  84. logger.info("Using cached LLM response")
  85. return cached_response
  86. headers = {
  87. "Content-Type": "application/json",
  88. "Authorization": f"Bearer {self.api_key}"
  89. }
  90. payload = {
  91. "model": self.model,
  92. "messages": [{"role": "user", "content": prompt}],
  93. "temperature": temperature,
  94. "max_tokens": max_tokens
  95. }
  96. # 根据请求大小动态调整超时
  97. adjusted_timeout = timeout
  98. if max_tokens > 1000:
  99. # 更大的请求需要更长的超时
  100. adjusted_timeout = timeout * (max_tokens / 1000) * 1.5
  101. try:
  102. async with httpx.AsyncClient(timeout=adjusted_timeout) as client:
  103. response = await client.post(
  104. f"{self.api_url}/chat/completions",
  105. headers=headers,
  106. json=payload
  107. )
  108. response.raise_for_status()
  109. result = response.json()
  110. # 添加到缓存
  111. content = result["choices"][0]["message"]["content"]
  112. self.cache.set(prompt, temperature, content)
  113. return content
  114. except httpx.TimeoutException as e:
  115. logger.error(f"Timeout calling LLM API ({adjusted_timeout}s): {str(e)}")
  116. raise Exception(f"LLM服务响应超时,请稍后重试") from e
  117. except Exception as e:
  118. logger.error(f"Error calling LLM API: {str(e)}")
  119. raise Exception(f"LLM服务调用失败: {str(e)}") from e
  120. async def generate_text_chunked(
  121. self,
  122. prompt: str,
  123. max_chunk_tokens: int = 800,
  124. temperature: float = 0.3
  125. ) -> str:
  126. """
  127. 针对大型文本生成任务的分块处理函数
  128. 将大型生成任务分解为多个小块,然后合并结果
  129. """
  130. # 计算预计的请求大小,对超大请求进行拆分
  131. estimated_tokens = len(prompt.split()) * 1.5 # 粗略估计
  132. if estimated_tokens > 2000:
  133. # 简化提示,移除不必要的说明性文本
  134. prompt = self._simplify_large_prompt(prompt)
  135. # 对于报告生成这类大型任务,使用分块策略
  136. if "structured research report" in prompt.lower() and estimated_tokens > 1000:
  137. # 1. 首先生成报告大纲
  138. outline_prompt = f"""
  139. Based on the following research direction and papers, create a concise outline
  140. for a research report with 3-5 main sections.
  141. Research direction: {prompt.split('Research direction:', 1)[1].split('Key papers:', 1)[0].strip()}
  142. Only provide the outline with section titles, no additional text.
  143. """
  144. outline = await self.generate_text(outline_prompt, temperature=0.3, max_tokens=400)
  145. # 2. 然后为每个部分生成内容
  146. sections = [s.strip() for s in outline.split('\n') if s.strip() and not s.strip().startswith('#')]
  147. # 3. 逐个处理每个部分
  148. full_report = ""
  149. for section in sections[:5]: # 限制最多5个部分
  150. section_prompt = f"""
  151. Write the content for the "{section}" section of a research report about:
  152. {prompt.split('Research direction:', 1)[1].split('Key papers:', 1)[0].strip()}
  153. Based on these papers:
  154. {prompt.split('Key papers:', 1)[1].strip() if 'Key papers:' in prompt else ''}
  155. Write a detailed section (250-400 words) with relevant information, findings, and analysis.
  156. """
  157. section_content = await self.generate_text(
  158. section_prompt,
  159. temperature=temperature,
  160. max_tokens=600
  161. )
  162. full_report += f"## {section}\n\n{section_content}\n\n"
  163. return full_report
  164. else:
  165. # 对于较小的任务,使用标准方法
  166. return await self.generate_text(prompt, temperature=temperature, max_tokens=max_chunk_tokens)
  167. def _simplify_large_prompt(self, prompt: str) -> str:
  168. """简化大型提示,删除冗余指令"""
  169. # 删除重复的指令和多余的空白
  170. lines = [line.strip() for line in prompt.split('\n')]
  171. lines = [line for line in lines if line]
  172. # 移除多余的模板语言
  173. skip_phrases = [
  174. "Please follow these guidelines:",
  175. "Make sure to:",
  176. "Remember to:",
  177. "Follow this structure:",
  178. "Your report should include:"
  179. ]
  180. simplified_lines = []
  181. for line in lines:
  182. if not any(phrase in line for phrase in skip_phrases):
  183. simplified_lines.append(line)
  184. return "\n".join(simplified_lines)
  185. async def translate_text(
  186. self,
  187. text: str,
  188. source_language: str = "auto",
  189. target_language: str = "en"
  190. ) -> str:
  191. """
  192. Translate text from source language to target language
  193. Args:
  194. text: Text to translate
  195. source_language: Source language code (or "auto" for auto-detection)
  196. target_language: Target language code
  197. Returns:
  198. Translated text
  199. """
  200. prompt = f"""
  201. Translate the following text from {source_language} to {target_language}:
  202. {text}
  203. Provide only the translated text with no additional explanations or comments.
  204. """
  205. return await self.generate_text(prompt, temperature=0.1)
  206. async def detect_language(self, text: str) -> str:
  207. """
  208. Detect the language of the given text
  209. Args:
  210. text: Text to detect language for
  211. Returns:
  212. Language code (e.g., "en", "zh", "es")
  213. """
  214. prompt = f"""
  215. Detect the language of the following text.
  216. Respond with just the ISO 639-1 language code (e.g., "en" for English, "zh" for Chinese, "es" for Spanish).
  217. Text: {text}
  218. Language code:
  219. """
  220. response = await self.generate_text(prompt, temperature=0.1)
  221. return response.strip().lower()
  222. async def extract_keywords(self, research_topic: str, original_language: str = "en") -> Dict[str, Any]:
  223. """
  224. Extract keywords from research topic in any language
  225. Args:
  226. research_topic: User input research topic
  227. original_language: Original language code
  228. Returns:
  229. Dictionary with keywords in English and the original language
  230. """
  231. # Detect language if not provided
  232. if original_language == "auto":
  233. original_language = await self.detect_language(research_topic)
  234. # Translate to English if not already in English
  235. english_topic = research_topic
  236. if original_language != "en":
  237. english_topic = await self.translate_text(research_topic, original_language, "en")
  238. # Extract English keywords
  239. prompt = f"""
  240. As a research assistant, extract 5-8 key search terms from the following research topic.
  241. These terms should be useful for academic literature search. Consider core concepts, methodology,
  242. application fields, etc. Return only a comma-separated list of keywords without numbering or other text.
  243. Research topic: {english_topic}
  244. Keywords:
  245. """
  246. response = await self.generate_text(prompt, temperature=0.1)
  247. english_keywords = [kw.strip() for kw in response.split(',') if kw.strip()]
  248. # Translate keywords back to original language if needed
  249. original_keywords = english_keywords
  250. if original_language != "en":
  251. keywords_string = ", ".join(english_keywords)
  252. translated_keywords = await self.translate_text(keywords_string, "en", original_language)
  253. original_keywords = [kw.strip() for kw in translated_keywords.split(',') if kw.strip()]
  254. return {
  255. "english_keywords": english_keywords,
  256. "original_keywords": original_keywords,
  257. "language": original_language
  258. }
  259. async def generate_research_directions(self, keywords: List[str], original_language: str = "en") -> Dict[str, Any]:
  260. """
  261. Generate multiple research directions based on keywords
  262. Args:
  263. keywords: List of English keywords
  264. original_language: Original language of user input
  265. Returns:
  266. Dictionary with research directions in English and original language
  267. """
  268. keywords_string = ", ".join(keywords)
  269. prompt = f"""
  270. As a research advisor, generate 4-6 distinct research directions based on these keywords: {keywords_string}
  271. Each research direction should:
  272. 1. Be focused and specific enough for academic research
  273. 2. Represent a different aspect or approach to studying the topic
  274. 3. Be formulated as a clear research question or direction statement
  275. Provide each research direction on a new line without numbering.
  276. """
  277. response = await self.generate_text(prompt, temperature=0.7, max_tokens=1500)
  278. # Parse research directions
  279. english_directions = [dir.strip() for dir in response.splitlines() if dir.strip()]
  280. # Translate research directions to original language if needed
  281. original_directions = english_directions
  282. if original_language != "en":
  283. directions_text = "\n".join(english_directions)
  284. translated_directions = await self.translate_text(directions_text, "en", original_language)
  285. original_directions = [dir.strip() for dir in translated_directions.splitlines() if dir.strip()]
  286. return {
  287. "english_directions": english_directions,
  288. "original_directions": original_directions,
  289. "language": original_language
  290. }
  291. async def batch_translate(self, texts: List[str], source_lang="en", target_lang="zh") -> List[str]:
  292. """批量翻译多个文本片段"""
  293. if not texts:
  294. return []
  295. if source_lang == target_lang:
  296. return texts
  297. # 将多个文本合并,使用特殊标记分隔
  298. separator = "\n===SECTION_BREAK===\n"
  299. combined_text = separator.join(texts)
  300. prompt = f"""
  301. Translate the following text sections from {source_lang} to {target_lang}.
  302. Each section is separated by "{separator}".
  303. Keep each section separate in your translation.
  304. {combined_text}
  305. """
  306. try:
  307. translated = await self.generate_text(prompt, temperature=0.1)
  308. # 拆分翻译结果
  309. sections = translated.split(separator)
  310. # 确保返回与输入相同数量的部分
  311. while len(sections) < len(texts):
  312. sections.append("")
  313. return sections[:len(texts)]
  314. except Exception as e:
  315. logger.error(f"Batch translation error: {str(e)}")
  316. # 出错时返回原文
  317. return texts
  318. @retry(
  319. stop=stop_after_attempt(3), # 最多重试3次
  320. wait=wait_exponential(multiplier=1, min=2, max=10), # 指数退避等待
  321. retry=retry_if_exception_type((httpx.ReadTimeout, httpx.ConnectTimeout))
  322. )
  323. async def chat_completion(
  324. self,
  325. messages: List[Dict[str, str]],
  326. temperature: float = 0.3,
  327. max_tokens: int = 1000,
  328. response_format: Optional[Dict[str, str]] = None
  329. ) -> Dict[str, Any]:
  330. """
  331. 发送聊天完成请求到API
  332. """
  333. headers = {
  334. "Content-Type": "application/json",
  335. "Authorization": f"Bearer {self.api_key}"
  336. }
  337. payload = {
  338. "model": self.model,
  339. "messages": messages,
  340. "temperature": temperature,
  341. "max_tokens": max_tokens
  342. }
  343. if response_format:
  344. payload["response_format"] = response_format
  345. try:
  346. # 增加超时时间到120秒
  347. logger.info(f"正在发送请求到 {self.api_url}/chat/completions")
  348. logger.info(f"使用模型: {self.model}")
  349. logger.info(f"API密钥前4位: {self.api_key[:4]}****")
  350. async with httpx.AsyncClient(timeout=120.0) as client:
  351. response = await client.post(
  352. f"{self.api_url}/chat/completions",
  353. headers=headers,
  354. json=payload
  355. )
  356. # 记录完整的响应状态和正文
  357. logger.info(f"API响应状态码: {response.status_code}")
  358. if response.status_code != 200:
  359. logger.error(f"API响应错误: {response.text}")
  360. response.raise_for_status()
  361. return response.json()
  362. except httpx.HTTPStatusError as e:
  363. logger.error(f"HTTP状态错误: {e.response.status_code} - {e.response.text}")
  364. if e.response.status_code == 404:
  365. raise Exception("API端点未找到,请检查API URL配置是否正确")
  366. elif e.response.status_code == 401:
  367. raise Exception("API密钥无效或未授权,请检查您的API密钥")
  368. else:
  369. raise Exception(f"API请求失败: HTTP {e.response.status_code}")
  370. except httpx.RequestError as e:
  371. logger.error(f"请求错误: {str(e)}")
  372. raise Exception(f"无法连接到LLM API服务器: {str(e)}")
  373. except Exception as e:
  374. logger.error(f"chat_completion错误: {str(e)}")
  375. raise Exception(f"LLM API调用失败: {str(e)}")
  376. class ArxivClient:
  377. """arXiv API客户端"""
  378. async def search_papers(
  379. self,
  380. query: str,
  381. max_results: int = 4
  382. ) -> List[Dict[str, Any]]:
  383. """搜索arXiv文献"""
  384. try:
  385. import arxiv
  386. import asyncio
  387. # 简化查询以提高成功率
  388. simplified_query = query
  389. if len(query) > 100: # 如果查询太长
  390. # 只保留核心部分
  391. if '"' in query:
  392. import re
  393. main_terms = re.findall(r'"([^"]*)"', query)
  394. if main_terms and len(main_terms) > 1:
  395. simplified_query = ' AND '.join([f'"{term}"' for term in main_terms[:2]])
  396. elif main_terms:
  397. simplified_query = f'"{main_terms[0]}"'
  398. else:
  399. parts = query.split(' AND ')
  400. if len(parts) > 2:
  401. simplified_query = ' AND '.join(parts[:2])
  402. logger.info(f"使用查询: {simplified_query}")
  403. search = arxiv.Search(
  404. query=simplified_query,
  405. max_results=max_results,
  406. sort_by=arxiv.SortCriterion.Relevance
  407. )
  408. results = []
  409. for paper in search.results():
  410. results.append({
  411. "id": paper.get_short_id(),
  412. "title": paper.title,
  413. "authors": [author.name for author in paper.authors],
  414. "summary": paper.summary,
  415. "published": paper.published.isoformat() if paper.published else None,
  416. "updated": paper.updated.isoformat() if paper.updated else None,
  417. "link": paper.pdf_url,
  418. "source": "arxiv"
  419. })
  420. return results
  421. except Exception as e:
  422. logger.error(f"Error searching arXiv: {str(e)}")
  423. # 发生错误时返回空列表
  424. return []