auth.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. """
  2. 用户认证API路由
  3. """
  4. from fastapi import APIRouter, Depends, HTTPException, status, Request
  5. from fastapi.security import OAuth2PasswordRequestForm
  6. from sqlalchemy.orm import Session
  7. from datetime import timedelta, datetime
  8. from pydantic import BaseModel, Field
  9. from typing import Optional
  10. import logging
  11. from backend.core.database import get_db
  12. from backend.core.models import User
  13. from backend.core.auth import (
  14. authenticate_user,
  15. create_access_token,
  16. get_password_hash,
  17. get_current_user
  18. )
  19. from backend.config import ACCESS_TOKEN_EXPIRE_MINUTES
  20. # 设置日志
  21. logger = logging.getLogger(__name__)
  22. router = APIRouter(tags=["auth"], prefix="/auth")
  23. class Token(BaseModel):
  24. access_token: str
  25. token_type: str
  26. class UserCreate(BaseModel):
  27. username: str
  28. email: str
  29. password: str
  30. institution: Optional[str] = "江西财经大学"
  31. position: Optional[str] = "学生"
  32. class UserResponse(BaseModel):
  33. id: int
  34. username: str
  35. email: str
  36. institution: str
  37. position: str
  38. research_fields: list
  39. membership_type: str
  40. expiry_date: Optional[str] = None
  41. credits: int
  42. class Config:
  43. orm_mode = True
  44. @router.post("/token", response_model=Token)
  45. async def login_for_access_token(
  46. form_data: OAuth2PasswordRequestForm = Depends(),
  47. db: Session = Depends(get_db)
  48. ):
  49. """获取访问令牌(登录)"""
  50. user = authenticate_user(db, form_data.username, form_data.password)
  51. if not user:
  52. raise HTTPException(
  53. status_code=status.HTTP_401_UNAUTHORIZED,
  54. detail="用户名或密码不正确",
  55. headers={"WWW-Authenticate": "Bearer"},
  56. )
  57. access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
  58. access_token = create_access_token(
  59. data={"sub": user.username},
  60. expires_delta=access_token_expires
  61. )
  62. return {"access_token": access_token, "token_type": "bearer"}
  63. @router.post("/register", response_model=UserResponse)
  64. async def register_user(user_create: UserCreate, db: Session = Depends(get_db)):
  65. """注册新用户"""
  66. try:
  67. # 检查用户名是否已存在
  68. db_user = db.query(User).filter(User.username == user_create.username).first()
  69. if db_user:
  70. # 更友好的错误消息
  71. raise HTTPException(
  72. status_code=status.HTTP_400_BAD_REQUEST,
  73. detail="用户名已被使用,请尝试其他用户名"
  74. )
  75. # 检查邮箱是否已存在
  76. db_user = db.query(User).filter(User.email == user_create.email).first()
  77. if db_user:
  78. raise HTTPException(
  79. status_code=status.HTTP_400_BAD_REQUEST,
  80. detail="邮箱已被注册,请使用其他邮箱"
  81. )
  82. # 创建新用户
  83. logger.info(f"创建新用户: {user_create.username}")
  84. hashed_password = get_password_hash(user_create.password)
  85. # 使用datetime对象
  86. expiry_date = datetime(2025, 12, 31)
  87. db_user = User(
  88. username=user_create.username,
  89. email=user_create.email,
  90. hashed_password=hashed_password,
  91. institution=user_create.institution or "江西财经大学",
  92. position=user_create.position or "学生",
  93. research_fields=["人工智能", "计算机视觉", "自然语言处理"],
  94. membership_type="高级研究版",
  95. expiry_date=expiry_date,
  96. credits=5000
  97. )
  98. db.add(db_user)
  99. db.commit()
  100. db.refresh(db_user)
  101. logger.info(f"用户创建成功: {db_user.id}")
  102. # 明确构建返回数据
  103. result = {
  104. "id": db_user.id,
  105. "username": db_user.username,
  106. "email": db_user.email,
  107. "institution": db_user.institution,
  108. "position": db_user.position,
  109. "research_fields": db_user.research_fields,
  110. "membership_type": db_user.membership_type,
  111. "expiry_date": db_user.expiry_date.strftime("%Y-%m-%d") if db_user.expiry_date else None,
  112. "credits": db_user.credits
  113. }
  114. return result
  115. except HTTPException:
  116. # 直接重新抛出HTTP异常
  117. raise
  118. except Exception as e:
  119. # 记录错误但返回友好消息
  120. logger.error(f"注册用户时发生错误: {str(e)}", exc_info=True)
  121. raise HTTPException(
  122. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  123. detail="服务器处理请求时出错,请稍后再试"
  124. )
  125. @router.get("/me", response_model=UserResponse)
  126. async def read_users_me(current_user: User = Depends(get_current_user)):
  127. """获取当前用户信息"""
  128. return current_user