common_cuda.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. # mypy: ignore-errors
  2. r"""This file is allowed to initialize CUDA context when imported."""
  3. import functools
  4. import torch
  5. import torch.cuda
  6. from torch.testing._internal.common_utils import LazyVal, TEST_NUMBA, TEST_WITH_ROCM, TEST_CUDA, IS_WINDOWS
  7. import inspect
  8. import contextlib
  9. import os
  10. CUDA_ALREADY_INITIALIZED_ON_IMPORT = torch.cuda.is_initialized()
  11. TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
  12. CUDA_DEVICE = torch.device("cuda:0") if TEST_CUDA else None
  13. # note: if ROCm is targeted, TEST_CUDNN is code for TEST_MIOPEN
  14. if TEST_WITH_ROCM:
  15. TEST_CUDNN = LazyVal(lambda: TEST_CUDA)
  16. else:
  17. TEST_CUDNN = LazyVal(lambda: TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE)))
  18. TEST_CUDNN_VERSION = LazyVal(lambda: torch.backends.cudnn.version() if TEST_CUDNN else 0)
  19. SM53OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (5, 3))
  20. SM60OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (6, 0))
  21. SM70OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 0))
  22. SM75OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 5))
  23. SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0))
  24. SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0))
  25. IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() in [(7, 2), (8, 7)])
  26. def evaluate_gfx_arch_exact(matching_arch):
  27. if not torch.cuda.is_available():
  28. return False
  29. gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName
  30. arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name)
  31. return arch == matching_arch
  32. GFX90A_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-'))
  33. GFX942_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-'))
  34. def evaluate_platform_supports_flash_attention():
  35. if TEST_WITH_ROCM:
  36. return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-')
  37. if TEST_CUDA:
  38. return not IS_WINDOWS and SM80OrLater
  39. return False
  40. def evaluate_platform_supports_efficient_attention():
  41. if TEST_WITH_ROCM:
  42. return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-')
  43. if TEST_CUDA:
  44. return True
  45. return False
  46. PLATFORM_SUPPORTS_FLASH_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_flash_attention())
  47. PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_efficient_attention())
  48. # TODO(eqy): gate this against a cuDNN version
  49. PLATFORM_SUPPORTS_CUDNN_ATTENTION: bool = LazyVal(lambda: TEST_CUDA and not TEST_WITH_ROCM and
  50. torch.backends.cuda.cudnn_sdp_enabled())
  51. # This condition always evaluates to PLATFORM_SUPPORTS_MEM_EFF_ATTENTION but for logical clarity we keep it separate
  52. PLATFORM_SUPPORTS_FUSED_ATTENTION: bool = LazyVal(lambda: PLATFORM_SUPPORTS_FLASH_ATTENTION or PLATFORM_SUPPORTS_MEM_EFF_ATTENTION)
  53. PLATFORM_SUPPORTS_FUSED_SDPA: bool = TEST_CUDA and not TEST_WITH_ROCM
  54. PLATFORM_SUPPORTS_BF16: bool = LazyVal(lambda: TEST_CUDA and SM80OrLater)
  55. if TEST_NUMBA:
  56. try:
  57. import numba.cuda
  58. TEST_NUMBA_CUDA = numba.cuda.is_available()
  59. except Exception as e:
  60. TEST_NUMBA_CUDA = False
  61. TEST_NUMBA = False
  62. else:
  63. TEST_NUMBA_CUDA = False
  64. # Used below in `initialize_cuda_context_rng` to ensure that CUDA context and
  65. # RNG have been initialized.
  66. __cuda_ctx_rng_initialized = False
  67. # after this call, CUDA context and RNG must have been initialized on each GPU
  68. def initialize_cuda_context_rng():
  69. global __cuda_ctx_rng_initialized
  70. assert TEST_CUDA, 'CUDA must be available when calling initialize_cuda_context_rng'
  71. if not __cuda_ctx_rng_initialized:
  72. # initialize cuda context and rng for memory tests
  73. for i in range(torch.cuda.device_count()):
  74. torch.randn(1, device=f"cuda:{i}")
  75. __cuda_ctx_rng_initialized = True
  76. # Test whether hardware TF32 math mode enabled. It is enabled only on:
  77. # - CUDA >= 11
  78. # - arch >= Ampere
  79. def tf32_is_not_fp32():
  80. if not torch.cuda.is_available() or torch.version.cuda is None:
  81. return False
  82. if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
  83. return False
  84. if int(torch.version.cuda.split('.')[0]) < 11:
  85. return False
  86. return True
  87. @contextlib.contextmanager
  88. def tf32_off():
  89. old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
  90. try:
  91. torch.backends.cuda.matmul.allow_tf32 = False
  92. with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=False):
  93. yield
  94. finally:
  95. torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
  96. @contextlib.contextmanager
  97. def tf32_on(self, tf32_precision=1e-5):
  98. old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
  99. old_precision = self.precision
  100. try:
  101. torch.backends.cuda.matmul.allow_tf32 = True
  102. self.precision = tf32_precision
  103. with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True):
  104. yield
  105. finally:
  106. torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
  107. self.precision = old_precision
  108. # This is a wrapper that wraps a test to run this test twice, one with
  109. # allow_tf32=True, another with allow_tf32=False. When running with
  110. # allow_tf32=True, it will use reduced precision as specified by the
  111. # argument. For example:
  112. # @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
  113. # @tf32_on_and_off(0.005)
  114. # def test_matmul(self, device, dtype):
  115. # a = ...; b = ...;
  116. # c = torch.matmul(a, b)
  117. # self.assertEqual(c, expected)
  118. # In the above example, when testing torch.float32 and torch.complex64 on CUDA
  119. # on a CUDA >= 11 build on an >=Ampere architecture, the matmul will be running at
  120. # TF32 mode and TF32 mode off, and on TF32 mode, the assertEqual will use reduced
  121. # precision to check values.
  122. #
  123. # This decorator can be used for function with or without device/dtype, such as
  124. # @tf32_on_and_off(0.005)
  125. # def test_my_op(self)
  126. # @tf32_on_and_off(0.005)
  127. # def test_my_op(self, device)
  128. # @tf32_on_and_off(0.005)
  129. # def test_my_op(self, device, dtype)
  130. # @tf32_on_and_off(0.005)
  131. # def test_my_op(self, dtype)
  132. # if neither device nor dtype is specified, it will check if the system has ampere device
  133. # if device is specified, it will check if device is cuda
  134. # if dtype is specified, it will check if dtype is float32 or complex64
  135. # tf32 and fp32 are different only when all the three checks pass
  136. def tf32_on_and_off(tf32_precision=1e-5):
  137. def with_tf32_disabled(self, function_call):
  138. with tf32_off():
  139. function_call()
  140. def with_tf32_enabled(self, function_call):
  141. with tf32_on(self, tf32_precision):
  142. function_call()
  143. def wrapper(f):
  144. params = inspect.signature(f).parameters
  145. arg_names = tuple(params.keys())
  146. @functools.wraps(f)
  147. def wrapped(*args, **kwargs):
  148. for k, v in zip(arg_names, args):
  149. kwargs[k] = v
  150. cond = tf32_is_not_fp32()
  151. if 'device' in kwargs:
  152. cond = cond and (torch.device(kwargs['device']).type == 'cuda')
  153. if 'dtype' in kwargs:
  154. cond = cond and (kwargs['dtype'] in {torch.float32, torch.complex64})
  155. if cond:
  156. with_tf32_disabled(kwargs['self'], lambda: f(**kwargs))
  157. with_tf32_enabled(kwargs['self'], lambda: f(**kwargs))
  158. else:
  159. f(**kwargs)
  160. return wrapped
  161. return wrapper
  162. # This is a wrapper that wraps a test to run it with TF32 turned off.
  163. # This wrapper is designed to be used when a test uses matmul or convolutions
  164. # but the purpose of that test is not testing matmul or convolutions.
  165. # Disabling TF32 will enforce torch.float tensors to be always computed
  166. # at full precision.
  167. def with_tf32_off(f):
  168. @functools.wraps(f)
  169. def wrapped(*args, **kwargs):
  170. with tf32_off():
  171. return f(*args, **kwargs)
  172. return wrapped
  173. def _get_magma_version():
  174. if 'Magma' not in torch.__config__.show():
  175. return (0, 0)
  176. position = torch.__config__.show().find('Magma ')
  177. version_str = torch.__config__.show()[position + len('Magma '):].split('\n')[0]
  178. return tuple(int(x) for x in version_str.split("."))
  179. def _get_torch_cuda_version():
  180. if torch.version.cuda is None:
  181. return (0, 0)
  182. cuda_version = str(torch.version.cuda)
  183. return tuple(int(x) for x in cuda_version.split("."))
  184. def _get_torch_rocm_version():
  185. if not TEST_WITH_ROCM:
  186. return (0, 0)
  187. rocm_version = str(torch.version.hip)
  188. rocm_version = rocm_version.split("-")[0] # ignore git sha
  189. return tuple(int(x) for x in rocm_version.split("."))
  190. def _check_cusparse_generic_available():
  191. return not TEST_WITH_ROCM
  192. def _check_hipsparse_generic_available():
  193. if not TEST_WITH_ROCM:
  194. return False
  195. rocm_version = str(torch.version.hip)
  196. rocm_version = rocm_version.split("-")[0] # ignore git sha
  197. rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
  198. return not (rocm_version_tuple is None or rocm_version_tuple < (5, 1))
  199. TEST_CUSPARSE_GENERIC = _check_cusparse_generic_available()
  200. TEST_HIPSPARSE_GENERIC = _check_hipsparse_generic_available()
  201. # Shared by test_torch.py and test_multigpu.py
  202. def _create_scaling_models_optimizers(device="cuda", optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None):
  203. # Create a module+optimizer that will use scaling, and a control module+optimizer
  204. # that will not use scaling, against which the scaling-enabled module+optimizer can be compared.
  205. mod_control = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)).to(device=device)
  206. mod_scaling = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)).to(device=device)
  207. with torch.no_grad():
  208. for c, s in zip(mod_control.parameters(), mod_scaling.parameters()):
  209. s.copy_(c)
  210. kwargs = {"lr": 1.0}
  211. if optimizer_kwargs is not None:
  212. kwargs.update(optimizer_kwargs)
  213. opt_control = optimizer_ctor(mod_control.parameters(), **kwargs)
  214. opt_scaling = optimizer_ctor(mod_scaling.parameters(), **kwargs)
  215. return mod_control, mod_scaling, opt_control, opt_scaling
  216. # Shared by test_torch.py, test_cuda.py and test_multigpu.py
  217. def _create_scaling_case(device="cuda", dtype=torch.float, optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None):
  218. data = [(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
  219. (torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
  220. (torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
  221. (torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device))]
  222. loss_fn = torch.nn.MSELoss().to(device)
  223. skip_iter = 2
  224. return _create_scaling_models_optimizers(
  225. device=device, optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs,
  226. ) + (data, loss_fn, skip_iter)
  227. # Importing this module should NOT eagerly initialize CUDA
  228. if not CUDA_ALREADY_INITIALIZED_ON_IMPORT:
  229. assert not torch.cuda.is_initialized()