12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576 |
- """
- 用户认证功能
- """
- from datetime import datetime, timedelta
- from typing import Optional
- from jose import JWTError, jwt
- from passlib.context import CryptContext
- from fastapi import Depends, HTTPException, status
- from fastapi.security import OAuth2PasswordBearer
- from sqlalchemy.orm import Session
- from backend.core.database import get_db
- from backend.core.models import User
- from backend.config import SECRET_KEY, ALGORITHM, ACCESS_TOKEN_EXPIRE_MINUTES
- # 密码上下文
- pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
- # OAuth2 密码Bearer
- oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/token")
- def verify_password(plain_password, hashed_password):
- """验证密码"""
- return pwd_context.verify(plain_password, hashed_password)
- def get_password_hash(password):
- """获取密码哈希"""
- return pwd_context.hash(password)
- def get_user(db: Session, username: str):
- """通过用户名获取用户"""
- return db.query(User).filter(User.username == username).first()
- def authenticate_user(db: Session, username: str, password: str):
- """认证用户"""
- user = get_user(db, username)
- if not user or not verify_password(password, user.hashed_password):
- return False
- return user
- def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
- """创建访问令牌"""
- to_encode = data.copy()
-
- if expires_delta:
- expire = datetime.utcnow() + expires_delta
- else:
- expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
-
- to_encode.update({"exp": expire})
- encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
- return encoded_jwt
- def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
- """获取当前用户并处理异常情况"""
- credentials_exception = HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="无效的认证凭据",
- headers={"WWW-Authenticate": "Bearer"},
- )
-
- try:
- # 解析token
- payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
- username: str = payload.get("sub")
- if username is None:
- raise credentials_exception
- except JWTError:
- raise credentials_exception
-
- # 查找用户
- user = db.query(User).filter(User.username == username).first()
- if user is None:
- raise credentials_exception
-
- return user
|