__init__.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # mypy: allow-untyped-defs
  2. r"""
  3. This package implements abstractions found in ``torch.cuda``
  4. to facilitate writing device-agnostic code.
  5. """
  6. from contextlib import AbstractContextManager
  7. from typing import Any, Optional, Union
  8. import torch
  9. from .. import device as _device
  10. from . import amp
  11. __all__ = [
  12. "is_available",
  13. "synchronize",
  14. "current_device",
  15. "current_stream",
  16. "stream",
  17. "set_device",
  18. "device_count",
  19. "Stream",
  20. "StreamContext",
  21. "Event",
  22. ]
  23. _device_t = Union[_device, str, int, None]
  24. def _is_cpu_support_avx2() -> bool:
  25. r"""Returns a bool indicating if CPU supports AVX2."""
  26. return torch._C._cpu._is_cpu_support_avx2()
  27. def _is_cpu_support_avx512() -> bool:
  28. r"""Returns a bool indicating if CPU supports AVX512."""
  29. return torch._C._cpu._is_cpu_support_avx512()
  30. def _is_cpu_support_vnni() -> bool:
  31. r"""Returns a bool indicating if CPU supports VNNI."""
  32. return torch._C._cpu._is_cpu_support_vnni()
  33. def is_available() -> bool:
  34. r"""Returns a bool indicating if CPU is currently available.
  35. N.B. This function only exists to facilitate device-agnostic code
  36. """
  37. return True
  38. def synchronize(device: _device_t = None) -> None:
  39. r"""Waits for all kernels in all streams on the CPU device to complete.
  40. Args:
  41. device (torch.device or int, optional): ignored, there's only one CPU device.
  42. N.B. This function only exists to facilitate device-agnostic code.
  43. """
  44. class Stream:
  45. """
  46. N.B. This class only exists to facilitate device-agnostic code
  47. """
  48. def __init__(self, priority: int = -1) -> None:
  49. pass
  50. def wait_stream(self, stream) -> None:
  51. pass
  52. class Event:
  53. def query(self) -> bool:
  54. return True
  55. def record(self, stream=None) -> None:
  56. pass
  57. def synchronize(self) -> None:
  58. pass
  59. def wait(self, stream=None) -> None:
  60. pass
  61. _default_cpu_stream = Stream()
  62. _current_stream = _default_cpu_stream
  63. def current_stream(device: _device_t = None) -> Stream:
  64. r"""Returns the currently selected :class:`Stream` for a given device.
  65. Args:
  66. device (torch.device or int, optional): Ignored.
  67. N.B. This function only exists to facilitate device-agnostic code
  68. """
  69. return _current_stream
  70. class StreamContext(AbstractContextManager):
  71. r"""Context-manager that selects a given stream.
  72. N.B. This class only exists to facilitate device-agnostic code
  73. """
  74. cur_stream: Optional[Stream]
  75. def __init__(self, stream):
  76. self.stream = stream
  77. self.prev_stream = _default_cpu_stream
  78. def __enter__(self):
  79. cur_stream = self.stream
  80. if cur_stream is None:
  81. return
  82. global _current_stream
  83. self.prev_stream = _current_stream
  84. _current_stream = cur_stream
  85. def __exit__(self, type: Any, value: Any, traceback: Any) -> None:
  86. cur_stream = self.stream
  87. if cur_stream is None:
  88. return
  89. global _current_stream
  90. _current_stream = self.prev_stream
  91. def stream(stream: Stream) -> AbstractContextManager:
  92. r"""Wrapper around the Context-manager StreamContext that
  93. selects a given stream.
  94. N.B. This function only exists to facilitate device-agnostic code
  95. """
  96. return StreamContext(stream)
  97. def device_count() -> int:
  98. r"""Returns number of CPU devices (not cores). Always 1.
  99. N.B. This function only exists to facilitate device-agnostic code
  100. """
  101. return 1
  102. def set_device(device: _device_t) -> None:
  103. r"""Sets the current device, in CPU we do nothing.
  104. N.B. This function only exists to facilitate device-agnostic code
  105. """
  106. def current_device() -> str:
  107. r"""Returns current device for cpu. Always 'cpu'.
  108. N.B. This function only exists to facilitate device-agnostic code
  109. """
  110. return "cpu"