| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148 |
- # mypy: allow-untyped-defs
- r"""
- This package enables an interface for accessing MPS (Metal Performance Shaders) backend in Python.
- Metal is Apple's API for programming metal GPU (graphics processor unit). Using MPS means that increased
- performance can be achieved, by running work on the metal GPU(s).
- See https://developer.apple.com/documentation/metalperformanceshaders for more details.
- """
- from typing import Union
- import torch
- from .. import Tensor
- _is_in_bad_fork = getattr(torch._C, "_mps_is_in_bad_fork", lambda: False)
- _default_mps_generator: torch._C.Generator = None # type: ignore[assignment]
- # local helper function (not public or exported)
- def _get_default_mps_generator() -> torch._C.Generator:
- global _default_mps_generator
- if _default_mps_generator is None:
- _default_mps_generator = torch._C._mps_get_default_generator()
- return _default_mps_generator
- def device_count() -> int:
- r"""Returns the number of available MPS devices."""
- return int(torch._C._has_mps and torch._C._mps_is_available())
- def synchronize() -> None:
- r"""Waits for all kernels in all streams on a MPS device to complete."""
- return torch._C._mps_deviceSynchronize()
- def get_rng_state(device: Union[int, str, torch.device] = "mps") -> Tensor:
- r"""Returns the random number generator state as a ByteTensor.
- Args:
- device (torch.device or int, optional): The device to return the RNG state of.
- Default: ``'mps'`` (i.e., ``torch.device('mps')``, the current MPS device).
- """
- return _get_default_mps_generator().get_state()
- def set_rng_state(
- new_state: Tensor, device: Union[int, str, torch.device] = "mps"
- ) -> None:
- r"""Sets the random number generator state.
- Args:
- new_state (torch.ByteTensor): The desired state
- device (torch.device or int, optional): The device to set the RNG state.
- Default: ``'mps'`` (i.e., ``torch.device('mps')``, the current MPS device).
- """
- new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
- _get_default_mps_generator().set_state(new_state_copy)
- def manual_seed(seed: int) -> None:
- r"""Sets the seed for generating random numbers.
- Args:
- seed (int): The desired seed.
- """
- # the torch.mps.manual_seed() can be called from the global
- # torch.manual_seed() in torch/random.py. So we need to make
- # sure mps is available (otherwise we just return without
- # erroring out)
- if not torch._C._has_mps:
- return
- seed = int(seed)
- _get_default_mps_generator().manual_seed(seed)
- def seed() -> None:
- r"""Sets the seed for generating random numbers to a random number."""
- _get_default_mps_generator().seed()
- def empty_cache() -> None:
- r"""Releases all unoccupied cached memory currently held by the caching
- allocator so that those can be used in other GPU applications.
- """
- torch._C._mps_emptyCache()
- def set_per_process_memory_fraction(fraction) -> None:
- r"""Set memory fraction for limiting process's memory allocation on MPS device.
- The allowed value equals the fraction multiplied by recommended maximum device memory
- (obtained from Metal API device.recommendedMaxWorkingSetSize).
- If trying to allocate more than the allowed value in a process, it will raise an out of
- memory error in allocator.
- Args:
- fraction(float): Range: 0~2. Allowed memory equals total_memory * fraction.
- .. note::
- Passing 0 to fraction means unlimited allocations
- (may cause system failure if out of memory).
- Passing fraction greater than 1.0 allows limits beyond the value
- returned from device.recommendedMaxWorkingSetSize.
- """
- if not isinstance(fraction, float):
- raise TypeError("Invalid type for fraction argument, must be `float`")
- if fraction < 0 or fraction > 2:
- raise ValueError(f"Invalid fraction value: {fraction}. Allowed range: 0~2")
- torch._C._mps_setMemoryFraction(fraction)
- def current_allocated_memory() -> int:
- r"""Returns the current GPU memory occupied by tensors in bytes.
- .. note::
- The returned size does not include cached allocations in
- memory pools of MPSAllocator.
- """
- return torch._C._mps_currentAllocatedMemory()
- def driver_allocated_memory() -> int:
- r"""Returns total GPU memory allocated by Metal driver for the process in bytes.
- .. note::
- The returned size includes cached allocations in MPSAllocator pools
- as well as allocations from MPS/MPSGraph frameworks.
- """
- return torch._C._mps_driverAllocatedMemory()
- from . import profiler
- from .event import Event
- __all__ = [
- "device_count",
- "get_rng_state",
- "manual_seed",
- "seed",
- "set_rng_state",
- "synchronize",
- "empty_cache",
- "set_per_process_memory_fraction",
- "current_allocated_memory",
- "driver_allocated_memory",
- "Event",
- "profiler",
- ]
|