chat.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. """
  2. 聊天API
  3. 处理与AI助手的对话
  4. """
  5. from fastapi import APIRouter, HTTPException, Depends, Request
  6. from pydantic import BaseModel
  7. from typing import List, Dict, Any, Optional
  8. import logging
  9. from sqlalchemy.orm import Session
  10. from backend.utils.api_client import LLMClient
  11. from backend.core.database import get_db
  12. from backend.core.models import User, ChatHistory, ChatMessage
  13. from backend.core.auth import get_current_user
  14. router = APIRouter(prefix="/chat", tags=["chat"])
  15. logger = logging.getLogger(__name__)
  16. class Message(BaseModel):
  17. role: str
  18. content: str
  19. class ChatRequest(BaseModel):
  20. messages: List[Message]
  21. temperature: Optional[float] = 0.7
  22. max_tokens: Optional[int] = 1000
  23. class MessageRequest(BaseModel):
  24. content: str
  25. chat_id: Optional[int] = None # 可选,如果为空则创建新对话
  26. class MessageResponse(BaseModel):
  27. role: str
  28. content: str
  29. class ChatResponse(BaseModel):
  30. chat_id: int
  31. messages: List[MessageResponse]
  32. class ChatCompletionResponse(BaseModel):
  33. message: Message
  34. model: str
  35. @router.post("/send-message", response_model=ChatResponse)
  36. async def send_message(
  37. request: MessageRequest,
  38. current_user: User = Depends(get_current_user),
  39. db: Session = Depends(get_db)
  40. ):
  41. """发送消息并获取AI响应"""
  42. # 获取或创建聊天历史
  43. if request.chat_id:
  44. # 使用现有聊天
  45. chat = db.query(ChatHistory)\
  46. .filter(ChatHistory.id == request.chat_id, ChatHistory.user_id == current_user.id)\
  47. .first()
  48. if not chat:
  49. raise HTTPException(status_code=404, detail="聊天历史不存在或无权访问")
  50. else:
  51. # 创建新聊天
  52. chat = ChatHistory(
  53. user_id=current_user.id,
  54. title="新对话" # 可以从第一条消息中提取标题
  55. )
  56. db.add(chat)
  57. db.commit()
  58. db.refresh(chat)
  59. # 保存用户消息
  60. user_message = ChatMessage(
  61. chat_id=chat.id,
  62. role="user",
  63. content=request.content
  64. )
  65. db.add(user_message)
  66. # 调用AI获取回复
  67. try:
  68. # 获取当前聊天的所有消息历史
  69. chat_history = db.query(ChatMessage)\
  70. .filter(ChatMessage.chat_id == chat.id)\
  71. .order_by(ChatMessage.timestamp)\
  72. .all()
  73. # 构建消息列表
  74. messages = [{"role": msg.role, "content": msg.content} for msg in chat_history]
  75. # 添加用户当前消息
  76. messages.append({"role": "user", "content": request.content})
  77. # 调用AI获取回复
  78. llm_client = LLMClient()
  79. response = await llm_client.chat_completion(
  80. messages=messages,
  81. temperature=0.7,
  82. max_tokens=1000
  83. )
  84. # 从响应中提取AI回复
  85. ai_response = response["choices"][0]["message"]["content"]
  86. # 保存AI回复
  87. ai_message = ChatMessage(
  88. chat_id=chat.id,
  89. role="assistant",
  90. content=ai_response
  91. )
  92. db.add(ai_message)
  93. # 更新聊天标题(如果是第一条消息)
  94. if len(chat.messages) <= 2: # 只有当前这两条消息
  95. # 从第一条消息生成标题
  96. chat.title = request.content[:30] + ("..." if len(request.content) > 30 else "")
  97. db.commit()
  98. # 获取消息
  99. messages = db.query(ChatMessage)\
  100. .filter(ChatMessage.chat_id == chat.id)\
  101. .order_by(ChatMessage.timestamp)\
  102. .all()
  103. return {
  104. "chat_id": chat.id,
  105. "messages": [{"role": msg.role, "content": msg.content} for msg in messages]
  106. }
  107. except Exception as e:
  108. db.rollback()
  109. logger.error(f"AI处理错误: {str(e)}")
  110. raise HTTPException(status_code=500, detail=f"AI处理错误: {str(e)}")
  111. @router.post("", response_model=ChatResponse)
  112. async def chat_completion_root(
  113. request: ChatRequest,
  114. db: Session = Depends(get_db)
  115. ):
  116. """
  117. 处理聊天请求,返回AI助手的回复
  118. """
  119. try:
  120. llm_client = LLMClient()
  121. # 提取消息列表
  122. messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
  123. # 创建一个临时聊天记录
  124. chat = ChatHistory(
  125. user_id=None, # 临时聊天,不关联用户
  126. title="临时对话"
  127. )
  128. db.add(chat)
  129. db.commit()
  130. db.refresh(chat)
  131. # 保存用户的历史消息
  132. for msg in request.messages:
  133. if msg.role == "user": # 只保存用户消息
  134. user_message = ChatMessage(
  135. chat_id=chat.id,
  136. role=msg.role,
  137. content=msg.content
  138. )
  139. db.add(user_message)
  140. # 调用LLM客户端的chat_completion方法
  141. response = await llm_client.chat_completion(
  142. messages=messages,
  143. temperature=request.temperature,
  144. max_tokens=request.max_tokens
  145. )
  146. # 从响应中提取助手消息
  147. assistant_message = response["choices"][0]["message"]
  148. # 保存AI响应
  149. ai_message = ChatMessage(
  150. chat_id=chat.id,
  151. role=assistant_message["role"],
  152. content=assistant_message["content"]
  153. )
  154. db.add(ai_message)
  155. db.commit()
  156. # 获取所有消息
  157. all_messages = db.query(ChatMessage)\
  158. .filter(ChatMessage.chat_id == chat.id)\
  159. .order_by(ChatMessage.timestamp)\
  160. .all()
  161. # 返回符合ChatResponse格式的响应
  162. return ChatResponse(
  163. chat_id=chat.id,
  164. messages=[
  165. MessageResponse(role=msg.role, content=msg.content)
  166. for msg in all_messages
  167. ]
  168. )
  169. except Exception as e:
  170. logger.error(f"聊天API错误: {str(e)}")
  171. raise HTTPException(status_code=500, detail=f"处理聊天请求时出错: {str(e)}")
  172. @router.post("/chat", response_model=ChatCompletionResponse)
  173. async def chat_completion(request: ChatRequest):
  174. """
  175. 处理聊天请求,返回AI助手的回复 (旧格式)
  176. """
  177. try:
  178. llm_client = LLMClient()
  179. # 提取消息列表
  180. messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
  181. # 调用LLM客户端的chat_completion方法
  182. response = await llm_client.chat_completion(
  183. messages=messages,
  184. temperature=request.temperature,
  185. max_tokens=request.max_tokens
  186. )
  187. # 从响应中提取助手消息
  188. assistant_message = response["choices"][0]["message"]
  189. return ChatCompletionResponse(
  190. message=Message(
  191. role=assistant_message["role"],
  192. content=assistant_message["content"]
  193. ),
  194. model=response["model"]
  195. )
  196. except Exception as e:
  197. logger.error(f"聊天API错误: {str(e)}")
  198. raise HTTPException(status_code=500, detail=f"处理聊天请求时出错: {str(e)}")