DistributionTemplates.h 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. #pragma once
  2. #include <ATen/CPUApplyUtils.h>
  3. #include <ATen/Dispatch.h>
  4. #include <ATen/Dispatch_v2.h>
  5. #include <ATen/ExpandBase.h>
  6. #include <ATen/core/DistributionsHelper.h>
  7. #include <ATen/native/TensorIterator.h>
  8. #include <ATen/native/cpu/Loops.h>
  9. #include <limits>
  10. #include <mutex>
  11. #ifdef CPU_CAPABILITY_AVX2
  12. #include <ATen/native/cpu/avx_mathfun.h>
  13. #include <c10/util/irange.h>
  14. #endif
  15. namespace at {
  16. namespace native {
  17. namespace templates {
  18. namespace cpu {
  19. namespace {
  20. // ==================================================== Random ========================================================
  21. template<typename RNG>
  22. void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG generator) {
  23. AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cpu", AT_WRAP([&] {
  24. std::lock_guard<std::mutex> lock(generator->mutex_);
  25. cpu_serial_kernel(iter, [range, base, generator]() -> scalar_t {
  26. uniform_int_from_to_distribution<scalar_t> random(range, base);
  27. return random(generator);
  28. });
  29. }), kBool, kHalf, kBFloat16, AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
  30. }
  31. // This is the special kernel to handle single specific case:
  32. // from(inclusive) = std::numeric_limits<int64_t>::lowest()
  33. // to(exclusive) = None (= std::numeric_limits<int64_t>::max() + 1)
  34. template<typename RNG>
  35. void random_full_64_bits_range_kernel(TensorIteratorBase& iter, RNG generator) {
  36. AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_cpu", [&] {
  37. if constexpr (std::is_same<scalar_t, int64_t>::value ||
  38. std::is_same<scalar_t, double>::value ||
  39. std::is_same<scalar_t, float>::value ||
  40. std::is_same<scalar_t, at::BFloat16>::value) {
  41. std::lock_guard<std::mutex> lock(generator->mutex_);
  42. cpu_serial_kernel(iter, [generator]() -> scalar_t {
  43. uniform_int_full_range_distribution<scalar_t> random;
  44. return random(generator);
  45. });
  46. } else {
  47. TORCH_CHECK(false, "random_full_64_bits_range_kernel_cpu handles only int64, double, float and bfloat16");
  48. }
  49. });
  50. }
  51. template<typename RNG>
  52. struct RandomFromToKernel {
  53. void operator()(TensorIteratorBase& iter, uint64_t range, int64_t base, std::optional<Generator> gen) {
  54. random_from_to_kernel(iter, range, base, check_generator<RNG>(gen));
  55. }
  56. void operator()(TensorIteratorBase& iter, std::optional<Generator> gen) {
  57. random_full_64_bits_range_kernel(iter, check_generator<RNG>(gen));
  58. }
  59. };
  60. template<typename RNG>
  61. void random_kernel(TensorIteratorBase& iter, RNG generator) {
  62. std::lock_guard<std::mutex> lock(generator->mutex_);
  63. AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cpu", [&] {
  64. cpu_serial_kernel(iter, [generator]() -> scalar_t {
  65. uniform_int_distribution<scalar_t> random;
  66. return random(generator);
  67. });
  68. });
  69. }
  70. template<typename RNG>
  71. struct RandomKernel {
  72. void operator()(TensorIteratorBase& iter, std::optional<Generator> gen) {
  73. random_kernel(iter, check_generator<RNG>(gen));
  74. }
  75. };
  76. // ==================================================== Normal ========================================================
  77. #ifdef CPU_CAPABILITY_AVX2
  78. static void normal_fill_16_AVX2(float *data,
  79. const __m256* two_pi,
  80. const __m256* one,
  81. const __m256* minus_two,
  82. const __m256* mean,
  83. const __m256* std_v) {
  84. const __m256 u1 = _mm256_sub_ps(*one, _mm256_loadu_ps(data));
  85. const __m256 u2 = _mm256_loadu_ps(data + 8);
  86. // sincos256_ps and log256_ps are from avx_mathfun.h
  87. const __m256 radius = _mm256_sqrt_ps(_mm256_mul_ps(*minus_two, log256_ps(u1)));
  88. const __m256 theta = _mm256_mul_ps(*two_pi, u2);
  89. __m256 sintheta, costheta;
  90. sincos256_ps(theta, &sintheta, &costheta);
  91. const __m256 n1 = _mm256_mul_ps(radius, costheta);
  92. const __m256 n2 = _mm256_mul_ps(radius, sintheta);
  93. _mm256_storeu_ps(data, _mm256_fmadd_ps(n1, *std_v, *mean));
  94. _mm256_storeu_ps(data + 8, _mm256_fmadd_ps(n2, *std_v, *mean));
  95. }
  96. template<typename RNG>
  97. void normal_fill_AVX2(const TensorBase &self, const float mean, const float std, RNG generator) {
  98. float *data = self.data_ptr<float>();
  99. auto size = self.numel();
  100. std::lock_guard<std::mutex> lock(generator->mutex_);
  101. for (const auto i : c10::irange(size)) {
  102. at::uniform_real_distribution<float> uniform(0, 1);
  103. data[i] = uniform(generator);
  104. }
  105. const __m256 two_pi = _mm256_set1_ps(2.0f * c10::pi<double>);
  106. const __m256 one = _mm256_set1_ps(1.0f);
  107. const __m256 minus_two = _mm256_set1_ps(-2.0f);
  108. const __m256 mean_v = _mm256_set1_ps(mean);
  109. const __m256 std_v = _mm256_set1_ps(std);
  110. for (int64_t i = 0; i < size - 15; i += 16) {
  111. normal_fill_16_AVX2(data + i, &two_pi, &one, &minus_two, &mean_v, &std_v);
  112. }
  113. if (size % 16 != 0) {
  114. // Recompute the last 16 values.
  115. data = data + size - 16;
  116. for (const auto i : c10::irange(16)) {
  117. at::uniform_real_distribution<float> uniform(0, 1);
  118. data[i] = uniform(generator);
  119. }
  120. normal_fill_16_AVX2(data, &two_pi, &one, &minus_two, &mean_v, &std_v);
  121. }
  122. }
  123. #endif
  124. template <typename scalar_t>
  125. static void normal_fill_16(scalar_t *data, const scalar_t mean, const scalar_t std) {
  126. for (const auto j : c10::irange(8)) {
  127. const scalar_t u1 = 1 - data[j]; // [0, 1) -> (0, 1] for log.
  128. const scalar_t u2 = data[j + 8];
  129. const scalar_t radius = std::sqrt(-2 * std::log(u1));
  130. const scalar_t theta = 2.0f * c10::pi<double> * u2;
  131. data[j] = radius * std::cos(theta) * std + mean;
  132. data[j + 8] = radius * std::sin(theta) * std + mean;
  133. }
  134. }
  135. #if defined(__VSX__) || defined(CPU_CAPABILITY_VSX)
  136. static void normal_fill_16_VSX(float *data,const Vectorized<float> &two_pi,const Vectorized<float> &one,const Vectorized<float> &minus_two,const Vectorized<float> &mean,const Vectorized<float> &std) {
  137. using Vec = Vectorized<float>;
  138. Vec u1=one-Vec::loadu(data);
  139. Vec u2=Vec::loadu(data+8);
  140. Vec radius=(minus_two * u1.log());
  141. radius=radius.sqrt();
  142. Vec theta=two_pi * u2;
  143. Vec output_vec=radius * theta.cos() * std + mean;
  144. Vec output_vec2=radius * theta.sin() * std + mean;
  145. output_vec.store(data);
  146. output_vec2.store(data+8);
  147. }
  148. template <typename scalar_t, typename RNG>
  149. void normal_fill_VSX(const TensorBase &self, const scalar_t mean, const scalar_t std, RNG generator) {
  150. float *data = self.data_ptr<float>();
  151. auto size = self.numel();
  152. std::lock_guard<std::mutex> lock(generator->mutex_);
  153. for (const auto i : c10::irange(size)) {
  154. at::uniform_real_distribution<scalar_t> uniform(0, 1);
  155. data[i] = uniform(generator);
  156. }
  157. using Vec = Vectorized<float>;
  158. const Vec two_pi = Vec(2.0f * c10::pi<double>);
  159. const Vec one = Vec(1.0f);
  160. const Vec minus_two = Vec(-2.0f);
  161. const Vec var_vec = Vec(std);
  162. const Vec mean_vec = Vec(mean);
  163. for (int64_t i = 0; i < size - 15; i += 16) {
  164. if(Vec::size()==8) {
  165. normal_fill_16_VSX(data + i, two_pi, one, minus_two, mean_vec, var_vec);
  166. }
  167. else{
  168. normal_fill_16<scalar_t>(data + i, mean, std);
  169. }
  170. }
  171. if (size % 16 != 0) {
  172. // Recompute the last 16 values.
  173. data = data + size - 16;
  174. for (const auto i : c10::irange(16)) {
  175. at::uniform_real_distribution<scalar_t> uniform(0, 1);
  176. data[i] = uniform(generator);
  177. }
  178. if(Vec::size()==8){
  179. normal_fill_16_VSX(data, two_pi, one, minus_two, mean_vec, var_vec);
  180. }
  181. else{
  182. normal_fill_16<scalar_t>(data, mean, std);
  183. }
  184. }
  185. }
  186. #endif //VSX
  187. template <typename scalar_t, typename RNG>
  188. void normal_fill(const TensorBase &self, const scalar_t mean, const scalar_t std, RNG generator) {
  189. scalar_t *data = self.data_ptr<scalar_t>();
  190. auto size = self.numel();
  191. std::lock_guard<std::mutex> lock(generator->mutex_);
  192. for (const auto i : c10::irange(size)) {
  193. at::uniform_real_distribution<scalar_t> uniform(0, 1);
  194. data[i] = uniform(generator);
  195. }
  196. for (int64_t i = 0; i < size - 15; i += 16) {
  197. normal_fill_16<scalar_t>(data + i, mean, std);
  198. }
  199. if (size % 16 != 0) {
  200. // Recompute the last 16 values.
  201. data = data + size - 16;
  202. for (const auto i : c10::irange(16)) {
  203. at::uniform_real_distribution<scalar_t> uniform(0, 1);
  204. data[i] = uniform(generator);
  205. }
  206. normal_fill_16<scalar_t>(data, mean, std);
  207. }
  208. }
  209. template<typename RNG>
  210. void normal_kernel(const TensorBase &self, double mean, double std, RNG generator) {
  211. auto size = self.numel();
  212. if (self.scalar_type() == ScalarType::Float && size >= 16 && self.is_contiguous()) {
  213. #ifdef CPU_CAPABILITY_AVX2
  214. normal_fill_AVX2(self, static_cast<float>(mean), static_cast<float>(std), generator);
  215. #elif defined(__VSX__) || defined(CPU_CAPABILITY_VSX)
  216. normal_fill_VSX(self, static_cast<float>(mean), static_cast<float>(std), generator);
  217. #else
  218. normal_fill(self, static_cast<float>(mean), static_cast<float>(std), generator);
  219. #endif
  220. } else {
  221. AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, self.scalar_type(), "normal_kernel_cpu", [&] {
  222. if (size >= 16 && self.is_contiguous()) {
  223. normal_fill<scalar_t>(self, static_cast<scalar_t>(mean), static_cast<scalar_t>(std), generator);
  224. } else {
  225. auto iter = TensorIterator::borrowing_nullary_op(self);
  226. std::lock_guard<std::mutex> lock(generator->mutex_);
  227. cpu_serial_kernel(iter, [mean, std, generator]() -> scalar_t {
  228. at::normal_distribution<double> normal(mean, std);
  229. return static_cast<scalar_t>(normal(generator));
  230. });
  231. }
  232. });
  233. }
  234. }
  235. template<typename RNG>
  236. struct NormalKernel {
  237. void operator()(Tensor& self, double mean, double std, std::optional<Generator> gen) {
  238. normal_kernel(self, mean, std, check_generator<RNG>(gen));
  239. }
  240. };
  241. // ==================================================== Uniform =======================================================
  242. template<typename RNG>
  243. void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG generator) {
  244. AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "uniform_kernel_cpu", [&]() {
  245. std::lock_guard<std::mutex> lock(generator->mutex_);
  246. auto from = static_cast<scalar_t>(from_);
  247. auto to = static_cast<scalar_t>(to_);
  248. at::uniform_real_distribution<scalar_t> uniform(from, to);
  249. cpu_serial_kernel(iter, [&uniform, generator]() -> scalar_t {
  250. return static_cast<scalar_t>(uniform(generator));
  251. });
  252. });
  253. }
  254. template<typename RNG>
  255. struct UniformKernel {
  256. void operator()(TensorIteratorBase& iter, double from, double to, std::optional<Generator> gen) {
  257. uniform_kernel(iter, from, to, check_generator<RNG>(gen));
  258. }
  259. };
  260. // ==================================================== Cauchy ========================================================
  261. template<typename RNG>
  262. void cauchy_kernel(TensorIteratorBase& iter, double median, double sigma, RNG generator) {
  263. AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "cauchy_cpu", [&]() {
  264. std::lock_guard<std::mutex> lock(generator->mutex_);
  265. at::cauchy_distribution<double> cauchy(median, sigma);
  266. cpu_serial_kernel(iter, [&cauchy, generator]() -> scalar_t {
  267. return static_cast<scalar_t>(cauchy(generator));
  268. });
  269. });
  270. }
  271. template<typename RNG>
  272. struct CauchyKernel {
  273. void operator()(TensorIteratorBase& iter, double median, double sigma, std::optional<Generator> gen) {
  274. cauchy_kernel(iter, median, sigma, check_generator<RNG>(gen));
  275. }
  276. };
  277. // ================================================== LogNormal =======================================================
  278. template<typename RNG>
  279. void log_normal_kernel(TensorIteratorBase& iter, double mean, double std, RNG generator) {
  280. AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_cpu", [&]() {
  281. std::lock_guard<std::mutex> lock(generator->mutex_);
  282. at::lognormal_distribution<double> logNormal(mean, std);
  283. cpu_serial_kernel(iter, [&logNormal, generator]() -> scalar_t {
  284. return static_cast<scalar_t>(logNormal(generator));
  285. });
  286. });
  287. }
  288. template<typename RNG>
  289. struct LogNormalKernel {
  290. void operator()(TensorIteratorBase& iter, double mean, double std, std::optional<Generator> gen) {
  291. log_normal_kernel(iter, mean, std, check_generator<RNG>(gen));
  292. }
  293. };
  294. // =================================================== Geometric ======================================================
  295. template<typename RNG>
  296. void geometric_kernel(TensorIteratorBase& iter, double p, RNG generator) {
  297. AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cpu", [&]() {
  298. std::lock_guard<std::mutex> lock(generator->mutex_);
  299. at::geometric_distribution<double> geometric(p);
  300. cpu_serial_kernel(iter, [&geometric, generator]() -> scalar_t {
  301. return static_cast<scalar_t>(geometric(generator));
  302. });
  303. });
  304. }
  305. template<typename RNG>
  306. struct GeometricKernel {
  307. void operator()(TensorIteratorBase& iter, double p, std::optional<Generator> gen) {
  308. geometric_kernel(iter, p, check_generator<RNG>(gen));
  309. }
  310. };
  311. // ================================================== Exponential =====================================================
  312. template<typename RNG>
  313. void exponential_kernel(TensorIteratorBase& iter, double lambda, RNG generator) {
  314. TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype());
  315. AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cpu", [&]() {
  316. std::lock_guard<std::mutex> lock(generator->mutex_);
  317. at::exponential_distribution<double> exponential(lambda);
  318. cpu_serial_kernel(iter, [&exponential, generator]() -> scalar_t {
  319. return static_cast<scalar_t>(exponential(generator));
  320. });
  321. });
  322. }
  323. template<typename RNG>
  324. struct ExponentialKernel {
  325. void operator()(TensorIteratorBase& iter, double lambda, std::optional<Generator> gen) {
  326. exponential_kernel(iter, lambda, check_generator<RNG>(gen));
  327. }
  328. };
  329. // ================================================== Bernoulli =======================================================
  330. template<typename RNG>
  331. void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG generator) {
  332. AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
  333. self.scalar_type(), "bernoulli_tensor_cpu_self_", [&] {
  334. // See Note [Acquire lock when using random generators]
  335. std::lock_guard<std::mutex> lock(generator->mutex_);
  336. using self_t = scalar_t;
  337. auto p_cpu = p_.to(kCPU);
  338. auto p = expand_inplace(self, p_cpu);
  339. auto iter = TensorIteratorConfig()
  340. .add_output(self)
  341. .add_const_input(*p)
  342. .check_all_same_dtype(false)
  343. .build();
  344. if (p->scalar_type() == kDouble) {
  345. cpu_serial_kernel(iter, [&](const double p_val) -> self_t {
  346. at::bernoulli_distribution<double> bernoulli(p_val);
  347. return static_cast<self_t>(bernoulli(generator));
  348. });
  349. } else {
  350. AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half,
  351. p->scalar_type(), "bernoulli_tensor_cpu_p_", [&] {
  352. using p_t = scalar_t;
  353. cpu_serial_kernel(iter, [&](const p_t p_val) -> self_t {
  354. at::bernoulli_distribution<float> bernoulli(p_val);
  355. return static_cast<self_t>(bernoulli(generator));
  356. });
  357. });
  358. }
  359. });
  360. }
  361. template<typename RNG>
  362. void bernoulli_kernel(const TensorBase &self, double p, RNG generator) {
  363. AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
  364. self.scalar_type(), "bernoulli_scalar_cpu_", [&] {
  365. // See Note [Acquire lock when using random generators]
  366. std::lock_guard<std::mutex> lock(generator->mutex_);
  367. auto iter = TensorIterator::borrowing_nullary_op(self);
  368. cpu_serial_kernel(iter, [p, generator]() -> scalar_t {
  369. at::bernoulli_distribution<double> bernoulli(p);
  370. return static_cast<scalar_t>(bernoulli(generator));
  371. });
  372. });
  373. }
  374. template<typename RNG>
  375. struct BernoulliKernel {
  376. void operator()(const TensorBase &self, double p, std::optional<Generator> gen) {
  377. bernoulli_kernel(self, p, check_generator<RNG>(gen));
  378. }
  379. void operator()(const TensorBase &self, const TensorBase &p_, std::optional<Generator> gen) {
  380. bernoulli_kernel(self, p_, check_generator<RNG>(gen));
  381. }
  382. };
  383. }}}}}