123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147 |
- """
- 用户认证API路由
- """
- from fastapi import APIRouter, Depends, HTTPException, status, Request
- from fastapi.security import OAuth2PasswordRequestForm
- from sqlalchemy.orm import Session
- from datetime import timedelta, datetime
- from pydantic import BaseModel, Field
- from typing import Optional
- import logging
- from backend.core.database import get_db
- from backend.core.models import User
- from backend.core.auth import (
- authenticate_user,
- create_access_token,
- get_password_hash,
- get_current_user
- )
- from backend.config import ACCESS_TOKEN_EXPIRE_MINUTES
- # 设置日志
- logger = logging.getLogger(__name__)
- router = APIRouter(tags=["auth"], prefix="/auth")
- class Token(BaseModel):
- access_token: str
- token_type: str
-
- class UserCreate(BaseModel):
- username: str
- email: str
- password: str
- institution: Optional[str] = "江西财经大学"
- position: Optional[str] = "学生"
-
- class UserResponse(BaseModel):
- id: int
- username: str
- email: str
- institution: str
- position: str
- research_fields: list
- membership_type: str
- expiry_date: Optional[str] = None
- credits: int
-
- class Config:
- orm_mode = True
- @router.post("/token", response_model=Token)
- async def login_for_access_token(
- form_data: OAuth2PasswordRequestForm = Depends(),
- db: Session = Depends(get_db)
- ):
- """获取访问令牌(登录)"""
- user = authenticate_user(db, form_data.username, form_data.password)
- if not user:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="用户名或密码不正确",
- headers={"WWW-Authenticate": "Bearer"},
- )
-
- access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
- access_token = create_access_token(
- data={"sub": user.username},
- expires_delta=access_token_expires
- )
-
- return {"access_token": access_token, "token_type": "bearer"}
- @router.post("/register", response_model=UserResponse)
- async def register_user(user_create: UserCreate, db: Session = Depends(get_db)):
- """注册新用户"""
- try:
- # 检查用户名是否已存在
- db_user = db.query(User).filter(User.username == user_create.username).first()
- if db_user:
- # 更友好的错误消息
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="用户名已被使用,请尝试其他用户名"
- )
-
- # 检查邮箱是否已存在
- db_user = db.query(User).filter(User.email == user_create.email).first()
- if db_user:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="邮箱已被注册,请使用其他邮箱"
- )
-
- # 创建新用户
- logger.info(f"创建新用户: {user_create.username}")
- hashed_password = get_password_hash(user_create.password)
-
- # 使用datetime对象
- expiry_date = datetime(2025, 12, 31)
-
- db_user = User(
- username=user_create.username,
- email=user_create.email,
- hashed_password=hashed_password,
- institution=user_create.institution or "江西财经大学",
- position=user_create.position or "学生",
- research_fields=["人工智能", "计算机视觉", "自然语言处理"],
- membership_type="高级研究版",
- expiry_date=expiry_date,
- credits=5000
- )
-
- db.add(db_user)
- db.commit()
- db.refresh(db_user)
- logger.info(f"用户创建成功: {db_user.id}")
-
- # 明确构建返回数据
- result = {
- "id": db_user.id,
- "username": db_user.username,
- "email": db_user.email,
- "institution": db_user.institution,
- "position": db_user.position,
- "research_fields": db_user.research_fields,
- "membership_type": db_user.membership_type,
- "expiry_date": db_user.expiry_date.strftime("%Y-%m-%d") if db_user.expiry_date else None,
- "credits": db_user.credits
- }
-
- return result
- except HTTPException:
- # 直接重新抛出HTTP异常
- raise
- except Exception as e:
- # 记录错误但返回友好消息
- logger.error(f"注册用户时发生错误: {str(e)}", exc_info=True)
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="服务器处理请求时出错,请稍后再试"
- )
- @router.get("/me", response_model=UserResponse)
- async def read_users_me(current_user: User = Depends(get_current_user)):
- """获取当前用户信息"""
- return current_user
|