api_client.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. """
  2. API客户端工具
  3. 用于与外部API(如LLM API、文献检索API)进行交互
  4. """
  5. import httpx
  6. import json
  7. import os
  8. import time
  9. from typing import Dict, List, Any, Optional
  10. import logging
  11. from backend.config import DEEPSEEK_API_KEY, DEEPSEEK_API_URL, DEEPSEEK_MODEL
  12. logger = logging.getLogger(__name__)
  13. class LLMClient:
  14. """大型语言模型API客户端"""
  15. def __init__(
  16. self,
  17. api_key: Optional[str] = None,
  18. model: Optional[str] = None,
  19. api_url: Optional[str] = None
  20. ):
  21. self.api_key = api_key or DEEPSEEK_API_KEY
  22. self.model = model or DEEPSEEK_MODEL
  23. self.api_url = api_url or DEEPSEEK_API_URL
  24. if not self.api_key:
  25. logger.warning("No LLM API key provided. API calls will fail.")
  26. async def generate_text(
  27. self,
  28. prompt: str,
  29. temperature: float = 0.3,
  30. max_tokens: int = 1000
  31. ) -> str:
  32. """
  33. 调用LLM生成文本
  34. Args:
  35. prompt: 提示文本
  36. temperature: 温度参数,控制随机性
  37. max_tokens: 最大生成token数
  38. Returns:
  39. 生成的文本
  40. """
  41. headers = {
  42. "Content-Type": "application/json",
  43. "Authorization": f"Bearer {self.api_key}"
  44. }
  45. payload = {
  46. "model": self.model,
  47. "messages": [{"role": "user", "content": prompt}],
  48. "temperature": temperature,
  49. "max_tokens": max_tokens
  50. }
  51. try:
  52. async with httpx.AsyncClient(timeout=60.0) as client:
  53. response = await client.post(
  54. f"{self.api_url}/chat/completions",
  55. headers=headers,
  56. json=payload
  57. )
  58. response.raise_for_status()
  59. result = response.json()
  60. return result["choices"][0]["message"]["content"]
  61. except Exception as e:
  62. logger.error(f"Error calling LLM API: {str(e)}")
  63. return f"Error generating text: {str(e)}"
  64. class ArxivClient:
  65. """arXiv API客户端"""
  66. async def search_papers(
  67. self,
  68. query: str,
  69. max_results: int = 10
  70. ) -> List[Dict[str, Any]]:
  71. """
  72. 搜索arXiv文献
  73. Args:
  74. query: 搜索查询
  75. max_results: 最大结果数量
  76. Returns:
  77. 文献列表
  78. """
  79. try:
  80. import arxiv
  81. search = arxiv.Search(
  82. query=query,
  83. max_results=max_results,
  84. sort_by=arxiv.SortCriterion.Relevance
  85. )
  86. results = []
  87. for paper in await search.results():
  88. results.append({
  89. "id": paper.get_short_id(),
  90. "title": paper.title,
  91. "authors": [author.name for author in paper.authors],
  92. "summary": paper.summary,
  93. "published": paper.published.isoformat() if paper.published else None,
  94. "updated": paper.updated.isoformat() if paper.updated else None,
  95. "link": paper.pdf_url,
  96. "source": "arxiv"
  97. })
  98. return results
  99. except Exception as e:
  100. logger.error(f"Error searching arXiv: {str(e)}")
  101. return []