""" 聊天API 处理与AI助手的对话 """ from fastapi import APIRouter, HTTPException, Depends, Request from pydantic import BaseModel from typing import List, Dict, Any, Optional import logging from sqlalchemy.orm import Session from backend.utils.api_client import LLMClient from backend.core.database import get_db from backend.core.models import User, ChatHistory, ChatMessage from backend.core.auth import get_current_user router = APIRouter(prefix="/chat", tags=["chat"]) logger = logging.getLogger(__name__) class Message(BaseModel): role: str content: str class ChatRequest(BaseModel): messages: List[Message] temperature: Optional[float] = 0.7 max_tokens: Optional[int] = 1000 class MessageRequest(BaseModel): content: str chat_id: Optional[int] = None # 可选,如果为空则创建新对话 class MessageResponse(BaseModel): role: str content: str class ChatResponse(BaseModel): chat_id: int messages: List[MessageResponse] class ChatCompletionResponse(BaseModel): message: Message model: str @router.post("/send-message", response_model=ChatResponse) async def send_message( request: MessageRequest, current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): """发送消息并获取AI响应""" # 获取或创建聊天历史 if request.chat_id: # 使用现有聊天 chat = db.query(ChatHistory)\ .filter(ChatHistory.id == request.chat_id, ChatHistory.user_id == current_user.id)\ .first() if not chat: raise HTTPException(status_code=404, detail="聊天历史不存在或无权访问") else: # 创建新聊天 chat = ChatHistory( user_id=current_user.id, title="新对话" # 可以从第一条消息中提取标题 ) db.add(chat) db.commit() db.refresh(chat) # 保存用户消息 user_message = ChatMessage( chat_id=chat.id, role="user", content=request.content ) db.add(user_message) # 调用AI获取回复 try: # 获取当前聊天的所有消息历史 chat_history = db.query(ChatMessage)\ .filter(ChatMessage.chat_id == chat.id)\ .order_by(ChatMessage.timestamp)\ .all() # 构建消息列表 messages = [{"role": msg.role, "content": msg.content} for msg in chat_history] # 添加用户当前消息 messages.append({"role": "user", "content": request.content}) # 调用AI获取回复 llm_client = LLMClient() response = await llm_client.chat_completion( messages=messages, temperature=0.7, max_tokens=1000 ) # 从响应中提取AI回复 ai_response = response["choices"][0]["message"]["content"] # 保存AI回复 ai_message = ChatMessage( chat_id=chat.id, role="assistant", content=ai_response ) db.add(ai_message) # 更新聊天标题(如果是第一条消息) if len(chat.messages) <= 2: # 只有当前这两条消息 # 从第一条消息生成标题 chat.title = request.content[:30] + ("..." if len(request.content) > 30 else "") db.commit() # 获取消息 messages = db.query(ChatMessage)\ .filter(ChatMessage.chat_id == chat.id)\ .order_by(ChatMessage.timestamp)\ .all() return { "chat_id": chat.id, "messages": [{"role": msg.role, "content": msg.content} for msg in messages] } except Exception as e: db.rollback() logger.error(f"AI处理错误: {str(e)}") raise HTTPException(status_code=500, detail=f"AI处理错误: {str(e)}") @router.post("", response_model=ChatResponse) async def chat_completion_root( request: ChatRequest, db: Session = Depends(get_db) ): """ 处理聊天请求,返回AI助手的回复 """ try: llm_client = LLMClient() # 提取消息列表 messages = [{"role": msg.role, "content": msg.content} for msg in request.messages] # 创建一个临时聊天记录 chat = ChatHistory( user_id=None, # 临时聊天,不关联用户 title="临时对话" ) db.add(chat) db.commit() db.refresh(chat) # 保存用户的历史消息 for msg in request.messages: if msg.role == "user": # 只保存用户消息 user_message = ChatMessage( chat_id=chat.id, role=msg.role, content=msg.content ) db.add(user_message) # 调用LLM客户端的chat_completion方法 response = await llm_client.chat_completion( messages=messages, temperature=request.temperature, max_tokens=request.max_tokens ) # 从响应中提取助手消息 assistant_message = response["choices"][0]["message"] # 保存AI响应 ai_message = ChatMessage( chat_id=chat.id, role=assistant_message["role"], content=assistant_message["content"] ) db.add(ai_message) db.commit() # 获取所有消息 all_messages = db.query(ChatMessage)\ .filter(ChatMessage.chat_id == chat.id)\ .order_by(ChatMessage.timestamp)\ .all() # 返回符合ChatResponse格式的响应 return ChatResponse( chat_id=chat.id, messages=[ MessageResponse(role=msg.role, content=msg.content) for msg in all_messages ] ) except Exception as e: logger.error(f"聊天API错误: {str(e)}") raise HTTPException(status_code=500, detail=f"处理聊天请求时出错: {str(e)}") @router.post("/chat", response_model=ChatCompletionResponse) async def chat_completion(request: ChatRequest): """ 处理聊天请求,返回AI助手的回复 (旧格式) """ try: llm_client = LLMClient() # 提取消息列表 messages = [{"role": msg.role, "content": msg.content} for msg in request.messages] # 调用LLM客户端的chat_completion方法 response = await llm_client.chat_completion( messages=messages, temperature=request.temperature, max_tokens=request.max_tokens ) # 从响应中提取助手消息 assistant_message = response["choices"][0]["message"] return ChatCompletionResponse( message=Message( role=assistant_message["role"], content=assistant_message["content"] ), model=response["model"] ) except Exception as e: logger.error(f"聊天API错误: {str(e)}") raise HTTPException(status_code=500, detail=f"处理聊天请求时出错: {str(e)}")