DistributionTemplates.h 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672
  1. #pragma once
  2. #include <ATen/AccumulateType.h>
  3. #include <ATen/Dispatch.h>
  4. #include <ATen/Dispatch_v2.h>
  5. #include <ATen/ExpandBase.h>
  6. #include <ATen/OpMathType.h>
  7. #include <ATen/native/TensorIterator.h>
  8. #include <ATen/native/cuda/Loops.cuh>
  9. #include <c10/util/Half.h>
  10. #include <ATen/cuda/CUDAApplyUtils.cuh>
  11. #include <ATen/cuda/CUDAContext.h>
  12. #include <ATen/cuda/detail/OffsetCalculator.cuh>
  13. #include <ATen/cuda/CUDAGraphsUtils.cuh>
  14. #include <ATen/detail/FunctionTraits.h>
  15. #include <ATen/core/DistributionsHelper.h>
  16. #include <curand.h>
  17. #include <curand_kernel.h>
  18. #include <curand_philox4x32_x.h>
  19. #include <cstdint>
  20. #include <limits>
  21. #include <utility>
  22. #include <mutex>
  23. #include <tuple>
  24. #include <type_traits>
  25. namespace at {
  26. namespace native {
  27. namespace {
  28. // launch bounds used for kernels utilizing TensorIterator
  29. const uint32_t block_size_bound = 256;
  30. const uint32_t grid_size_bound = 4;
  31. // number of randoms given by distributions like curand_uniform4, curand_uniform2_double
  32. // used in calculating philox offset.
  33. const uint32_t curand4_engine_calls = 4;
  34. // utility function that calculates proper philox_offset
  35. // for distributions utilizing TensorIterator. For distributions using
  36. // TensorIterator, we are using a grid-stride loop with each
  37. // thread yielding one element per thread. For the edge of the grid-stride
  38. // loop, if the tensor size is large, the unroll loop will kick in and the float4
  39. // from curand4 will start getting utilized (for common tensor sizes, we end up
  40. // using rand.x from each thread). Hence, the philox_offset is
  41. // (number of elements per thread * number of engine calls), which makes
  42. // sure that philox offset increment is not less than the number of randoms used
  43. // in each thread.
  44. std::tuple<uint64_t, dim3, dim3> calc_execution_policy(int64_t total_elements) {
  45. const uint64_t numel = static_cast<uint64_t>(total_elements);
  46. const uint32_t block_size = block_size_bound;
  47. const uint32_t unroll = curand4_engine_calls;
  48. dim3 dim_block(block_size);
  49. dim3 grid((numel + block_size - 1) / block_size);
  50. uint32_t blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size;
  51. grid.x = std::min(
  52. static_cast<uint32_t>(at::cuda::getCurrentDeviceProperties()->multiProcessorCount) * blocks_per_sm,
  53. grid.x);
  54. //number of times random will be generated per thread, to offset philox counter in thc random state
  55. uint64_t counter_offset = ((numel - 1) / (block_size * grid.x * unroll) + 1)
  56. * curand4_engine_calls;
  57. return std::make_tuple(counter_offset, grid, dim_block);
  58. }
  59. // grid stride loop kernel for distributions
  60. template<typename accscalar_t, int unroll_factor, typename dist_t, typename transform_t>
  61. C10_LAUNCH_BOUNDS_2(block_size_bound, grid_size_bound)
  62. __global__ void distribution_elementwise_grid_stride_kernel(int numel,
  63. PhiloxCudaState philox_args,
  64. const dist_t dist_func,
  65. const transform_t transform_func) {
  66. auto seeds = at::cuda::philox::unpack(philox_args);
  67. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  68. curandStatePhilox4_32_10_t state;
  69. curand_init(std::get<0>(seeds),
  70. idx,
  71. std::get<1>(seeds),
  72. &state);
  73. int rounded_size = ((numel - 1)/(blockDim.x * gridDim.x * unroll_factor)+1) *
  74. blockDim.x * gridDim.x * unroll_factor;
  75. for(int linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) {
  76. auto rand = dist_func(&state);
  77. #pragma unroll
  78. for (int ii = 0; ii < unroll_factor; ii++) {
  79. int li = linear_index + blockDim.x * gridDim.x * ii;
  80. if (li < numel) {
  81. transform_func(li, static_cast<accscalar_t>((&rand.x)[ii]));
  82. }
  83. }
  84. __syncthreads();
  85. }
  86. }
  87. /**
  88. * distribution_nullary_kernel is analogous to gpu_kernel in
  89. * ATen/native/cuda/Loops.cuh. Like gpu_kernel, it uses
  90. * TensorIterator to launch a kernel. However, the differences are
  91. * - it launches a grid-stride loop based kernel. The kernel is not
  92. * generic like elementwise_kernel in Loops.cuh and is specialized
  93. * for the distribution kernels here.
  94. * - For big size tensors, we can launch multiple kernels recursively
  95. * (i.e. if (!iter.can_use_32bit_indexing())) and hence, the philox
  96. * offset calculation is done in this function.
  97. *
  98. * FIXME: Can we specialize elementwise_kernel and launch_kernel in Loops.cuh
  99. * to have grid-stride loop kernel and then use that to launch our distribution
  100. * kernels? Note that we need a grid-stride loop kernel because, we found by testing
  101. * that it achieves peak effective bandwidth.
  102. */
  103. template<typename scalar_t,
  104. typename accscalar_t,
  105. int unroll_factor,
  106. typename RNG,
  107. typename dist_t,
  108. typename transform_t>
  109. void distribution_nullary_kernel(at::TensorIteratorBase& iter,
  110. RNG gen,
  111. const dist_t& dist_func,
  112. const transform_t transform_func) {
  113. static_assert(unroll_factor >= 1, "unroll_factor must be >= 1.");
  114. int64_t numel = iter.numel();
  115. if (numel == 0) {
  116. return;
  117. }
  118. auto execution_policy = calc_execution_policy(numel);
  119. auto counter_offset = std::get<0>(execution_policy);
  120. auto grid = std::get<1>(execution_policy);
  121. auto block = std::get<2>(execution_policy);
  122. PhiloxCudaState rng_engine_inputs;
  123. {
  124. // See Note [Acquire lock when using random generators]
  125. std::lock_guard<std::mutex> lock(gen->mutex_);
  126. rng_engine_inputs = gen->philox_cuda_state(counter_offset);
  127. }
  128. if (!iter.can_use_32bit_indexing()) {
  129. for (auto& sub_iter : iter.with_32bit_indexing()) {
  130. distribution_nullary_kernel<scalar_t, accscalar_t, unroll_factor>(sub_iter,
  131. gen, dist_func, transform_func);
  132. }
  133. return;
  134. }
  135. char* out_data = (char*)iter.data_ptr(0);
  136. auto stream = at::cuda::getCurrentCUDAStream();
  137. if (iter.is_trivial_1d()) {
  138. auto strides = iter.get_inner_strides();
  139. int stride0 = strides[0];
  140. distribution_elementwise_grid_stride_kernel<accscalar_t, unroll_factor><<<grid, block, 0, stream>>>(
  141. numel,
  142. rng_engine_inputs,
  143. dist_func,
  144. [=]__device__(int idx, accscalar_t rand) {
  145. scalar_t* out = (scalar_t*)&out_data[stride0 * idx];
  146. *out = transform_func(rand);
  147. }
  148. );
  149. C10_CUDA_KERNEL_LAUNCH_CHECK();
  150. } else {
  151. auto offset_calc = make_offset_calculator<1>(iter);
  152. distribution_elementwise_grid_stride_kernel<accscalar_t, unroll_factor><<<grid, block, 0, stream>>>(
  153. numel,
  154. rng_engine_inputs,
  155. dist_func,
  156. [=]__device__(int idx, accscalar_t rand) {
  157. auto offsets = offset_calc.get(idx);
  158. scalar_t* out = (scalar_t*)&out_data[offsets[0]];
  159. *out = transform_func(rand);
  160. }
  161. );
  162. C10_CUDA_KERNEL_LAUNCH_CHECK();
  163. }
  164. }
  165. // Binary kernel
  166. template <typename func_t, typename inp_offset_calc_t, typename out_offset_calc_t>
  167. __global__ void distribution_binary_elementwise_kernel(
  168. int numel,
  169. func_t f,
  170. PhiloxCudaState philox_args,
  171. typename function_traits<func_t>::result_type *output_data,
  172. const typename function_traits<func_t>::template arg<1>::type *input_data_1,
  173. const typename function_traits<func_t>::template arg<2>::type *input_data_2,
  174. inp_offset_calc_t inp_calc,
  175. out_offset_calc_t out_calc) {
  176. auto seeds = at::cuda::philox::unpack(philox_args);
  177. using input_t_1 = typename function_traits<func_t>::template arg<1>::type;
  178. using input_t_2 = typename function_traits<func_t>::template arg<2>::type;
  179. input_t_1 inputs_1[thread_work_size()];
  180. input_t_2 inputs_2[thread_work_size()];
  181. int base_index = block_work_size() * blockIdx.x;
  182. int remaining = std::min<int>(numel - base_index, block_work_size());
  183. curandStatePhilox4_32_10_t state;
  184. curand_init(std::get<0>(seeds),
  185. blockIdx.x * blockDim.x + threadIdx.x,
  186. std::get<1>(seeds),
  187. &state);
  188. // load data into registers
  189. int thread_idx = threadIdx.x;
  190. #pragma unroll
  191. for (int i = 0; i < thread_work_size(); i++) {
  192. if (thread_idx >= remaining) {
  193. break;
  194. }
  195. int input_idx = thread_idx + base_index;
  196. auto offsets = inp_calc.get(input_idx);
  197. inputs_1[i] = input_data_1[offsets[0]];
  198. inputs_2[i] = input_data_2[offsets[1]];
  199. thread_idx += num_threads();
  200. }
  201. // compute and store
  202. thread_idx = threadIdx.x;
  203. #pragma unroll
  204. for (int i = 0; i < thread_work_size(); i++) {
  205. if (thread_idx >= remaining) {
  206. break;
  207. }
  208. int input_idx = thread_idx + base_index;
  209. auto offsets = out_calc.get(input_idx);
  210. output_data[offsets[0]] = f(state, inputs_1[i], inputs_2[i]);
  211. thread_idx += num_threads();
  212. }
  213. }
  214. template <typename func_t>
  215. void distribution_binary_kernel(TensorIteratorBase &iter, PhiloxCudaState philox_args, const func_t &f) {
  216. static_assert(std::is_same<typename function_traits<func_t>::template arg<0>::type, curandStatePhilox4_32_10_t&>::value, "the first argument of functor must be curandStatePhilox4_32_10_t");
  217. using input_t_1 = typename function_traits<func_t>::template arg<1>::type;
  218. using input_t_2 = typename function_traits<func_t>::template arg<2>::type;
  219. using output_t = typename function_traits<func_t>::result_type;
  220. if (!iter.can_use_32bit_indexing()) {
  221. for (auto& sub_iter : iter.with_32bit_indexing()) {
  222. distribution_binary_kernel(sub_iter, philox_args, f);
  223. }
  224. return;
  225. }
  226. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(iter.can_use_32bit_indexing());
  227. int64_t numel = iter.numel();
  228. if (numel == 0) {
  229. return;
  230. }
  231. output_t *output_data = static_cast<output_t *>(iter.data_ptr(0));
  232. const input_t_1 *input_data_1 = static_cast<const input_t_1 *>(iter.data_ptr(1));
  233. const input_t_2 *input_data_2 = static_cast<const input_t_2 *>(iter.data_ptr(2));
  234. int64_t grid = (numel + block_work_size() - 1) / block_work_size();
  235. auto stream = at::cuda::getCurrentCUDAStream();
  236. if (iter.is_contiguous()) {
  237. distribution_binary_elementwise_kernel<<<grid,num_threads(), 0, stream>>>(
  238. numel, f, philox_args, output_data, input_data_1, input_data_2,
  239. TrivialOffsetCalculator<2>(), TrivialOffsetCalculator<1>());
  240. C10_CUDA_KERNEL_LAUNCH_CHECK();
  241. } else {
  242. distribution_binary_elementwise_kernel<<<grid, num_threads(), 0, stream>>>(
  243. numel, f, philox_args, output_data, input_data_1, input_data_2,
  244. make_input_offset_calculator<2>(iter), make_output_offset_calculator(iter));
  245. C10_CUDA_KERNEL_LAUNCH_CHECK();
  246. }
  247. }
  248. } // namespace
  249. }} // namespace at::native
  250. namespace at {
  251. namespace native {
  252. namespace templates {
  253. namespace cuda {
  254. // ==================================================== Random ========================================================
  255. template<typename RNG>
  256. void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG gen) {
  257. AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cuda", AT_WRAP([&] {
  258. if ((
  259. std::is_same<scalar_t, int64_t>::value ||
  260. std::is_same<scalar_t, double>::value ||
  261. std::is_same<scalar_t, float>::value ||
  262. std::is_same<scalar_t, at::BFloat16>::value) && range >= 1ULL << 32)
  263. {
  264. // define lambda to mod with range and add base
  265. auto random_func = [range, base] __device__ (uint64_t rand) {
  266. return transformation::uniform_int_from_to<scalar_t>(rand, range, base);
  267. };
  268. distribution_nullary_kernel<scalar_t, uint64_t, curand4_engine_calls/2>(iter,
  269. gen,
  270. [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
  271. ulonglong2 ret;
  272. uint4 rand_val = curand4(state);
  273. ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
  274. ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
  275. return ret;
  276. },
  277. random_func);
  278. } else {
  279. auto random_func = [range, base] __device__ (uint32_t rand) {
  280. return transformation::uniform_int_from_to<scalar_t>(rand, range, base);
  281. };
  282. distribution_nullary_kernel<scalar_t, uint32_t, curand4_engine_calls>(iter,
  283. gen,
  284. [] __device__ (curandStatePhilox4_32_10_t* state) {
  285. return curand4(state);
  286. },
  287. random_func);
  288. }
  289. }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
  290. }
  291. // This is the special kernel to handle single specific case:
  292. // from(inclusive) = std::numeric_limits<int64_t>::lowest()
  293. // to(exclusive) = None (= std::numeric_limits<int64_t>::max() + 1)
  294. template<typename RNG>
  295. void random_full_64_bits_range_kernel(TensorIteratorBase& iter, RNG gen) {
  296. AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_cuda", [&] {
  297. if (std::is_same<scalar_t, int64_t>::value ||
  298. std::is_same<scalar_t, double>::value ||
  299. std::is_same<scalar_t, float>::value ||
  300. std::is_same<scalar_t, at::BFloat16>::value) {
  301. auto random_func = [] __device__ (uint64_t rand) {
  302. return transformation::uniform_int_full_range<scalar_t>(rand);
  303. };
  304. distribution_nullary_kernel<scalar_t, uint64_t, curand4_engine_calls/2>(iter,
  305. gen,
  306. [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
  307. ulonglong2 ret;
  308. uint4 rand_val = curand4(state);
  309. ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
  310. ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
  311. return ret;
  312. },
  313. random_func);
  314. } else {
  315. TORCH_CHECK(false, "random_full_64_bits_range_kernel_cuda handles only int64, double, float and bfloat16");
  316. }
  317. });
  318. }
  319. template<typename RNG>
  320. struct RandomFromToKernel {
  321. void operator()(TensorIteratorBase& iter, uint64_t range, int64_t base, std::optional<Generator> gen) {
  322. random_from_to_kernel(iter, range, base, check_generator<RNG>(gen));
  323. }
  324. void operator()(TensorIteratorBase& iter, std::optional<Generator> gen) {
  325. random_full_64_bits_range_kernel(iter, check_generator<RNG>(gen));
  326. }
  327. };
  328. template<typename RNG>
  329. void random_kernel(TensorIteratorBase& iter, RNG gen) {
  330. AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cuda", [&] {
  331. if (std::is_same<scalar_t, double>::value || std::is_same<scalar_t, int64_t>::value) {
  332. auto random_func = [] __device__ (uint64_t rand) {
  333. return transformation::uniform_int<scalar_t>(rand);
  334. };
  335. distribution_nullary_kernel<scalar_t, uint64_t, curand4_engine_calls/2>(iter, gen,
  336. [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
  337. ulonglong2 ret;
  338. uint4 rand_val = curand4(state);
  339. ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
  340. ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
  341. return ret;
  342. },
  343. random_func);
  344. } else {
  345. auto random_func = [] __device__ (uint32_t rand) {
  346. return transformation::uniform_int<scalar_t>(rand);
  347. };
  348. distribution_nullary_kernel<scalar_t, uint32_t, curand4_engine_calls>(iter,
  349. gen,
  350. [] __device__ (curandStatePhilox4_32_10_t* state) {
  351. return curand4(state);
  352. },
  353. random_func);
  354. }
  355. });
  356. }
  357. template<typename RNG>
  358. struct RandomKernel {
  359. void operator()(TensorIteratorBase& iter, RNG gen) {
  360. random_kernel(iter, gen);
  361. }
  362. };
  363. // ====================================================================================================================
  364. template<typename scalar_t, typename accscalar_t, size_t curand4_engine_calls, typename RNG, typename transform_t>
  365. void uniform_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transform) {
  366. if (std::is_same<scalar_t, double>::value) {
  367. distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls/2>(iter,
  368. gen,
  369. [] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform2_double(state); },
  370. transform);
  371. } else {
  372. distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls>(iter,
  373. gen,
  374. [] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform4(state); },
  375. transform);
  376. }
  377. }
  378. template<typename scalar_t, typename accscalar_t, size_t curand4_engine_calls, typename RNG, typename transform_t>
  379. void normal_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transform) {
  380. if (std::is_same<scalar_t, double>::value) {
  381. distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls/2>(iter,
  382. gen,
  383. [] __device__ (curandStatePhilox4_32_10_t* state) { return curand_normal2_double(state); },
  384. transform);
  385. } else {
  386. distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls>(iter,
  387. gen,
  388. [] __device__ (curandStatePhilox4_32_10_t* state) { return curand_normal4(state); },
  389. transform);
  390. }
  391. }
  392. // ==================================================== Normal ========================================================
  393. template<typename RNG>
  394. void normal_kernel(const TensorBase &self, double mean_, double std_, RNG gen) {
  395. auto iter = TensorIterator::borrowing_nullary_op(self);
  396. AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "normal_kernel_cuda", [&] {
  397. using accscalar_t = at::acc_type<scalar_t, true>;
  398. auto mean = static_cast<accscalar_t>(mean_);
  399. auto std = static_cast<accscalar_t>(std_);
  400. // define lambda to multiply std and add mean
  401. auto normal_func = [mean, std] __device__ (accscalar_t rand) {
  402. return static_cast<scalar_t>(transformation::normal<accscalar_t>(rand, mean, std));
  403. };
  404. normal_and_transform<scalar_t, accscalar_t, curand4_engine_calls>(iter, gen, normal_func);
  405. });
  406. }
  407. template<typename RNG>
  408. struct NormalKernel {
  409. void operator()(const TensorBase &self, double mean, double std, std::optional<Generator> gen) {
  410. normal_kernel(self, mean, std, check_generator<RNG>(gen));
  411. }
  412. };
  413. // ==================================================== Uniform ========================================================
  414. template<typename RNG>
  415. void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG gen) {
  416. AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "uniform_kernel_cuda", [&] {
  417. auto from = static_cast<scalar_t>(from_);
  418. auto to = static_cast<scalar_t>(to_);
  419. using opmath_t = at::opmath_type<scalar_t>;
  420. auto range = static_cast<opmath_t>(to-from);
  421. // define lambda to reverse bounds, multiply 'range' and add 'from_'
  422. auto uniform_func = [range, from, to] __device__ (opmath_t rand) {
  423. // Compute output value before reversing the bounds
  424. // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/96947
  425. auto value = static_cast<scalar_t>(rand * range + from);
  426. // reverse the bounds of curand4 from (0, 1] to [0, 1)
  427. // Note that this method is from legacy THCTensorRandom and is likely to give
  428. // you more 0-s, since, the probability of gettings 1-s is higher than 0-s and
  429. // by reversing the bounds, we are flipping the probabilities of 1-s and 0-s.
  430. // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/16706
  431. auto reverse_bound_value = value == to ? from : value;
  432. return reverse_bound_value;
  433. };
  434. uniform_and_transform<scalar_t, opmath_t, curand4_engine_calls>(iter, gen, uniform_func);
  435. });
  436. }
  437. template<typename RNG>
  438. struct UniformKernel {
  439. void operator()(TensorIteratorBase& iter, double from, double to, std::optional<Generator> gen) {
  440. uniform_kernel(iter, from, to, check_generator<RNG>(gen));
  441. }
  442. };
  443. // ================================================== LogNormal =======================================================
  444. template<typename RNG>
  445. void log_normal_kernel(TensorIteratorBase& iter, double mean_, double std_, RNG gen) {
  446. AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_cuda", [&] {
  447. using accscalar_t = at::acc_type<scalar_t, true>;
  448. auto mean = static_cast<accscalar_t>(mean_);
  449. auto std = static_cast<accscalar_t>(std_);
  450. // define lambda for log_normal transformation
  451. auto log_normal_func = [mean, std] __device__ (accscalar_t rand) {
  452. return static_cast<scalar_t>(transformation::log_normal<accscalar_t>(transformation::normal<accscalar_t>(rand, mean, std)));
  453. };
  454. normal_and_transform<scalar_t, accscalar_t, curand4_engine_calls>(iter, gen, log_normal_func);
  455. });
  456. }
  457. template<typename RNG>
  458. struct LogNormalKernel {
  459. void operator()(TensorIteratorBase& iter, double mean, double std, std::optional<Generator> gen) {
  460. log_normal_kernel(iter, mean, std, check_generator<RNG>(gen));
  461. }
  462. };
  463. // =================================================== Geometric ======================================================
  464. template<typename RNG>
  465. void geometric_kernel(TensorIteratorBase& iter, double p, RNG gen) {
  466. AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cuda", [&] {
  467. using accscalar_t = at::DiscreteDistributionType<scalar_t>::type;
  468. // define lambda for geometric transformation
  469. auto geometric_func = [p] __device__ (accscalar_t rand) {
  470. return static_cast<scalar_t>(transformation::geometric<accscalar_t>(rand, p));
  471. };
  472. uniform_and_transform<scalar_t, accscalar_t, curand4_engine_calls>(iter, gen, geometric_func);
  473. });
  474. }
  475. template<typename RNG>
  476. struct GeometricKernel {
  477. void operator()(TensorIteratorBase& iter, double p, std::optional<Generator> gen) {
  478. geometric_kernel(iter, p, check_generator<RNG>(gen));
  479. }
  480. };
  481. // ================================================== Exponential =====================================================
  482. template<typename RNG>
  483. void exponential_kernel(TensorIteratorBase& iter, double lambda_, RNG gen) {
  484. TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype());
  485. AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cuda", [&] {
  486. using accscalar_t = at::acc_type<scalar_t, true>;
  487. auto lambda = static_cast<accscalar_t>(lambda_);
  488. // define lambda for exponential transformation
  489. auto exponential_func = [lambda] __device__ (accscalar_t rand) {
  490. return static_cast<scalar_t>(transformation::exponential<accscalar_t>(rand, lambda));
  491. };
  492. uniform_and_transform<scalar_t, accscalar_t, curand4_engine_calls>(iter, gen, exponential_func);
  493. });
  494. }
  495. template<typename RNG>
  496. struct ExponentialKernel {
  497. void operator()(TensorIteratorBase& iter, double lambda, std::optional<Generator> gen) {
  498. exponential_kernel(iter, lambda, check_generator<RNG>(gen));
  499. }
  500. };
  501. // ==================================================== Cauchy ========================================================
  502. template<typename RNG>
  503. void cauchy_kernel(TensorIteratorBase& iter, double median_, double sigma_, RNG gen) {
  504. AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "cauchy_cuda", [&] {
  505. using accscalar_t = at::acc_type<scalar_t, true>;
  506. auto median = static_cast<accscalar_t>(median_);
  507. auto sigma = static_cast<accscalar_t>(sigma_);
  508. // define lambda for cauchy transformation
  509. auto cauchy_func = [median, sigma] __device__ (accscalar_t rand) {
  510. return static_cast<scalar_t>(transformation::cauchy<accscalar_t>(rand, median, sigma));
  511. };
  512. uniform_and_transform<scalar_t, accscalar_t, curand4_engine_calls>(iter, gen, cauchy_func);
  513. });
  514. }
  515. template<typename RNG>
  516. struct CauchyKernel {
  517. void operator()(TensorIteratorBase& iter, double median, double sigma, std::optional<Generator> gen) {
  518. cauchy_kernel(iter, median, sigma, check_generator<RNG>(gen));
  519. }
  520. };
  521. // ==================================================== Bernoulli =====================================================
  522. template<typename scalar_t, typename prob_t>
  523. void bernoulli_tensor_cuda_kernel(
  524. const TensorBase &ret, const at::TensorBase &p,
  525. PhiloxCudaState philox_args) {
  526. auto functor = [philox_args] __device__(
  527. int n, scalar_t& v1, scalar_t& v2, scalar_t& v3, scalar_t& v4,
  528. const prob_t& p1, const prob_t& p2, const prob_t& p3, const prob_t& p4) {
  529. auto seeds = at::cuda::philox::unpack(philox_args);
  530. curandStatePhilox4_32_10_t state;
  531. curand_init(std::get<0>(seeds),
  532. blockIdx.x * blockDim.x + threadIdx.x,
  533. std::get<1>(seeds),
  534. &state);
  535. // See Note [Register spilling in curand call for CUDA < 10]
  536. float4 rand = curand_uniform4(&state);
  537. switch (n) {
  538. case 4: {
  539. CUDA_KERNEL_ASSERT(0 <= p4 && p4 <= 1);
  540. v4 = static_cast<scalar_t>(rand.w <= p4);
  541. // fallthrough
  542. }
  543. case 3: {
  544. CUDA_KERNEL_ASSERT(0 <= p3 && p3 <= 1);
  545. v3 = static_cast<scalar_t>(rand.z <= p3);
  546. // fallthrough
  547. }
  548. case 2: {
  549. CUDA_KERNEL_ASSERT(0 <= p2 && p2 <= 1);
  550. v2 = static_cast<scalar_t>(rand.y <= p2);
  551. // fallthrough
  552. }
  553. case 1: {
  554. CUDA_KERNEL_ASSERT(0 <= p1 && p1 <= 1);
  555. v1 = static_cast<scalar_t>(rand.x <= p1);
  556. }
  557. }
  558. };
  559. // The template argument `4` below indicates that we want to operate on four
  560. // element at each time. See NOTE [ CUDA_tensor_applyN helpers ] for details.
  561. at::cuda::CUDA_tensor_apply2<scalar_t, const prob_t, 4, decltype(functor),
  562. /*max_threads_per_block=*/512,
  563. /*min_blocks_per_sm==*/2>(ret, p, functor);
  564. }
  565. template<typename RNG>
  566. void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG gen) {
  567. PhiloxCudaState rng_engine_inputs;
  568. {
  569. // See Note [Acquire lock when using random generators]
  570. std::lock_guard<std::mutex> lock(gen->mutex_);
  571. rng_engine_inputs = gen->philox_cuda_state(10);
  572. }
  573. TORCH_CHECK(at::isFloatingType(p_.scalar_type()), "expected probabilities tensor to have floating type, got ", p_.scalar_type());
  574. // cast probabilities tensor to double for double `self` tensor, and to `float` for everything else
  575. const auto p_type = self.dtype() == at::kDouble ? at::kDouble : at::kFloat;
  576. auto p_cuda = p_.to(TensorOptions().device(self.device()).dtype(p_type));
  577. auto p = expand_inplace(self, p_cuda);
  578. AT_DISPATCH_ALL_TYPES_AND3(
  579. at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, self.scalar_type(), "bernoulli_tensor_cuda_self_", [&] {
  580. if (std::is_same<scalar_t, double>::value) {
  581. return bernoulli_tensor_cuda_kernel<double, double>(self, *p, rng_engine_inputs);
  582. } else {
  583. return bernoulli_tensor_cuda_kernel<scalar_t, float>(self, *p, rng_engine_inputs);
  584. }
  585. });
  586. }
  587. template<typename RNG>
  588. void bernoulli_kernel(TensorIteratorBase& iter, double p, RNG gen) {
  589. AT_DISPATCH_ALL_TYPES_AND3(
  590. at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "bernoulli_scalar_cuda_", [&] {
  591. using accscalar_t = at::DiscreteDistributionType<scalar_t>::type;
  592. // define lambda for bernoulli transformation
  593. auto bernoulli_func = [p] __device__ (accscalar_t rand) {
  594. return static_cast<scalar_t>(transformation::bernoulli<accscalar_t>(rand, p));
  595. };
  596. uniform_and_transform<scalar_t, accscalar_t, curand4_engine_calls>(iter, gen, bernoulli_func);
  597. });
  598. }
  599. template<typename RNG>
  600. struct BernoulliKernel {
  601. void operator()(TensorIteratorBase& iter, double p, std::optional<Generator> gen) {
  602. bernoulli_kernel(iter, p, check_generator<RNG>(gen));
  603. }
  604. void operator()(const TensorBase &self, const TensorBase &p_, std::optional<Generator> gen) {
  605. bernoulli_kernel(self, p_, check_generator<RNG>(gen));
  606. }
  607. };
  608. }}}}