123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277 |
- """
- 聊天历史API路由
- """
- from fastapi import APIRouter, Depends, HTTPException, status
- from sqlalchemy.orm import Session
- from pydantic import BaseModel, validator
- from typing import List, Optional
- from datetime import datetime
- from enum import Enum
- from backend.core.database import get_db
- from backend.core.models import User, ChatHistory, ChatMessage
- from backend.core.auth import get_current_user
- router = APIRouter(tags=["chat"], prefix="/chat-history")
- # 添加角色枚举类型
- class MessageRole(str, Enum):
- USER = "user"
- ASSISTANT = "assistant"
- SYSTEM = "system"
- class MessageCreate(BaseModel):
- role: MessageRole # 使用枚举类型替代字符串
- content: str
-
- # 添加额外验证器
- @validator('content')
- def content_must_not_be_empty(cls, v):
- if not v or not v.strip():
- raise ValueError('消息内容不能为空')
- return v
-
- # 添加角色验证器,尝试将字符串转换为枚举
- @validator('role', pre=True)
- def validate_role(cls, v):
- if isinstance(v, MessageRole):
- return v
-
- if isinstance(v, str):
- # 尝试将字符串匹配到枚举值
- if v.lower() == "user":
- return MessageRole.USER
- elif v.lower() == "assistant":
- return MessageRole.ASSISTANT
- elif v.lower() == "system":
- return MessageRole.SYSTEM
-
- # 如果无法匹配,默认返回USER
- return MessageRole.USER
- class MessageResponse(BaseModel):
- id: int
- role: str
- content: str
- timestamp: datetime
-
- class Config:
- orm_mode = True
- class ChatHistoryCreate(BaseModel):
- title: str
- messages: List[MessageCreate]
-
- class ChatHistoryResponse(BaseModel):
- id: int
- title: str
- created_at: datetime
- updated_at: datetime
- messages: List[MessageResponse]
-
- class Config:
- orm_mode = True
-
- class ChatHistorySummary(BaseModel):
- id: int
- title: str
- created_at: datetime
- updated_at: datetime
-
- class Config:
- orm_mode = True
- class ChatHistoryUpdate(BaseModel):
- title: Optional[str] = None
- @router.post("/", response_model=ChatHistoryResponse)
- async def create_chat_history(
- chat_data: ChatHistoryCreate,
- current_user: User = Depends(get_current_user),
- db: Session = Depends(get_db)
- ):
- """创建新的聊天历史"""
- try:
- # 打印接收到的数据
- print("收到的聊天历史数据:", chat_data)
-
- # 创建聊天历史
- chat = ChatHistory(
- user_id=current_user.id,
- title=chat_data.title
- )
- db.add(chat)
- db.commit()
- db.refresh(chat)
-
- # 添加消息
- for msg_data in chat_data.messages:
- message = ChatMessage(
- chat_id=chat.id,
- role=msg_data.role,
- content=msg_data.content
- )
- db.add(message)
-
- db.commit()
- db.refresh(chat)
-
- return chat
- except Exception as e:
- # 打印详细错误信息
- print("保存聊天历史出错:", str(e))
- raise
- @router.get("/", response_model=List[ChatHistorySummary])
- async def get_chat_histories(
- current_user: User = Depends(get_current_user),
- db: Session = Depends(get_db)
- ):
- """获取用户的所有聊天历史摘要"""
- # 获取基本聊天历史记录
- chats_query = db.query(ChatHistory)\
- .filter(ChatHistory.user_id == current_user.id)\
- .order_by(ChatHistory.updated_at.desc())
-
- chats = chats_query.all()
-
- # 增强聊天记录信息
- for chat in chats:
- # 获取消息数量
- message_count = db.query(ChatMessage)\
- .filter(ChatMessage.chat_id == chat.id)\
- .count()
- chat.message_count = message_count
-
- # 获取最新的用户消息作为预览
- last_user_message = db.query(ChatMessage)\
- .filter(ChatMessage.chat_id == chat.id, ChatMessage.role == "user")\
- .order_by(ChatMessage.timestamp.desc())\
- .first()
-
- if last_user_message:
- # 截取前100个字符作为预览
- preview = last_user_message.content[:100]
- if len(last_user_message.content) > 100:
- preview += "..."
- chat.preview = preview
- else:
- chat.preview = ""
-
- # 检查是否包含代码
- has_code = db.query(ChatMessage)\
- .filter(
- ChatMessage.chat_id == chat.id,
- ChatMessage.content.like("%```%")
- ).count() > 0
- chat.has_code = has_code
-
- return chats
- @router.get("/{chat_id}", response_model=ChatHistoryResponse)
- async def get_chat_history(
- chat_id: int,
- current_user: User = Depends(get_current_user),
- db: Session = Depends(get_db)
- ):
- """获取特定聊天历史的详情"""
- chat = db.query(ChatHistory)\
- .filter(ChatHistory.id == chat_id, ChatHistory.user_id == current_user.id)\
- .first()
-
- if not chat:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="聊天历史不存在或无权访问"
- )
-
- return chat
- @router.post("/{chat_id}/messages", response_model=MessageResponse)
- async def add_message(
- chat_id: int,
- message: MessageCreate,
- current_user: User = Depends(get_current_user),
- db: Session = Depends(get_db)
- ):
- """向现有聊天添加新消息"""
- # 检查聊天是否存在且属于当前用户
- chat = db.query(ChatHistory)\
- .filter(ChatHistory.id == chat_id, ChatHistory.user_id == current_user.id)\
- .first()
-
- if not chat:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="聊天历史不存在或无权访问"
- )
-
- # 创建新消息
- db_message = ChatMessage(
- chat_id=chat_id,
- role=message.role,
- content=message.content
- )
-
- # 更新聊天的更新时间
- chat.updated_at = datetime.utcnow()
-
- db.add(db_message)
- db.commit()
- db.refresh(db_message)
-
- return db_message
- @router.delete("/{chat_id}")
- async def delete_chat_history(
- chat_id: int,
- current_user: User = Depends(get_current_user),
- db: Session = Depends(get_db)
- ):
- """删除聊天历史"""
- chat = db.query(ChatHistory)\
- .filter(ChatHistory.id == chat_id, ChatHistory.user_id == current_user.id)\
- .first()
-
- if not chat:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="聊天历史不存在或无权访问"
- )
-
- # 删除聊天相关的所有消息
- db.query(ChatMessage).filter(ChatMessage.chat_id == chat_id).delete()
-
- # 删除聊天
- db.delete(chat)
- db.commit()
-
- return {"message": "聊天历史已删除"}
- @router.put("/{chat_id}", response_model=ChatHistoryResponse)
- async def update_chat_history(
- chat_id: int,
- update_data: ChatHistoryUpdate,
- current_user: User = Depends(get_current_user),
- db: Session = Depends(get_db)
- ):
- """更新聊天历史的标题"""
- chat = db.query(ChatHistory)\
- .filter(ChatHistory.id == chat_id, ChatHistory.user_id == current_user.id)\
- .first()
-
- if not chat:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="聊天历史不存在或无权访问"
- )
-
- # 更新标题
- if update_data.title:
- chat.title = update_data.title
-
- chat.updated_at = datetime.utcnow()
- db.commit()
- db.refresh(chat)
-
- return chat
|