users.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. """
  2. 用户资料API路由
  3. """
  4. from fastapi import APIRouter, Depends, HTTPException, status
  5. from sqlalchemy.orm import Session
  6. from pydantic import BaseModel
  7. from typing import List, Optional
  8. from datetime import datetime
  9. from backend.core.database import get_db
  10. from backend.core.models import User
  11. from backend.core.auth import get_current_user
  12. router = APIRouter(tags=["users"], prefix="/users")
  13. class UserProfileUpdate(BaseModel):
  14. username: Optional[str] = None
  15. email: Optional[str] = None
  16. institution: Optional[str] = None
  17. position: Optional[str] = None
  18. research_fields: Optional[List[str]] = None
  19. class UserProfileResponse(BaseModel):
  20. id: int
  21. username: str
  22. email: str
  23. institution: str
  24. position: str
  25. research_fields: list
  26. membership_type: str
  27. expiry_date: Optional[str] = None
  28. credits: int
  29. class Config:
  30. orm_mode = True
  31. @router.get("/profile", response_model=UserProfileResponse)
  32. async def get_user_profile(current_user: User = Depends(get_current_user)):
  33. """获取用户资料"""
  34. # 规范化日期格式,确保一致性
  35. result = {
  36. "id": current_user.id,
  37. "username": current_user.username,
  38. "email": current_user.email,
  39. "institution": current_user.institution,
  40. "position": current_user.position,
  41. "research_fields": current_user.research_fields,
  42. "membership_type": current_user.membership_type,
  43. "expiry_date": current_user.expiry_date.strftime("%Y-%m-%d") if current_user.expiry_date else None,
  44. "credits": current_user.credits
  45. }
  46. return result
  47. @router.put("/profile", response_model=UserProfileResponse)
  48. async def update_user_profile(
  49. profile: UserProfileUpdate,
  50. current_user: User = Depends(get_current_user),
  51. db: Session = Depends(get_db)
  52. ):
  53. """更新用户资料"""
  54. # 获取用户
  55. user = db.query(User).filter(User.id == current_user.id).first()
  56. # 更新用户资料
  57. if profile.username is not None:
  58. # 检查用户名是否已存在
  59. existing_user = db.query(User).filter(User.username == profile.username).first()
  60. if existing_user and existing_user.id != current_user.id:
  61. raise HTTPException(
  62. status_code=status.HTTP_400_BAD_REQUEST,
  63. detail="用户名已被使用"
  64. )
  65. user.username = profile.username
  66. if profile.email is not None:
  67. # 检查邮箱是否已存在
  68. existing_user = db.query(User).filter(User.email == profile.email).first()
  69. if existing_user and existing_user.id != current_user.id:
  70. raise HTTPException(
  71. status_code=status.HTTP_400_BAD_REQUEST,
  72. detail="邮箱已被使用"
  73. )
  74. user.email = profile.email
  75. if profile.institution is not None:
  76. user.institution = profile.institution
  77. if profile.position is not None:
  78. user.position = profile.position
  79. if profile.research_fields is not None:
  80. user.research_fields = profile.research_fields
  81. # 保存更新
  82. user.updated_at = datetime.utcnow()
  83. db.commit()
  84. db.refresh(user)
  85. # 在返回前将datetime转换为字符串
  86. return {
  87. "id": user.id,
  88. "username": user.username,
  89. "email": user.email,
  90. "institution": user.institution,
  91. "position": user.position,
  92. "research_fields": user.research_fields,
  93. "membership_type": user.membership_type,
  94. "expiry_date": user.expiry_date.strftime("%Y-%m-%d") if user.expiry_date else None,
  95. "credits": user.credits
  96. }