Context.h 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578
  1. #pragma once
  2. #include <ATen/BlasBackend.h>
  3. #include <ATen/CPUGeneratorImpl.h>
  4. #include <ATen/DeviceAccelerator.h>
  5. #include <ATen/LinalgBackend.h>
  6. #include <ATen/core/ATenGeneral.h>
  7. #include <ATen/core/DeprecatedTypeProperties.h>
  8. #include <ATen/core/Generator.h>
  9. #include <ATen/core/LegacyTypeDispatch.h>
  10. #include <ATen/detail/AcceleratorHooksInterface.h>
  11. #include <ATen/detail/CUDAHooksInterface.h>
  12. #include <ATen/detail/HIPHooksInterface.h>
  13. #include <ATen/detail/IPUHooksInterface.h>
  14. #include <ATen/detail/MAIAHooksInterface.h>
  15. #include <ATen/detail/MPSHooksInterface.h>
  16. #include <ATen/detail/MTIAHooksInterface.h>
  17. #include <ATen/detail/PrivateUse1HooksInterface.h>
  18. #include <ATen/detail/XPUHooksInterface.h>
  19. #include <c10/core/QEngine.h>
  20. #include <c10/core/impl/DeviceGuardImplInterface.h>
  21. #include <c10/util/CallOnce.h>
  22. #include <c10/util/Exception.h>
  23. #include <c10/util/env.h>
  24. #include <c10/util/irange.h>
  25. #include <cstdint>
  26. #include <mutex>
  27. namespace at {
  28. class Tensor;
  29. enum class TORCH_API Float32MatmulPrecision { HIGHEST, HIGH, MEDIUM };
  30. class TORCH_API Context {
  31. public:
  32. Context();
  33. const Generator& defaultGenerator(Device device) {
  34. c10::DeviceType device_type = device.type();
  35. initCUDAIfNeeded(device_type);
  36. initHIPIfNeeded(device_type);
  37. if (device_type == at::kCPU) {
  38. return at::detail::getDefaultCPUGenerator();
  39. } else if (device_type == at::kCUDA) {
  40. return at::detail::getCUDAHooks().getDefaultCUDAGenerator(device.index());
  41. } else if (device_type == at::kMPS) {
  42. return at::detail::getMPSHooks().getDefaultMPSGenerator();
  43. } else if (device_type == at::kXPU) {
  44. return at::detail::getXPUHooks().getDefaultXPUGenerator(device.index());
  45. } else if (device_type == at::kIPU) {
  46. return at::detail::getIPUHooks().getDefaultIPUGenerator(device.index());
  47. } else if (device_type == at::kPrivateUse1) {
  48. return at::GetPrivateUse1HooksInterface()->getDefaultGenerator(
  49. device.index());
  50. } else {
  51. AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
  52. }
  53. }
  54. const AcceleratorHooksInterface& getAcceleratorHooksInterface(
  55. std::optional<c10::DeviceType> opt_device_type = c10::nullopt) {
  56. c10::DeviceType device_type = opt_device_type.has_value()
  57. ? opt_device_type.value()
  58. : at::getAccelerator(true).value();
  59. if (device_type == at::kCUDA) {
  60. return at::detail::getCUDAHooks();
  61. } else if (device_type == at::kMPS) {
  62. return at::detail::getMPSHooks();
  63. } else if (device_type == at::kPrivateUse1) {
  64. return at::detail::getPrivateUse1Hooks();
  65. } else if (device_type == at::kMTIA) {
  66. return at::detail::getMTIAHooks();
  67. } else {
  68. AT_ERROR(
  69. c10::DeviceTypeName(device_type), " device type not an accelerator.");
  70. }
  71. }
  72. Device getDeviceFromPtr(void* data, c10::DeviceType device_type) {
  73. initCUDAIfNeeded(device_type);
  74. initHIPIfNeeded(device_type);
  75. initXPUIfNeeded(device_type);
  76. if (device_type == at::kCPU) {
  77. return c10::DeviceType::CPU;
  78. } else if (device_type == at::kCUDA) {
  79. return at::detail::getCUDAHooks().getDeviceFromPtr(data);
  80. } else if (device_type == at::kXPU) {
  81. return at::detail::getXPUHooks().getDeviceFromPtr(data);
  82. } else if (device_type == at::kPrivateUse1) {
  83. return at::GetPrivateUse1HooksInterface()->getDeviceFromPtr(data);
  84. } else {
  85. AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
  86. }
  87. }
  88. static bool isPinnedPtr(const void* data) {
  89. return detail::getCUDAHooks().isPinnedPtr(data);
  90. }
  91. static bool hasOpenMP();
  92. static bool hasMKL();
  93. static bool hasLAPACK();
  94. static bool hasMKLDNN();
  95. static bool hasMAGMA() {
  96. return detail::getCUDAHooks().hasMAGMA();
  97. }
  98. static bool hasCUDA() {
  99. return detail::getCUDAHooks().hasCUDA();
  100. }
  101. static bool hasMTIA() {
  102. return detail::getMTIAHooks().hasMTIA();
  103. }
  104. static bool hasCUDART() {
  105. return detail::getCUDAHooks().hasCUDART();
  106. }
  107. static long versionCUDART() {
  108. return detail::getCUDAHooks().versionCUDART();
  109. }
  110. static bool hasCuDNN() {
  111. return detail::getCUDAHooks().hasCuDNN();
  112. }
  113. static long versionCuDNN() {
  114. return detail::getCUDAHooks().versionCuDNN();
  115. }
  116. static bool hasCuSOLVER() {
  117. return detail::getCUDAHooks().hasCuSOLVER();
  118. }
  119. static bool hasCuBLASLt() {
  120. return detail::getCUDAHooks().hasCuBLASLt();
  121. }
  122. static bool hasHIP() {
  123. return detail::getHIPHooks().hasHIP();
  124. }
  125. static bool hasMPS() {
  126. return detail::getMPSHooks().hasMPS();
  127. }
  128. static bool hasIPU() {
  129. return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU);
  130. }
  131. static bool hasXLA() {
  132. return c10::impl::hasDeviceGuardImpl(c10::DeviceType::XLA);
  133. }
  134. static bool hasXPU() {
  135. return detail::getXPUHooks().hasXPU();
  136. }
  137. static bool hasLazy() {
  138. return c10::impl::hasDeviceGuardImpl(c10::DeviceType::Lazy);
  139. }
  140. static bool hasMAIA() {
  141. return c10::impl::hasDeviceGuardImpl(c10::DeviceType::MAIA);
  142. }
  143. // defined in header so that getNonVariableType has ability to inline
  144. // call_once check. getNonVariableType is called fairly frequently
  145. void lazyInitCUDA() {
  146. c10::call_once(thc_init, [&] { detail::getCUDAHooks().initCUDA(); });
  147. }
  148. void lazyInitHIP() {
  149. c10::call_once(thh_init, [&] { detail::getHIPHooks().initHIP(); });
  150. }
  151. void lazyInitXPU() {
  152. c10::call_once(thx_init, [&] { detail::getXPUHooks().initXPU(); });
  153. }
  154. void lazyInitMTIA() {
  155. c10::call_once(th_mtia_init, [&] { detail::getMTIAHooks().initMTIA(); });
  156. }
  157. void lazyInitPrivateUse1() {
  158. c10::call_once(thp_init, [&] {
  159. if (isPrivateUse1HooksRegistered()) {
  160. at::GetPrivateUse1HooksInterface()->initPrivateUse1();
  161. }
  162. });
  163. }
  164. static const at::cuda::NVRTC& getNVRTC() {
  165. return detail::getCUDAHooks().nvrtc();
  166. }
  167. static bool setFlushDenormal(bool on);
  168. // NB: This method is *purely* whether or not a user requested
  169. // that CuDNN was enabled, it doesn't actually say anything about
  170. // whether or not CuDNN is actually usable. Use cudnn_is_acceptable
  171. // to test this instead
  172. bool userEnabledCuDNN() const;
  173. void setUserEnabledCuDNN(bool e);
  174. bool userEnabledMkldnn() const;
  175. void setUserEnabledMkldnn(bool e);
  176. bool benchmarkCuDNN() const;
  177. void setBenchmarkCuDNN(bool);
  178. int benchmarkLimitCuDNN() const;
  179. void setBenchmarkLimitCuDNN(int);
  180. bool deterministicCuDNN() const;
  181. void setDeterministicCuDNN(bool);
  182. bool userEnabledNNPACK() const;
  183. void setUserEnabledNNPACK(bool e);
  184. // Note [Disabling Fused SDP Kernels]
  185. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  186. // Flash and Memory Efficient SDP kernels are enabled by default.
  187. // However, they can be disabled by setting
  188. // at::globalContext().setUserEnabledFlashSDP(false) flag.
  189. // This is useful for debugging purposes. For example, if you want to
  190. // compare the performance of the flash SDP kernels with the unfused
  191. // kernel, you can disable the flash SDP kernels. By disabling
  192. // the math SDP kernel, you can force your code to use flash kernels.
  193. // The math SDP kernel can be disabled by setting
  194. // at::globalContext().setUserEnabledMathSDP(false) flag.
  195. void setSDPUseFlash(bool);
  196. bool userEnabledFlashSDP() const;
  197. void setSDPUseMemEfficient(bool);
  198. bool userEnabledMemEfficientSDP() const;
  199. void setSDPUseMath(bool);
  200. bool userEnabledMathSDP() const;
  201. void setSDPUseCuDNN(bool);
  202. bool userEnabledCuDNNSDP() const;
  203. at::LinalgBackend linalgPreferredBackend() const;
  204. void setLinalgPreferredBackend(at::LinalgBackend);
  205. at::BlasBackend blasPreferredBackend();
  206. void setBlasPreferredBackend(at::BlasBackend);
  207. // Note [Enabling Deterministic Operations]
  208. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  209. // Operations in PyTorch that normally act nondeterministically, but have an
  210. // alternate deterministic implementation, should satisfy the following
  211. // requirements:
  212. //
  213. // * Include this comment: "See Note [Enabling Deterministic Operations]"
  214. //
  215. // * Check the value of `at::globalContext().deterministicAlgorithms()` to
  216. // toggle
  217. // between nondeterministic and deterministic implementations.
  218. //
  219. // * Have an entry in the list of PyTorch operations that toggle between
  220. // nondeterministic
  221. // and deterministic implementations, in the docstring of
  222. // `use_deterministic_algorithms()` in torch/__init__.py
  223. //
  224. // `example_func()` below shows an example of toggling between
  225. // nondeterministic and deterministic implementations:
  226. //
  227. // void example_func() {
  228. // // See Note [Enabling Deterministic Operations]
  229. // if (at::globalContext().deterministicAlgorithms()) {
  230. // example_func_deterministic();
  231. // } else {
  232. // example_func_nondeterministic();
  233. // }
  234. // }
  235. bool deterministicAlgorithms() const;
  236. bool deterministicAlgorithmsWarnOnly() const;
  237. void setDeterministicAlgorithms(bool, bool);
  238. bool deterministicFillUninitializedMemory() const;
  239. void setDeterministicFillUninitializedMemory(bool);
  240. // Note [Writing Nondeterministic Operations]
  241. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  242. // Operations in PyTorch that act nondeterministically and do not have an
  243. // alternate deterministic implementation should satisfy the following
  244. // requirements:
  245. //
  246. // * Include this comment: "See Note [Writing Nondeterministic Operations]"
  247. //
  248. // * Include a comment explaining why the operation is nondeterministic.
  249. //
  250. // * Throw an error when `Context::deterministicAlgorithms()` is true. Most
  251. // of the time, this should be accomplished by calling
  252. // `at::globalContext().alertNotDeterminstic()`. However, if the
  253. // nondeterministic behavior is caused by the CuBLAS workspace
  254. // configuration in CUDA >= 10.2,
  255. // `at::globalContext().alertCuBLASConfigNotDeterministic()` should be
  256. // called instead (in this case, a comment explaining why the operation is
  257. // nondeterministic is not necessary). See below for details on these
  258. // methods.
  259. //
  260. // * Have an entry in the list of nondeterministic PyTorch operations in the
  261. // docstring of `use_deterministic_algorithms()` in torch/__init__.py
  262. //
  263. // * Have a test function in `test/test_torch.py` whose name begins with
  264. // `test_nondeterministic_alert_`. Alternatively, if CuBLAS workspace
  265. // configuration is the reason for nondeterminism, the operation should be
  266. // included in the `test_cublas_config_nondeterministic_alert` test. Any new
  267. // tests should ideally follow a pattern similar to the existing ones.
  268. //
  269. // `example_func()` below shows an example of the comments and error-throwing
  270. // code for a nondeterministic operation:
  271. //
  272. // void example_func() {
  273. // // See Note [Writing Nondeterministic Operations]
  274. // // Nondeterministic because <reason>
  275. // at::globalContext().alertNondeterministic("example_func");
  276. // ...
  277. // }
  278. // Throws an error if `Context::deterministicAlgorithms()` is true
  279. static void alertNotDeterministic(c10::string_view const& caller);
  280. // Throws an error if `Context::deterministicAlgorithms()` is true, CUDA
  281. // >= 10.2, and CUBLAS_WORKSPACE_CONFIG is not set to either ":16:8" or
  282. // ":4096:8". For more details:
  283. // https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility
  284. void alertCuBLASConfigNotDeterministic() const;
  285. void setFloat32MatmulPrecision(const std::string& s);
  286. bool allowTF32CuDNN() const;
  287. void setAllowTF32CuDNN(bool);
  288. bool allowTF32CuBLAS() const;
  289. void setAllowTF32CuBLAS(bool);
  290. Float32MatmulPrecision float32MatmulPrecision() const;
  291. void setFloat32MatmulPrecision(Float32MatmulPrecision p);
  292. bool allowFP16ReductionCuBLAS() const;
  293. void setAllowFP16ReductionCuBLAS(bool);
  294. bool allowBF16ReductionCuBLAS() const;
  295. void setAllowBF16ReductionCuBLAS(bool);
  296. at::QEngine qEngine() const;
  297. void setQEngine(at::QEngine e);
  298. static const std::vector<at::QEngine>& supportedQEngines();
  299. static bool isXNNPACKAvailable();
  300. void setCheckSparseTensorInvariants(bool e);
  301. bool checkSparseTensorInvariants() const;
  302. // This method is used to release the original weight after pre-packing.
  303. // It should be called once before loading/running the model.
  304. // NB: By default it is set to true for mobile builds.
  305. void setReleaseWeightsWhenPrepacking(bool e);
  306. bool releaseWeightsWhenPrepacking() const;
  307. void setDisplayVmapFallbackWarnings(bool enabled);
  308. bool areVmapFallbackWarningsEnabled() const;
  309. void setDefaultMobileCPUAllocator();
  310. void unsetDefaultMobileCPUAllocator();
  311. bool allowFP16ReductionCPU() const;
  312. void setAllowFP16ReductionCPU(bool);
  313. private:
  314. void initCUDAIfNeeded(c10::DeviceType p) {
  315. if (p == c10::DeviceType::CUDA) {
  316. lazyInitCUDA();
  317. }
  318. }
  319. void initHIPIfNeeded(c10::DeviceType p) {
  320. if (p == c10::DeviceType::HIP) {
  321. lazyInitHIP();
  322. }
  323. }
  324. void initXPUIfNeeded(c10::DeviceType p) {
  325. if (p == c10::DeviceType::XPU) {
  326. lazyInitXPU();
  327. }
  328. }
  329. static bool checkCuBLASConfigDeterministic();
  330. c10::once_flag thc_init;
  331. c10::once_flag thh_init;
  332. c10::once_flag thx_init;
  333. c10::once_flag th_mtia_init;
  334. c10::once_flag thp_init;
  335. bool enabled_cudnn = true;
  336. bool deterministic_cudnn = false;
  337. bool _deterministic_algorithms = false;
  338. bool _deterministic_algorithms_warn_only = false;
  339. bool _deterministic_fill_uninitialized_memory = true;
  340. bool enabled_flashSDP = true;
  341. bool enabled_mem_efficientSDP = true;
  342. bool enabled_mathSDP = true;
  343. bool enabled_cudnnSDP = false;
  344. #ifdef USE_ROCM
  345. bool benchmark_cudnn = true;
  346. #else
  347. bool benchmark_cudnn = false;
  348. #endif
  349. Float32MatmulPrecision float32_matmul_precision =
  350. c10::utils::check_env("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") == true
  351. ? at::Float32MatmulPrecision::HIGH
  352. : at::Float32MatmulPrecision::HIGHEST;
  353. int benchmark_limit_cudnn = 10;
  354. bool allow_tf32_cudnn = true;
  355. bool allow_fp16_reduction_cublas = true;
  356. bool allow_bf16_reduction_cublas = true;
  357. bool enabled_mkldnn = true;
  358. bool enabled_nnpack = true;
  359. at::LinalgBackend linalg_preferred_backend =
  360. c10::utils::check_env("TORCH_LINALG_PREFER_CUSOLVER") == true
  361. ? at::LinalgBackend::Cusolver
  362. : at::LinalgBackend::Default;
  363. at::BlasBackend blas_preferred_backend =
  364. (c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT") == true ||
  365. c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT") == true)
  366. ? at::BlasBackend::Cublaslt
  367. : at::BlasBackend::Cublas;
  368. #ifdef C10_MOBILE
  369. bool release_original_weights = true;
  370. #else
  371. bool release_original_weights = false;
  372. #endif
  373. bool display_vmap_fallback_warnings_ = false;
  374. std::optional<at::QEngine> quantized_engine = c10::nullopt;
  375. bool enable_sparse_tensor_invariant_checks = false;
  376. bool allow_fp16_reduction_cpu = false;
  377. Allocator* prev_allocator_ptr_{nullptr};
  378. };
  379. TORCH_API Context& globalContext();
  380. static inline void init() {
  381. globalContext();
  382. }
  383. TORCH_API Allocator* getCPUAllocator();
  384. static inline DeprecatedTypeProperties& getDeprecatedTypeProperties(
  385. Backend p,
  386. ScalarType s) {
  387. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  388. p, s);
  389. }
  390. static inline DeprecatedTypeProperties& CPU(ScalarType s) {
  391. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  392. Backend::CPU, s);
  393. }
  394. static inline DeprecatedTypeProperties& CUDA(ScalarType s) {
  395. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  396. Backend::CUDA, s);
  397. }
  398. static inline DeprecatedTypeProperties& HIP(ScalarType s) {
  399. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  400. Backend::HIP, s);
  401. }
  402. static inline DeprecatedTypeProperties& MPS(ScalarType s) {
  403. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  404. Backend::MPS, s);
  405. }
  406. static inline bool hasCUDA() {
  407. return globalContext().hasCUDA();
  408. }
  409. static inline bool hasMTIA() {
  410. return globalContext().hasMTIA();
  411. }
  412. static inline bool hasHIP() {
  413. return globalContext().hasHIP();
  414. }
  415. static inline bool hasIPU() {
  416. return globalContext().hasIPU();
  417. }
  418. static inline bool hasXLA() {
  419. return globalContext().hasXLA();
  420. }
  421. static inline bool hasMPS() {
  422. return globalContext().hasMPS();
  423. }
  424. static inline bool hasMAIA() {
  425. return globalContext().hasMAIA();
  426. }
  427. static inline bool hasXPU() {
  428. return globalContext().hasXPU();
  429. }
  430. // Despite its name, this function returns the number of *CUDA* GPUs.
  431. static inline size_t getNumGPUs() {
  432. // WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS
  433. // FUNCTION. If you are interested in interrogating the number of
  434. // devices for a specific device type, add that function to the
  435. // relevant library (e.g., similar to at::cuda::device_count())
  436. if (hasCUDA() && hasHIP()) {
  437. throw std::runtime_error(
  438. "Enabling both CUDA and HIP in ATen is not supported, as HIP masquerades "
  439. "to be CUDA (e.g., when you say CUDA, on a HIP build of ATen, this actually "
  440. "means HIP. Rebuild PyTorch with one or the other disabled.");
  441. } else if (hasCUDA()) {
  442. return detail::getCUDAHooks().getNumGPUs();
  443. } else if (hasHIP()) {
  444. return detail::getHIPHooks().getNumGPUs();
  445. } else {
  446. return 0;
  447. }
  448. }
  449. static inline bool hasOpenMP() {
  450. return globalContext().hasOpenMP();
  451. }
  452. static inline bool hasMKL() {
  453. return globalContext().hasMKL();
  454. }
  455. static inline bool hasLAPACK() {
  456. return globalContext().hasLAPACK();
  457. }
  458. static inline bool hasMAGMA() {
  459. return globalContext().hasMAGMA();
  460. }
  461. static inline bool hasMKLDNN() {
  462. return globalContext().hasMKLDNN();
  463. }
  464. static inline void manual_seed(uint64_t seed) {
  465. auto gen = globalContext().defaultGenerator(c10::DeviceType::CPU);
  466. {
  467. // See Note [Acquire lock when using random generators]
  468. std::lock_guard<std::mutex> lock(gen.mutex());
  469. gen.set_current_seed(seed);
  470. }
  471. // NB: Sometimes we build with CUDA, but we don't have any GPUs
  472. // available. In that case, we must not seed CUDA; it will fail!
  473. const auto cuda_num_gpus = detail::getCUDAHooks().getNumGPUs();
  474. if (hasCUDA() && cuda_num_gpus > 0) {
  475. for (const auto i : c10::irange(cuda_num_gpus)) {
  476. auto cuda_gen = globalContext().defaultGenerator(
  477. Device(at::kCUDA, static_cast<c10::DeviceIndex>(i)));
  478. {
  479. // See Note [Acquire lock when using random generators]
  480. std::lock_guard<std::mutex> lock(cuda_gen.mutex());
  481. cuda_gen.set_current_seed(seed);
  482. }
  483. }
  484. }
  485. const auto xpu_num_gpus = detail::getXPUHooks().getNumGPUs();
  486. if (hasXPU() && xpu_num_gpus) {
  487. for (const auto i : c10::irange(xpu_num_gpus)) {
  488. auto xpu_gen = globalContext().defaultGenerator(
  489. Device(at::kXPU, static_cast<c10::DeviceIndex>(i)));
  490. {
  491. // See Note [Acquire lock when using random generators]
  492. std::lock_guard<std::mutex> lock(xpu_gen.mutex());
  493. xpu_gen.set_current_seed(seed);
  494. }
  495. }
  496. }
  497. if (hasMPS()) {
  498. auto mps_gen = globalContext().defaultGenerator(c10::DeviceType::MPS);
  499. // See Note [Acquire lock when using random generators]
  500. std::lock_guard<std::mutex> lock(mps_gen.mutex());
  501. mps_gen.set_current_seed(seed);
  502. }
  503. }
  504. // When the global flag `allow_tf32` is set to true, cuBLAS handles are
  505. // automatically configured to use math mode CUBLAS_TF32_TENSOR_OP_MATH.
  506. // For some operators, such as addmv, TF32 offers no performance improvement
  507. // but causes precision loss. To help this case, this class implements
  508. // a RAII guard that can be used to quickly disable TF32 within its scope.
  509. //
  510. // Usage:
  511. // NoTF32Guard disable_tf32;
  512. struct TORCH_API NoTF32Guard {
  513. NoTF32Guard();
  514. ~NoTF32Guard();
  515. static bool should_disable_tf32();
  516. private:
  517. bool changed = false;
  518. };
  519. struct TORCH_API ROCmBackwardPassGuard {
  520. ROCmBackwardPassGuard();
  521. ~ROCmBackwardPassGuard();
  522. static bool is_backward_pass();
  523. };
  524. } // namespace at