UnaryOps.h 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. #pragma once
  2. #include <ATen/native/DispatchStub.h>
  3. #include <ATen/Generator.h>
  4. #include <c10/core/Scalar.h>
  5. #include <stdexcept>
  6. namespace at {
  7. class Tensor;
  8. class TensorBase;
  9. struct TensorIteratorBase;
  10. }
  11. namespace at::native {
  12. using unary_fn = void(*)(TensorIteratorBase&);
  13. using unary_fn_with_scalar = void(*)(TensorIteratorBase&, const Scalar& a);
  14. inline namespace CPU_CAPABILITY {
  15. void conj_kernel(TensorIteratorBase &iter);
  16. void neg_kernel(TensorIteratorBase &iter);
  17. void reciprocal_kernel(TensorIteratorBase &iter);
  18. void rsqrt_kernel(TensorIteratorBase& iter);
  19. void sqrt_kernel(TensorIteratorBase& iter);
  20. } // namespace CPU_CAPABILITY
  21. DECLARE_DISPATCH(unary_fn, abs_stub);
  22. DECLARE_DISPATCH(unary_fn, angle_stub);
  23. DECLARE_DISPATCH(unary_fn, conj_physical_stub);
  24. DECLARE_DISPATCH(unary_fn, acos_stub);
  25. DECLARE_DISPATCH(unary_fn, acosh_stub);
  26. DECLARE_DISPATCH(unary_fn, asinh_stub);
  27. DECLARE_DISPATCH(unary_fn, atanh_stub);
  28. DECLARE_DISPATCH(unary_fn, asin_stub);
  29. DECLARE_DISPATCH(unary_fn, atan_stub);
  30. DECLARE_DISPATCH(unary_fn, bitwise_not_stub);
  31. DECLARE_DISPATCH(unary_fn, logical_not_stub);
  32. DECLARE_DISPATCH(unary_fn, ceil_stub);
  33. DECLARE_DISPATCH(unary_fn, cos_stub);
  34. DECLARE_DISPATCH(unary_fn, cosh_stub);
  35. DECLARE_DISPATCH(unary_fn, digamma_stub);
  36. DECLARE_DISPATCH(unary_fn, special_entr_stub);
  37. DECLARE_DISPATCH(unary_fn, special_erfcx_stub);
  38. DECLARE_DISPATCH(unary_fn, erf_stub);
  39. DECLARE_DISPATCH(unary_fn, erfc_stub);
  40. DECLARE_DISPATCH(unary_fn, erfinv_stub);
  41. DECLARE_DISPATCH(unary_fn, exp_stub);
  42. DECLARE_DISPATCH(unary_fn, exp2_stub);
  43. DECLARE_DISPATCH(unary_fn, expm1_stub);
  44. DECLARE_DISPATCH(unary_fn, floor_stub);
  45. DECLARE_DISPATCH(unary_fn, frac_stub);
  46. DECLARE_DISPATCH(unary_fn, frexp_stub);
  47. DECLARE_DISPATCH(unary_fn, i0_stub);
  48. DECLARE_DISPATCH(unary_fn, special_i0e_stub);
  49. DECLARE_DISPATCH(unary_fn, special_i1_stub);
  50. DECLARE_DISPATCH(unary_fn, special_i1e_stub);
  51. DECLARE_DISPATCH(unary_fn, log_stub);
  52. DECLARE_DISPATCH(unary_fn, log10_stub);
  53. DECLARE_DISPATCH(unary_fn, log1p_stub);
  54. DECLARE_DISPATCH(unary_fn, log2_stub);
  55. DECLARE_DISPATCH(unary_fn, special_ndtri_stub);
  56. DECLARE_DISPATCH(unary_fn, special_log_ndtr_stub);
  57. DECLARE_DISPATCH(unary_fn, neg_stub);
  58. DECLARE_DISPATCH(unary_fn, reciprocal_stub);
  59. DECLARE_DISPATCH(unary_fn, round_stub);
  60. DECLARE_DISPATCH(unary_fn, rsqrt_stub);
  61. DECLARE_DISPATCH(unary_fn, sigmoid_stub);
  62. DECLARE_DISPATCH(unary_fn_with_scalar, logit_stub);
  63. DECLARE_DISPATCH(unary_fn, sign_stub);
  64. DECLARE_DISPATCH(unary_fn, signbit_stub);
  65. DECLARE_DISPATCH(unary_fn, sgn_stub);
  66. DECLARE_DISPATCH(unary_fn, sin_stub);
  67. DECLARE_DISPATCH(unary_fn, sinc_stub);
  68. DECLARE_DISPATCH(unary_fn, sinh_stub);
  69. DECLARE_DISPATCH(unary_fn, sqrt_stub);
  70. DECLARE_DISPATCH(unary_fn, tan_stub);
  71. DECLARE_DISPATCH(unary_fn, tanh_stub);
  72. DECLARE_DISPATCH(unary_fn, trigamma_stub);
  73. DECLARE_DISPATCH(unary_fn, trunc_stub);
  74. DECLARE_DISPATCH(unary_fn, lgamma_stub);
  75. DECLARE_DISPATCH(unary_fn, special_airy_ai_stub);
  76. DECLARE_DISPATCH(unary_fn, special_bessel_j0_stub);
  77. DECLARE_DISPATCH(unary_fn, special_bessel_j1_stub);
  78. DECLARE_DISPATCH(unary_fn, special_bessel_y0_stub);
  79. DECLARE_DISPATCH(unary_fn, special_bessel_y1_stub);
  80. DECLARE_DISPATCH(unary_fn, special_modified_bessel_i0_stub);
  81. DECLARE_DISPATCH(unary_fn, special_modified_bessel_i1_stub);
  82. DECLARE_DISPATCH(unary_fn, special_modified_bessel_k0_stub);
  83. DECLARE_DISPATCH(unary_fn, special_modified_bessel_k1_stub);
  84. DECLARE_DISPATCH(unary_fn, special_scaled_modified_bessel_k0_stub);
  85. DECLARE_DISPATCH(unary_fn, special_scaled_modified_bessel_k1_stub);
  86. DECLARE_DISPATCH(unary_fn, special_spherical_bessel_j0_stub);
  87. // NB: these are actually defined in Distribution
  88. DECLARE_DISPATCH(void(*)(const TensorBase&, const TensorBase&, std::optional<Generator>), bernoulli_tensor_stub);
  89. DECLARE_DISPATCH(void(*)(const TensorBase&, const double, std::optional<Generator>), bernoulli_scalar_stub);
  90. DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, std::optional<Generator>), cauchy_stub);
  91. DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, std::optional<Generator>), exponential_stub);
  92. DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, std::optional<Generator>), geometric_stub);
  93. DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, std::optional<Generator>), log_normal_stub);
  94. DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, std::optional<Generator>), uniform_stub);
  95. DECLARE_DISPATCH(void(*)(const TensorBase&, const double, const double, std::optional<Generator>), normal_stub);
  96. DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const uint64_t, const int64_t, std::optional<Generator>), random_from_to_stub);
  97. DECLARE_DISPATCH(void(*)(TensorIteratorBase&, std::optional<Generator>), random_full_64_bits_range_stub);
  98. DECLARE_DISPATCH(void(*)(TensorIteratorBase&, std::optional<Generator>), random_stub);
  99. DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const int64_t, const double), kaiser_window_stub);
  100. DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const int64_t), polygamma_stub);
  101. DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const Scalar& a, const Scalar& b), clamp_stub);
  102. DECLARE_DISPATCH(
  103. void (*)(Tensor&, const Tensor&, int64_t, std::optional<Generator>),
  104. multinomial_with_replacement_stub);
  105. DECLARE_DISPATCH(
  106. void (*)(
  107. TensorIteratorBase&,
  108. std::optional<double>,
  109. std::optional<double>,
  110. std::optional<double>),
  111. nan_to_num_stub);
  112. DECLARE_DISPATCH(void (*)(TensorIteratorBase&, int64_t), round_decimals_stub);
  113. // Missing unary functions
  114. // digamma
  115. // lgamma
  116. // erfinv
  117. // clone
  118. // contiguous
  119. // zero
  120. } // namespace at::native