chat_history.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. """
  2. 聊天历史API路由
  3. """
  4. from fastapi import APIRouter, Depends, HTTPException, status
  5. from sqlalchemy.orm import Session
  6. from pydantic import BaseModel
  7. from typing import List, Optional
  8. from datetime import datetime
  9. from backend.core.database import get_db
  10. from backend.core.models import User, ChatHistory, ChatMessage
  11. from backend.core.auth import get_current_user
  12. router = APIRouter(tags=["chat"], prefix="/chat-history")
  13. class MessageCreate(BaseModel):
  14. role: str
  15. content: str
  16. class MessageResponse(BaseModel):
  17. id: int
  18. role: str
  19. content: str
  20. timestamp: datetime
  21. class Config:
  22. orm_mode = True
  23. class ChatHistoryCreate(BaseModel):
  24. title: str
  25. messages: List[MessageCreate]
  26. class ChatHistoryResponse(BaseModel):
  27. id: int
  28. title: str
  29. created_at: datetime
  30. updated_at: datetime
  31. messages: List[MessageResponse]
  32. class Config:
  33. orm_mode = True
  34. class ChatHistorySummary(BaseModel):
  35. id: int
  36. title: str
  37. created_at: datetime
  38. updated_at: datetime
  39. class Config:
  40. orm_mode = True
  41. @router.post("/", response_model=ChatHistoryResponse)
  42. async def create_chat_history(
  43. chat_data: ChatHistoryCreate,
  44. current_user: User = Depends(get_current_user),
  45. db: Session = Depends(get_db)
  46. ):
  47. """创建新的聊天历史"""
  48. # 创建聊天历史
  49. chat = ChatHistory(
  50. user_id=current_user.id,
  51. title=chat_data.title
  52. )
  53. db.add(chat)
  54. db.commit()
  55. db.refresh(chat)
  56. # 添加消息
  57. for msg_data in chat_data.messages:
  58. message = ChatMessage(
  59. chat_id=chat.id,
  60. role=msg_data.role,
  61. content=msg_data.content
  62. )
  63. db.add(message)
  64. db.commit()
  65. db.refresh(chat)
  66. return chat
  67. @router.get("/", response_model=List[ChatHistorySummary])
  68. async def get_chat_histories(
  69. current_user: User = Depends(get_current_user),
  70. db: Session = Depends(get_db)
  71. ):
  72. """获取用户的所有聊天历史摘要"""
  73. chats = db.query(ChatHistory)\
  74. .filter(ChatHistory.user_id == current_user.id)\
  75. .order_by(ChatHistory.updated_at.desc())\
  76. .all()
  77. return chats
  78. @router.get("/{chat_id}", response_model=ChatHistoryResponse)
  79. async def get_chat_history(
  80. chat_id: int,
  81. current_user: User = Depends(get_current_user),
  82. db: Session = Depends(get_db)
  83. ):
  84. """获取特定聊天历史的详情"""
  85. chat = db.query(ChatHistory)\
  86. .filter(ChatHistory.id == chat_id, ChatHistory.user_id == current_user.id)\
  87. .first()
  88. if not chat:
  89. raise HTTPException(
  90. status_code=status.HTTP_404_NOT_FOUND,
  91. detail="聊天历史不存在或无权访问"
  92. )
  93. return chat
  94. @router.post("/{chat_id}/messages", response_model=MessageResponse)
  95. async def add_message(
  96. chat_id: int,
  97. message: MessageCreate,
  98. current_user: User = Depends(get_current_user),
  99. db: Session = Depends(get_db)
  100. ):
  101. """向现有聊天添加新消息"""
  102. # 检查聊天是否存在且属于当前用户
  103. chat = db.query(ChatHistory)\
  104. .filter(ChatHistory.id == chat_id, ChatHistory.user_id == current_user.id)\
  105. .first()
  106. if not chat:
  107. raise HTTPException(
  108. status_code=status.HTTP_404_NOT_FOUND,
  109. detail="聊天历史不存在或无权访问"
  110. )
  111. # 创建新消息
  112. db_message = ChatMessage(
  113. chat_id=chat_id,
  114. role=message.role,
  115. content=message.content
  116. )
  117. # 更新聊天的更新时间
  118. chat.updated_at = datetime.utcnow()
  119. db.add(db_message)
  120. db.commit()
  121. db.refresh(db_message)
  122. return db_message
  123. @router.delete("/{chat_id}")
  124. async def delete_chat_history(
  125. chat_id: int,
  126. current_user: User = Depends(get_current_user),
  127. db: Session = Depends(get_db)
  128. ):
  129. """删除聊天历史"""
  130. chat = db.query(ChatHistory)\
  131. .filter(ChatHistory.id == chat_id, ChatHistory.user_id == current_user.id)\
  132. .first()
  133. if not chat:
  134. raise HTTPException(
  135. status_code=status.HTTP_404_NOT_FOUND,
  136. detail="聊天历史不存在或无权访问"
  137. )
  138. # 删除聊天相关的所有消息
  139. db.query(ChatMessage).filter(ChatMessage.chat_id == chat_id).delete()
  140. # 删除聊天
  141. db.delete(chat)
  142. db.commit()
  143. return {"message": "聊天历史已删除"}