_foreach_utils.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. from typing import List, Dict, Tuple, Optional
  2. import torch
  3. from torch import Tensor
  4. from torch.autograd.grad_mode import no_grad
  5. from typing_extensions import TypeAlias
  6. def _get_foreach_kernels_supported_devices() -> List[str]:
  7. r"""Return the device type list that supports foreach kernels."""
  8. return ["cuda", "xpu", torch._C._get_privateuse1_backend_name()]
  9. def _get_fused_kernels_supported_devices() -> List[str]:
  10. r"""Return the device type list that supports fused kernels in optimizer."""
  11. return ["cuda", "xpu", "cpu", torch._C._get_privateuse1_backend_name()]
  12. TensorListList: TypeAlias = List[List[Optional[Tensor]]]
  13. Indices: TypeAlias = List[int]
  14. _foreach_supported_types = [torch.Tensor]
  15. # This util function splits tensors into groups by device and dtype, which is useful before sending
  16. # tensors off to a foreach implementation, which requires tensors to be on one device and dtype.
  17. # If tensorlistlist contains more than one tensorlist, the following assumptions are made BUT NOT verified:
  18. # - tensorlists CAN be None
  19. # - all tensors in the first specified list cannot be None
  20. # - given an index i, all specified tensorlist[i]s match in dtype and device
  21. # with_indices (bool, optional): whether to track previous indices as the last list per dictionary entry.
  22. # It comes in handy if there are Nones or literals in the tensorlists that are getting scattered out.
  23. # Whereas mutating a tensor in the resulting split-up tensorlists WILL propagate changes back to the
  24. # original input tensorlists, changing up Nones/literals WILL NOT propagate, and manual propagation
  25. # may be necessary. Check out torch/optim/sgd.py for an example.
  26. @no_grad()
  27. def _group_tensors_by_device_and_dtype(
  28. tensorlistlist: TensorListList,
  29. with_indices: bool = False,
  30. ) -> Dict[Tuple[torch.device, torch.dtype], Tuple[TensorListList, Indices]]:
  31. return torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices)
  32. def _device_has_foreach_support(device: torch.device) -> bool:
  33. return device.type in (_get_foreach_kernels_supported_devices() + ["cpu"]) and not torch.jit.is_scripting()
  34. def _has_foreach_support(tensors: List[Tensor], device: torch.device) -> bool:
  35. return _device_has_foreach_support(device) and all(t is None or type(t) in _foreach_supported_types for t in tensors)