""" 聊天历史API路由 """ from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session from pydantic import BaseModel, validator from typing import List, Optional from datetime import datetime from enum import Enum from backend.core.database import get_db from backend.core.models import User, ChatHistory, ChatMessage from backend.core.auth import get_current_user router = APIRouter(tags=["chat"], prefix="/chat-history") # 添加角色枚举类型 class MessageRole(str, Enum): USER = "user" ASSISTANT = "assistant" SYSTEM = "system" class MessageCreate(BaseModel): role: MessageRole # 使用枚举类型替代字符串 content: str # 添加额外验证器 @validator('content') def content_must_not_be_empty(cls, v): if not v or not v.strip(): raise ValueError('消息内容不能为空') return v # 添加角色验证器,尝试将字符串转换为枚举 @validator('role', pre=True) def validate_role(cls, v): if isinstance(v, MessageRole): return v if isinstance(v, str): # 尝试将字符串匹配到枚举值 if v.lower() == "user": return MessageRole.USER elif v.lower() == "assistant": return MessageRole.ASSISTANT elif v.lower() == "system": return MessageRole.SYSTEM # 如果无法匹配,默认返回USER return MessageRole.USER class MessageResponse(BaseModel): id: int role: str content: str timestamp: datetime class Config: orm_mode = True class ChatHistoryCreate(BaseModel): title: str messages: List[MessageCreate] class ChatHistoryResponse(BaseModel): id: int title: str created_at: datetime updated_at: datetime messages: List[MessageResponse] class Config: orm_mode = True class ChatHistorySummary(BaseModel): id: int title: str created_at: datetime updated_at: datetime class Config: orm_mode = True class ChatHistoryUpdate(BaseModel): title: Optional[str] = None @router.post("/", response_model=ChatHistoryResponse) async def create_chat_history( chat_data: ChatHistoryCreate, current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): """创建新的聊天历史""" try: # 打印接收到的数据 print("收到的聊天历史数据:", chat_data) # 创建聊天历史 chat = ChatHistory( user_id=current_user.id, title=chat_data.title ) db.add(chat) db.commit() db.refresh(chat) # 添加消息 for msg_data in chat_data.messages: message = ChatMessage( chat_id=chat.id, role=msg_data.role, content=msg_data.content ) db.add(message) db.commit() db.refresh(chat) return chat except Exception as e: # 打印详细错误信息 print("保存聊天历史出错:", str(e)) raise @router.get("/", response_model=List[ChatHistorySummary]) async def get_chat_histories( current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): """获取用户的所有聊天历史摘要""" # 获取基本聊天历史记录 chats_query = db.query(ChatHistory)\ .filter(ChatHistory.user_id == current_user.id)\ .order_by(ChatHistory.updated_at.desc()) chats = chats_query.all() # 增强聊天记录信息 for chat in chats: # 获取消息数量 message_count = db.query(ChatMessage)\ .filter(ChatMessage.chat_id == chat.id)\ .count() chat.message_count = message_count # 获取最新的用户消息作为预览 last_user_message = db.query(ChatMessage)\ .filter(ChatMessage.chat_id == chat.id, ChatMessage.role == "user")\ .order_by(ChatMessage.timestamp.desc())\ .first() if last_user_message: # 截取前100个字符作为预览 preview = last_user_message.content[:100] if len(last_user_message.content) > 100: preview += "..." chat.preview = preview else: chat.preview = "" # 检查是否包含代码 has_code = db.query(ChatMessage)\ .filter( ChatMessage.chat_id == chat.id, ChatMessage.content.like("%```%") ).count() > 0 chat.has_code = has_code return chats @router.get("/{chat_id}", response_model=ChatHistoryResponse) async def get_chat_history( chat_id: int, current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): """获取特定聊天历史的详情""" chat = db.query(ChatHistory)\ .filter(ChatHistory.id == chat_id, ChatHistory.user_id == current_user.id)\ .first() if not chat: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="聊天历史不存在或无权访问" ) return chat @router.post("/{chat_id}/messages", response_model=MessageResponse) async def add_message( chat_id: int, message: MessageCreate, current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): """向现有聊天添加新消息""" # 检查聊天是否存在且属于当前用户 chat = db.query(ChatHistory)\ .filter(ChatHistory.id == chat_id, ChatHistory.user_id == current_user.id)\ .first() if not chat: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="聊天历史不存在或无权访问" ) # 创建新消息 db_message = ChatMessage( chat_id=chat_id, role=message.role, content=message.content ) # 更新聊天的更新时间 chat.updated_at = datetime.utcnow() db.add(db_message) db.commit() db.refresh(db_message) return db_message @router.delete("/{chat_id}") async def delete_chat_history( chat_id: int, current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): """删除聊天历史""" chat = db.query(ChatHistory)\ .filter(ChatHistory.id == chat_id, ChatHistory.user_id == current_user.id)\ .first() if not chat: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="聊天历史不存在或无权访问" ) # 删除聊天相关的所有消息 db.query(ChatMessage).filter(ChatMessage.chat_id == chat_id).delete() # 删除聊天 db.delete(chat) db.commit() return {"message": "聊天历史已删除"} @router.put("/{chat_id}", response_model=ChatHistoryResponse) async def update_chat_history( chat_id: int, update_data: ChatHistoryUpdate, current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): """更新聊天历史的标题""" chat = db.query(ChatHistory)\ .filter(ChatHistory.id == chat_id, ChatHistory.user_id == current_user.id)\ .first() if not chat: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="聊天历史不存在或无权访问" ) # 更新标题 if update_data.title: chat.title = update_data.title chat.updated_at = datetime.utcnow() db.commit() db.refresh(chat) return chat