auth.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. """
  2. 用户认证功能
  3. """
  4. from datetime import datetime, timedelta
  5. from typing import Optional
  6. from jose import JWTError, jwt
  7. from passlib.context import CryptContext
  8. from fastapi import Depends, HTTPException, status
  9. from fastapi.security import OAuth2PasswordBearer
  10. from sqlalchemy.orm import Session
  11. from backend.core.database import get_db
  12. from backend.core.models import User
  13. from backend.config import SECRET_KEY, ALGORITHM, ACCESS_TOKEN_EXPIRE_MINUTES
  14. # 密码上下文
  15. pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
  16. # OAuth2 密码Bearer
  17. oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/token")
  18. def verify_password(plain_password, hashed_password):
  19. """验证密码"""
  20. return pwd_context.verify(plain_password, hashed_password)
  21. def get_password_hash(password):
  22. """获取密码哈希"""
  23. return pwd_context.hash(password)
  24. def get_user(db: Session, username: str):
  25. """通过用户名获取用户"""
  26. return db.query(User).filter(User.username == username).first()
  27. def authenticate_user(db: Session, username: str, password: str):
  28. """认证用户"""
  29. user = get_user(db, username)
  30. if not user or not verify_password(password, user.hashed_password):
  31. return False
  32. return user
  33. def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
  34. """创建访问令牌"""
  35. to_encode = data.copy()
  36. if expires_delta:
  37. expire = datetime.utcnow() + expires_delta
  38. else:
  39. expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
  40. to_encode.update({"exp": expire})
  41. encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
  42. return encoded_jwt
  43. def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
  44. """获取当前用户并处理异常情况"""
  45. credentials_exception = HTTPException(
  46. status_code=status.HTTP_401_UNAUTHORIZED,
  47. detail="无效的认证凭据",
  48. headers={"WWW-Authenticate": "Bearer"},
  49. )
  50. try:
  51. # 解析token
  52. payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
  53. username: str = payload.get("sub")
  54. if username is None:
  55. raise credentials_exception
  56. except JWTError:
  57. raise credentials_exception
  58. # 查找用户
  59. user = db.query(User).filter(User.username == username).first()
  60. if user is None:
  61. raise credentials_exception
  62. return user