__init__.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. # mypy: allow-untyped-defs
  2. from functools import lru_cache as _lru_cache
  3. from typing import Optional
  4. import torch
  5. from ...library import Library as _Library
  6. __all__ = ["is_built", "is_available", "is_macos13_or_newer", "is_macos_or_newer"]
  7. def is_built() -> bool:
  8. r"""Return whether PyTorch is built with MPS support.
  9. Note that this doesn't necessarily mean MPS is available; just that
  10. if this PyTorch binary were run a machine with working MPS drivers
  11. and devices, we would be able to use it.
  12. """
  13. return torch._C._has_mps
  14. @_lru_cache
  15. def is_available() -> bool:
  16. r"""Return a bool indicating if MPS is currently available."""
  17. return torch._C._mps_is_available()
  18. @_lru_cache
  19. def is_macos_or_newer(major: int, minor: int) -> bool:
  20. r"""Return a bool indicating whether MPS is running on given MacOS or newer."""
  21. return torch._C._mps_is_on_macos_or_newer(major, minor)
  22. @_lru_cache
  23. def is_macos13_or_newer(minor: int = 0) -> bool:
  24. r"""Return a bool indicating whether MPS is running on MacOS 13 or newer."""
  25. return torch._C._mps_is_on_macos_or_newer(13, minor)
  26. _lib: Optional[_Library] = None
  27. def _init():
  28. r"""Register prims as implementation of var_mean and group_norm."""
  29. global _lib
  30. if is_built() is False or _lib is not None:
  31. return
  32. from ..._decomp.decompositions import (
  33. native_group_norm_backward as _native_group_norm_backward,
  34. )
  35. from ..._refs import native_group_norm as _native_group_norm
  36. _lib = _Library("aten", "IMPL")
  37. _lib.impl("native_group_norm", _native_group_norm, "MPS")
  38. _lib.impl("native_group_norm_backward", _native_group_norm_backward, "MPS")