""" 用户认证API路由 """ from fastapi import APIRouter, Depends, HTTPException, status, Request from fastapi.security import OAuth2PasswordRequestForm from sqlalchemy.orm import Session from datetime import timedelta, datetime from pydantic import BaseModel, Field from typing import Optional import logging from backend.core.database import get_db from backend.core.models import User from backend.core.auth import ( authenticate_user, create_access_token, get_password_hash, get_current_user ) from backend.config import ACCESS_TOKEN_EXPIRE_MINUTES # 设置日志 logger = logging.getLogger(__name__) router = APIRouter(tags=["auth"], prefix="/auth") class Token(BaseModel): access_token: str token_type: str class UserCreate(BaseModel): username: str email: str password: str institution: Optional[str] = "江西财经大学" position: Optional[str] = "学生" class UserResponse(BaseModel): id: int username: str email: str institution: str position: str research_fields: list membership_type: str expiry_date: Optional[str] = None credits: int class Config: orm_mode = True @router.post("/token", response_model=Token) async def login_for_access_token( form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db) ): """获取访问令牌(登录)""" user = authenticate_user(db, form_data.username, form_data.password) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名或密码不正确", headers={"WWW-Authenticate": "Bearer"}, ) access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) access_token = create_access_token( data={"sub": user.username}, expires_delta=access_token_expires ) return {"access_token": access_token, "token_type": "bearer"} @router.post("/register", response_model=UserResponse) async def register_user(user_create: UserCreate, db: Session = Depends(get_db)): """注册新用户""" try: # 检查用户名是否已存在 db_user = db.query(User).filter(User.username == user_create.username).first() if db_user: # 更友好的错误消息 raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="用户名已被使用,请尝试其他用户名" ) # 检查邮箱是否已存在 db_user = db.query(User).filter(User.email == user_create.email).first() if db_user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="邮箱已被注册,请使用其他邮箱" ) # 创建新用户 logger.info(f"创建新用户: {user_create.username}") hashed_password = get_password_hash(user_create.password) # 使用datetime对象 expiry_date = datetime(2025, 12, 31) db_user = User( username=user_create.username, email=user_create.email, hashed_password=hashed_password, institution=user_create.institution or "江西财经大学", position=user_create.position or "学生", research_fields=["人工智能", "计算机视觉", "自然语言处理"], membership_type="高级研究版", expiry_date=expiry_date, credits=5000 ) db.add(db_user) db.commit() db.refresh(db_user) logger.info(f"用户创建成功: {db_user.id}") # 明确构建返回数据 result = { "id": db_user.id, "username": db_user.username, "email": db_user.email, "institution": db_user.institution, "position": db_user.position, "research_fields": db_user.research_fields, "membership_type": db_user.membership_type, "expiry_date": db_user.expiry_date.strftime("%Y-%m-%d") if db_user.expiry_date else None, "credits": db_user.credits } return result except HTTPException: # 直接重新抛出HTTP异常 raise except Exception as e: # 记录错误但返回友好消息 logger.error(f"注册用户时发生错误: {str(e)}", exc_info=True) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="服务器处理请求时出错,请稍后再试" ) @router.get("/me", response_model=UserResponse) async def read_users_me(current_user: User = Depends(get_current_user)): """获取当前用户信息""" return current_user