zmath.h 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. #pragma once
  2. // Complex number math operations that act as no-ops for other dtypes.
  3. #include <c10/util/complex.h>
  4. #include <c10/util/MathConstants.h>
  5. #include<ATen/NumericUtils.h>
  6. namespace at { namespace native {
  7. inline namespace CPU_CAPABILITY {
  8. template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
  9. inline VALUE_TYPE zabs (SCALAR_TYPE z) {
  10. return z;
  11. }
  12. template<>
  13. inline c10::complex<float> zabs <c10::complex<float>> (c10::complex<float> z) {
  14. return c10::complex<float>(std::abs(z));
  15. }
  16. template<>
  17. inline float zabs <c10::complex<float>, float> (c10::complex<float> z) {
  18. return std::abs(z);
  19. }
  20. template<>
  21. inline c10::complex<double> zabs <c10::complex<double>> (c10::complex<double> z) {
  22. return c10::complex<double>(std::abs(z));
  23. }
  24. template<>
  25. inline double zabs <c10::complex<double>, double> (c10::complex<double> z) {
  26. return std::abs(z);
  27. }
  28. // This overload corresponds to non-complex dtypes.
  29. // The function is consistent with its NumPy equivalent
  30. // for non-complex dtypes where `pi` is returned for
  31. // negative real numbers and `0` is returned for 0 or positive
  32. // real numbers.
  33. // Note: `nan` is propagated.
  34. template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
  35. inline VALUE_TYPE angle_impl (SCALAR_TYPE z) {
  36. if (at::_isnan(z)) {
  37. return z;
  38. }
  39. return z < 0 ? c10::pi<double> : 0;
  40. }
  41. template<>
  42. inline c10::complex<float> angle_impl <c10::complex<float>> (c10::complex<float> z) {
  43. return c10::complex<float>(std::arg(z), 0.0);
  44. }
  45. template<>
  46. inline float angle_impl <c10::complex<float>, float> (c10::complex<float> z) {
  47. return std::arg(z);
  48. }
  49. template<>
  50. inline c10::complex<double> angle_impl <c10::complex<double>> (c10::complex<double> z) {
  51. return c10::complex<double>(std::arg(z), 0.0);
  52. }
  53. template<>
  54. inline double angle_impl <c10::complex<double>, double> (c10::complex<double> z) {
  55. return std::arg(z);
  56. }
  57. template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
  58. constexpr VALUE_TYPE real_impl (SCALAR_TYPE z) {
  59. return z; //No-Op
  60. }
  61. template<>
  62. constexpr c10::complex<float> real_impl <c10::complex<float>> (c10::complex<float> z) {
  63. return c10::complex<float>(z.real(), 0.0);
  64. }
  65. template<>
  66. constexpr float real_impl <c10::complex<float>, float> (c10::complex<float> z) {
  67. return z.real();
  68. }
  69. template<>
  70. constexpr c10::complex<double> real_impl <c10::complex<double>> (c10::complex<double> z) {
  71. return c10::complex<double>(z.real(), 0.0);
  72. }
  73. template<>
  74. constexpr double real_impl <c10::complex<double>, double> (c10::complex<double> z) {
  75. return z.real();
  76. }
  77. template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
  78. constexpr VALUE_TYPE imag_impl (SCALAR_TYPE /*z*/) {
  79. return 0;
  80. }
  81. template<>
  82. constexpr c10::complex<float> imag_impl <c10::complex<float>> (c10::complex<float> z) {
  83. return c10::complex<float>(z.imag(), 0.0);
  84. }
  85. template<>
  86. constexpr float imag_impl <c10::complex<float>, float> (c10::complex<float> z) {
  87. return z.imag();
  88. }
  89. template<>
  90. constexpr c10::complex<double> imag_impl <c10::complex<double>> (c10::complex<double> z) {
  91. return c10::complex<double>(z.imag(), 0.0);
  92. }
  93. template<>
  94. constexpr double imag_impl <c10::complex<double>, double> (c10::complex<double> z) {
  95. return z.imag();
  96. }
  97. template <typename TYPE>
  98. inline TYPE conj_impl (TYPE z) {
  99. return z; //No-Op
  100. }
  101. template<>
  102. inline c10::complex<at::Half> conj_impl <c10::complex<at::Half>> (c10::complex<at::Half> z) {
  103. return c10::complex<at::Half>{z.real(), -z.imag()};
  104. }
  105. template<>
  106. inline c10::complex<float> conj_impl <c10::complex<float>> (c10::complex<float> z) {
  107. return c10::complex<float>(z.real(), -z.imag());
  108. }
  109. template<>
  110. inline c10::complex<double> conj_impl <c10::complex<double>> (c10::complex<double> z) {
  111. return c10::complex<double>(z.real(), -z.imag());
  112. }
  113. template <typename TYPE>
  114. inline TYPE ceil_impl (TYPE z) {
  115. return std::ceil(z);
  116. }
  117. template <>
  118. inline c10::complex<float> ceil_impl (c10::complex<float> z) {
  119. return c10::complex<float>(std::ceil(z.real()), std::ceil(z.imag()));
  120. }
  121. template <>
  122. inline c10::complex<double> ceil_impl (c10::complex<double> z) {
  123. return c10::complex<double>(std::ceil(z.real()), std::ceil(z.imag()));
  124. }
  125. template<typename T>
  126. inline c10::complex<T> sgn_impl (c10::complex<T> z) {
  127. if (z == c10::complex<T>(0, 0)) {
  128. return c10::complex<T>(0, 0);
  129. } else {
  130. return z / zabs(z);
  131. }
  132. }
  133. template <typename TYPE>
  134. inline TYPE floor_impl (TYPE z) {
  135. return std::floor(z);
  136. }
  137. template <>
  138. inline c10::complex<float> floor_impl (c10::complex<float> z) {
  139. return c10::complex<float>(std::floor(z.real()), std::floor(z.imag()));
  140. }
  141. template <>
  142. inline c10::complex<double> floor_impl (c10::complex<double> z) {
  143. return c10::complex<double>(std::floor(z.real()), std::floor(z.imag()));
  144. }
  145. template <typename TYPE>
  146. inline TYPE round_impl (TYPE z) {
  147. return std::nearbyint(z);
  148. }
  149. template <>
  150. inline c10::complex<float> round_impl (c10::complex<float> z) {
  151. return c10::complex<float>(std::nearbyint(z.real()), std::nearbyint(z.imag()));
  152. }
  153. template <>
  154. inline c10::complex<double> round_impl (c10::complex<double> z) {
  155. return c10::complex<double>(std::nearbyint(z.real()), std::nearbyint(z.imag()));
  156. }
  157. template <typename TYPE>
  158. inline TYPE trunc_impl (TYPE z) {
  159. return std::trunc(z);
  160. }
  161. template <>
  162. inline c10::complex<float> trunc_impl (c10::complex<float> z) {
  163. return c10::complex<float>(std::trunc(z.real()), std::trunc(z.imag()));
  164. }
  165. template <>
  166. inline c10::complex<double> trunc_impl (c10::complex<double> z) {
  167. return c10::complex<double>(std::trunc(z.real()), std::trunc(z.imag()));
  168. }
  169. template <typename TYPE, std::enable_if_t<!c10::is_complex<TYPE>::value, int> = 0>
  170. inline TYPE max_impl (TYPE a, TYPE b) {
  171. if (_isnan<TYPE>(a) || _isnan<TYPE>(b)) {
  172. return std::numeric_limits<TYPE>::quiet_NaN();
  173. } else {
  174. return std::max(a, b);
  175. }
  176. }
  177. template <typename TYPE, std::enable_if_t<c10::is_complex<TYPE>::value, int> = 0>
  178. inline TYPE max_impl (TYPE a, TYPE b) {
  179. if (_isnan<TYPE>(a)) {
  180. return a;
  181. } else if (_isnan<TYPE>(b)) {
  182. return b;
  183. } else {
  184. return std::abs(a) > std::abs(b) ? a : b;
  185. }
  186. }
  187. template <typename TYPE, std::enable_if_t<!c10::is_complex<TYPE>::value, int> = 0>
  188. inline TYPE min_impl (TYPE a, TYPE b) {
  189. if (_isnan<TYPE>(a) || _isnan<TYPE>(b)) {
  190. return std::numeric_limits<TYPE>::quiet_NaN();
  191. } else {
  192. return std::min(a, b);
  193. }
  194. }
  195. template <typename TYPE, std::enable_if_t<c10::is_complex<TYPE>::value, int> = 0>
  196. inline TYPE min_impl (TYPE a, TYPE b) {
  197. if (_isnan<TYPE>(a)) {
  198. return a;
  199. } else if (_isnan<TYPE>(b)) {
  200. return b;
  201. } else {
  202. return std::abs(a) < std::abs(b) ? a : b;
  203. }
  204. }
  205. } // end namespace
  206. }} //end at::native