123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234 |
- """
- 聊天API
- 处理与AI助手的对话
- """
- from fastapi import APIRouter, HTTPException, Depends, Request
- from pydantic import BaseModel
- from typing import List, Dict, Any, Optional
- import logging
- from sqlalchemy.orm import Session
- from backend.utils.api_client import LLMClient
- from backend.core.database import get_db
- from backend.core.models import User, ChatHistory, ChatMessage
- from backend.core.auth import get_current_user
- router = APIRouter(prefix="/chat", tags=["chat"])
- logger = logging.getLogger(__name__)
- class Message(BaseModel):
- role: str
- content: str
- class ChatRequest(BaseModel):
- messages: List[Message]
- temperature: Optional[float] = 0.7
- max_tokens: Optional[int] = 1000
- class MessageRequest(BaseModel):
- content: str
- chat_id: Optional[int] = None # 可选,如果为空则创建新对话
- class MessageResponse(BaseModel):
- role: str
- content: str
-
- class ChatResponse(BaseModel):
- chat_id: int
- messages: List[MessageResponse]
- class ChatCompletionResponse(BaseModel):
- message: Message
- model: str
- @router.post("/send-message", response_model=ChatResponse)
- async def send_message(
- request: MessageRequest,
- current_user: User = Depends(get_current_user),
- db: Session = Depends(get_db)
- ):
- """发送消息并获取AI响应"""
- # 获取或创建聊天历史
- if request.chat_id:
- # 使用现有聊天
- chat = db.query(ChatHistory)\
- .filter(ChatHistory.id == request.chat_id, ChatHistory.user_id == current_user.id)\
- .first()
-
- if not chat:
- raise HTTPException(status_code=404, detail="聊天历史不存在或无权访问")
- else:
- # 创建新聊天
- chat = ChatHistory(
- user_id=current_user.id,
- title="新对话" # 可以从第一条消息中提取标题
- )
- db.add(chat)
- db.commit()
- db.refresh(chat)
-
- # 保存用户消息
- user_message = ChatMessage(
- chat_id=chat.id,
- role="user",
- content=request.content
- )
- db.add(user_message)
-
- # 调用AI获取回复
- try:
- # 获取当前聊天的所有消息历史
- chat_history = db.query(ChatMessage)\
- .filter(ChatMessage.chat_id == chat.id)\
- .order_by(ChatMessage.timestamp)\
- .all()
-
- # 构建消息列表
- messages = [{"role": msg.role, "content": msg.content} for msg in chat_history]
- # 添加用户当前消息
- messages.append({"role": "user", "content": request.content})
-
- # 调用AI获取回复
- llm_client = LLMClient()
- response = await llm_client.chat_completion(
- messages=messages,
- temperature=0.7,
- max_tokens=1000
- )
-
- # 从响应中提取AI回复
- ai_response = response["choices"][0]["message"]["content"]
-
- # 保存AI回复
- ai_message = ChatMessage(
- chat_id=chat.id,
- role="assistant",
- content=ai_response
- )
- db.add(ai_message)
-
- # 更新聊天标题(如果是第一条消息)
- if len(chat.messages) <= 2: # 只有当前这两条消息
- # 从第一条消息生成标题
- chat.title = request.content[:30] + ("..." if len(request.content) > 30 else "")
-
- db.commit()
-
- # 获取消息
- messages = db.query(ChatMessage)\
- .filter(ChatMessage.chat_id == chat.id)\
- .order_by(ChatMessage.timestamp)\
- .all()
-
- return {
- "chat_id": chat.id,
- "messages": [{"role": msg.role, "content": msg.content} for msg in messages]
- }
-
- except Exception as e:
- db.rollback()
- logger.error(f"AI处理错误: {str(e)}")
- raise HTTPException(status_code=500, detail=f"AI处理错误: {str(e)}")
- @router.post("", response_model=ChatResponse)
- async def chat_completion_root(
- request: ChatRequest,
- db: Session = Depends(get_db)
- ):
- """
- 处理聊天请求,返回AI助手的回复
- """
- try:
- llm_client = LLMClient()
-
- # 提取消息列表
- messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
-
- # 创建一个临时聊天记录
- chat = ChatHistory(
- user_id=None, # 临时聊天,不关联用户
- title="临时对话"
- )
- db.add(chat)
- db.commit()
- db.refresh(chat)
-
- # 保存用户的历史消息
- for msg in request.messages:
- if msg.role == "user": # 只保存用户消息
- user_message = ChatMessage(
- chat_id=chat.id,
- role=msg.role,
- content=msg.content
- )
- db.add(user_message)
-
- # 调用LLM客户端的chat_completion方法
- response = await llm_client.chat_completion(
- messages=messages,
- temperature=request.temperature,
- max_tokens=request.max_tokens
- )
-
- # 从响应中提取助手消息
- assistant_message = response["choices"][0]["message"]
-
- # 保存AI响应
- ai_message = ChatMessage(
- chat_id=chat.id,
- role=assistant_message["role"],
- content=assistant_message["content"]
- )
- db.add(ai_message)
- db.commit()
-
- # 获取所有消息
- all_messages = db.query(ChatMessage)\
- .filter(ChatMessage.chat_id == chat.id)\
- .order_by(ChatMessage.timestamp)\
- .all()
-
- # 返回符合ChatResponse格式的响应
- return ChatResponse(
- chat_id=chat.id,
- messages=[
- MessageResponse(role=msg.role, content=msg.content)
- for msg in all_messages
- ]
- )
- except Exception as e:
- logger.error(f"聊天API错误: {str(e)}")
- raise HTTPException(status_code=500, detail=f"处理聊天请求时出错: {str(e)}")
- @router.post("/chat", response_model=ChatCompletionResponse)
- async def chat_completion(request: ChatRequest):
- """
- 处理聊天请求,返回AI助手的回复 (旧格式)
- """
- try:
- llm_client = LLMClient()
-
- # 提取消息列表
- messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
-
- # 调用LLM客户端的chat_completion方法
- response = await llm_client.chat_completion(
- messages=messages,
- temperature=request.temperature,
- max_tokens=request.max_tokens
- )
-
- # 从响应中提取助手消息
- assistant_message = response["choices"][0]["message"]
-
- return ChatCompletionResponse(
- message=Message(
- role=assistant_message["role"],
- content=assistant_message["content"]
- ),
- model=response["model"]
- )
- except Exception as e:
- logger.error(f"聊天API错误: {str(e)}")
- raise HTTPException(status_code=500, detail=f"处理聊天请求时出错: {str(e)}")
|