_device.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. # mypy: allow-untyped-defs
  2. from typing import Optional
  3. import torch
  4. from torch.overrides import TorchFunctionMode
  5. from torch.utils._contextlib import context_decorator
  6. import functools
  7. CURRENT_DEVICE: Optional[torch.device] = None
  8. @functools.lru_cache(1)
  9. def _device_constructors():
  10. return {
  11. # standard ones
  12. torch.empty,
  13. torch.empty_permuted,
  14. torch.empty_strided,
  15. torch.empty_quantized,
  16. torch.ones,
  17. torch.arange,
  18. torch.bartlett_window,
  19. torch.blackman_window,
  20. torch.eye,
  21. torch.fft.fftfreq,
  22. torch.fft.rfftfreq,
  23. torch.full,
  24. torch.fill,
  25. torch.hamming_window,
  26. torch.hann_window,
  27. torch.kaiser_window,
  28. torch.linspace,
  29. torch.logspace,
  30. torch.nested.nested_tensor,
  31. # This function doesn't actually take a device argument
  32. # torch.normal,
  33. torch.ones,
  34. torch.rand,
  35. torch.randn,
  36. torch.randint,
  37. torch.randperm,
  38. torch.range,
  39. torch.sparse_coo_tensor,
  40. torch.sparse_compressed_tensor,
  41. torch.sparse_csr_tensor,
  42. torch.sparse_csc_tensor,
  43. torch.sparse_bsr_tensor,
  44. torch.sparse_bsc_tensor,
  45. torch.tril_indices,
  46. torch.triu_indices,
  47. torch.vander,
  48. torch.zeros,
  49. torch.asarray,
  50. # weird ones
  51. torch.tensor,
  52. torch.as_tensor,
  53. torch.scalar_tensor,
  54. torch.asarray,
  55. }
  56. # NB: This is directly called from C++ in torch/csrc/Device.cpp
  57. class DeviceContext(TorchFunctionMode):
  58. def __init__(self, device):
  59. self.device = torch.device(device)
  60. def __enter__(self):
  61. global CURRENT_DEVICE
  62. self.old_device = CURRENT_DEVICE
  63. CURRENT_DEVICE = self.device
  64. return super().__enter__()
  65. def __exit__(self, exc_type, exc_val, exc_tb):
  66. global CURRENT_DEVICE
  67. CURRENT_DEVICE = self.old_device
  68. return super().__exit__(exc_type, exc_val, exc_tb)
  69. def __torch_function__(self, func, types, args=(), kwargs=None):
  70. kwargs = kwargs or {}
  71. if func in _device_constructors() and kwargs.get('device') is None:
  72. kwargs['device'] = self.device
  73. return func(*args, **kwargs)
  74. # NB: This is directly called from C++ in torch/csrc/Device.cpp
  75. def device_decorator(device, func):
  76. return context_decorator(lambda: device, func)
  77. def set_device(device):
  78. """
  79. Set the default device inside of the wrapped function by decorating it with this function.
  80. If you would like to use this as a context manager, use device as a
  81. context manager directly, e.g., ``with torch.device(device)``.
  82. """
  83. return lambda func: device_decorator(torch.device(device), func)