123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- """
- 用户资料API路由
- """
- from fastapi import APIRouter, Depends, HTTPException, status
- from sqlalchemy.orm import Session
- from pydantic import BaseModel
- from typing import List, Optional
- from datetime import datetime
- from backend.core.database import get_db
- from backend.core.models import User
- from backend.core.auth import get_current_user
- router = APIRouter(tags=["users"], prefix="/users")
- class UserProfileUpdate(BaseModel):
- username: Optional[str] = None
- email: Optional[str] = None
- institution: Optional[str] = None
- position: Optional[str] = None
- research_fields: Optional[List[str]] = None
-
- class UserProfileResponse(BaseModel):
- id: int
- username: str
- email: str
- institution: str
- position: str
- research_fields: list
- membership_type: str
- expiry_date: Optional[str] = None
- credits: int
-
- class Config:
- orm_mode = True
- @router.get("/profile", response_model=UserProfileResponse)
- async def get_user_profile(current_user: User = Depends(get_current_user)):
- """获取用户资料"""
- # 规范化日期格式,确保一致性
- result = {
- "id": current_user.id,
- "username": current_user.username,
- "email": current_user.email,
- "institution": current_user.institution,
- "position": current_user.position,
- "research_fields": current_user.research_fields,
- "membership_type": current_user.membership_type,
- "expiry_date": current_user.expiry_date.strftime("%Y-%m-%d") if current_user.expiry_date else None,
- "credits": current_user.credits
- }
- return result
- @router.put("/profile", response_model=UserProfileResponse)
- async def update_user_profile(
- profile: UserProfileUpdate,
- current_user: User = Depends(get_current_user),
- db: Session = Depends(get_db)
- ):
- """更新用户资料"""
- # 获取用户
- user = db.query(User).filter(User.id == current_user.id).first()
-
- # 更新用户资料
- if profile.username is not None:
- # 检查用户名是否已存在
- existing_user = db.query(User).filter(User.username == profile.username).first()
- if existing_user and existing_user.id != current_user.id:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="用户名已被使用"
- )
- user.username = profile.username
-
- if profile.email is not None:
- # 检查邮箱是否已存在
- existing_user = db.query(User).filter(User.email == profile.email).first()
- if existing_user and existing_user.id != current_user.id:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="邮箱已被使用"
- )
- user.email = profile.email
-
- if profile.institution is not None:
- user.institution = profile.institution
-
- if profile.position is not None:
- user.position = profile.position
-
- if profile.research_fields is not None:
- user.research_fields = profile.research_fields
-
- # 保存更新
- user.updated_at = datetime.utcnow()
- db.commit()
- db.refresh(user)
-
- # 在返回前将datetime转换为字符串
- return {
- "id": user.id,
- "username": user.username,
- "email": user.email,
- "institution": user.institution,
- "position": user.position,
- "research_fields": user.research_fields,
- "membership_type": user.membership_type,
- "expiry_date": user.expiry_date.strftime("%Y-%m-%d") if user.expiry_date else None,
- "credits": user.credits
- }
|