complex.h 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618
  1. #pragma once
  2. #include <complex>
  3. #include <c10/macros/Macros.h>
  4. #if defined(__CUDACC__) || defined(__HIPCC__)
  5. #include <thrust/complex.h>
  6. #endif
  7. C10_CLANG_DIAGNOSTIC_PUSH()
  8. #if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
  9. C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
  10. #endif
  11. #if C10_CLANG_HAS_WARNING("-Wfloat-conversion")
  12. C10_CLANG_DIAGNOSTIC_IGNORE("-Wfloat-conversion")
  13. #endif
  14. namespace c10 {
  15. // c10::complex is an implementation of complex numbers that aims
  16. // to work on all devices supported by PyTorch
  17. //
  18. // Most of the APIs duplicates std::complex
  19. // Reference: https://en.cppreference.com/w/cpp/numeric/complex
  20. //
  21. // [NOTE: Complex Operator Unification]
  22. // Operators currently use a mix of std::complex, thrust::complex, and
  23. // c10::complex internally. The end state is that all operators will use
  24. // c10::complex internally. Until then, there may be some hacks to support all
  25. // variants.
  26. //
  27. //
  28. // [Note on Constructors]
  29. //
  30. // The APIs of constructors are mostly copied from C++ standard:
  31. // https://en.cppreference.com/w/cpp/numeric/complex/complex
  32. //
  33. // Since C++14, all constructors are constexpr in std::complex
  34. //
  35. // There are three types of constructors:
  36. // - initializing from real and imag:
  37. // `constexpr complex( const T& re = T(), const T& im = T() );`
  38. // - implicitly-declared copy constructor
  39. // - converting constructors
  40. //
  41. // Converting constructors:
  42. // - std::complex defines converting constructor between float/double/long
  43. // double,
  44. // while we define converting constructor between float/double.
  45. // - For these converting constructors, upcasting is implicit, downcasting is
  46. // explicit.
  47. // - We also define explicit casting from std::complex/thrust::complex
  48. // - Note that the conversion from thrust is not constexpr, because
  49. // thrust does not define them as constexpr ????
  50. //
  51. //
  52. // [Operator =]
  53. //
  54. // The APIs of operator = are mostly copied from C++ standard:
  55. // https://en.cppreference.com/w/cpp/numeric/complex/operator%3D
  56. //
  57. // Since C++20, all operator= are constexpr. Although we are not building with
  58. // C++20, we also obey this behavior.
  59. //
  60. // There are three types of assign operator:
  61. // - Assign a real value from the same scalar type
  62. // - In std, this is templated as complex& operator=(const T& x)
  63. // with specialization `complex& operator=(T x)` for float/double/long
  64. // double Since we only support float and double, on will use `complex&
  65. // operator=(T x)`
  66. // - Copy assignment operator and converting assignment operator
  67. // - There is no specialization of converting assignment operators, which type
  68. // is
  69. // convertible is solely dependent on whether the scalar type is convertible
  70. //
  71. // In addition to the standard assignment, we also provide assignment operators
  72. // with std and thrust
  73. //
  74. //
  75. // [Casting operators]
  76. //
  77. // std::complex does not have casting operators. We define casting operators
  78. // casting to std::complex and thrust::complex
  79. //
  80. //
  81. // [Operator ""]
  82. //
  83. // std::complex has custom literals `i`, `if` and `il` defined in namespace
  84. // `std::literals::complex_literals`. We define our own custom literals in the
  85. // namespace `c10::complex_literals`. Our custom literals does not follow the
  86. // same behavior as in std::complex, instead, we define _if, _id to construct
  87. // float/double complex literals.
  88. //
  89. //
  90. // [real() and imag()]
  91. //
  92. // In C++20, there are two overload of these functions, one it to return the
  93. // real/imag, another is to set real/imag, they are both constexpr. We follow
  94. // this design.
  95. //
  96. //
  97. // [Operator +=,-=,*=,/=]
  98. //
  99. // Since C++20, these operators become constexpr. In our implementation, they
  100. // are also constexpr.
  101. //
  102. // There are two types of such operators: operating with a real number, or
  103. // operating with another complex number. For the operating with a real number,
  104. // the generic template form has argument type `const T &`, while the overload
  105. // for float/double/long double has `T`. We will follow the same type as
  106. // float/double/long double in std.
  107. //
  108. // [Unary operator +-]
  109. //
  110. // Since C++20, they are constexpr. We also make them expr
  111. //
  112. // [Binary operators +-*/]
  113. //
  114. // Each operator has three versions (taking + as example):
  115. // - complex + complex
  116. // - complex + real
  117. // - real + complex
  118. //
  119. // [Operator ==, !=]
  120. //
  121. // Each operator has three versions (taking == as example):
  122. // - complex == complex
  123. // - complex == real
  124. // - real == complex
  125. //
  126. // Some of them are removed on C++20, but we decide to keep them
  127. //
  128. // [Operator <<, >>]
  129. //
  130. // These are implemented by casting to std::complex
  131. //
  132. //
  133. //
  134. // TODO(@zasdfgbnm): c10::complex<c10::Half> is not currently supported,
  135. // because:
  136. // - lots of members and functions of c10::Half are not constexpr
  137. // - thrust::complex only support float and double
  138. template <typename T>
  139. struct alignas(sizeof(T) * 2) complex {
  140. using value_type = T;
  141. T real_ = T(0);
  142. T imag_ = T(0);
  143. constexpr complex() = default;
  144. C10_HOST_DEVICE constexpr complex(const T& re, const T& im = T())
  145. : real_(re), imag_(im) {}
  146. template <typename U>
  147. explicit constexpr complex(const std::complex<U>& other)
  148. : complex(other.real(), other.imag()) {}
  149. #if defined(__CUDACC__) || defined(__HIPCC__)
  150. template <typename U>
  151. explicit C10_HOST_DEVICE complex(const thrust::complex<U>& other)
  152. : real_(other.real()), imag_(other.imag()) {}
  153. // NOTE can not be implemented as follow due to ROCm bug:
  154. // explicit C10_HOST_DEVICE complex(const thrust::complex<U> &other):
  155. // complex(other.real(), other.imag()) {}
  156. #endif
  157. // Use SFINAE to specialize casting constructor for c10::complex<float> and
  158. // c10::complex<double>
  159. template <typename U = T>
  160. C10_HOST_DEVICE explicit constexpr complex(
  161. const std::enable_if_t<std::is_same_v<U, float>, complex<double>>& other)
  162. : real_(other.real_), imag_(other.imag_) {}
  163. template <typename U = T>
  164. C10_HOST_DEVICE constexpr complex(
  165. const std::enable_if_t<std::is_same_v<U, double>, complex<float>>& other)
  166. : real_(other.real_), imag_(other.imag_) {}
  167. constexpr complex<T>& operator=(T re) {
  168. real_ = re;
  169. imag_ = 0;
  170. return *this;
  171. }
  172. constexpr complex<T>& operator+=(T re) {
  173. real_ += re;
  174. return *this;
  175. }
  176. constexpr complex<T>& operator-=(T re) {
  177. real_ -= re;
  178. return *this;
  179. }
  180. constexpr complex<T>& operator*=(T re) {
  181. real_ *= re;
  182. imag_ *= re;
  183. return *this;
  184. }
  185. constexpr complex<T>& operator/=(T re) {
  186. real_ /= re;
  187. imag_ /= re;
  188. return *this;
  189. }
  190. template <typename U>
  191. constexpr complex<T>& operator=(const complex<U>& rhs) {
  192. real_ = rhs.real();
  193. imag_ = rhs.imag();
  194. return *this;
  195. }
  196. template <typename U>
  197. constexpr complex<T>& operator+=(const complex<U>& rhs) {
  198. real_ += rhs.real();
  199. imag_ += rhs.imag();
  200. return *this;
  201. }
  202. template <typename U>
  203. constexpr complex<T>& operator-=(const complex<U>& rhs) {
  204. real_ -= rhs.real();
  205. imag_ -= rhs.imag();
  206. return *this;
  207. }
  208. template <typename U>
  209. constexpr complex<T>& operator*=(const complex<U>& rhs) {
  210. // (a + bi) * (c + di) = (a*c - b*d) + (a * d + b * c) i
  211. T a = real_;
  212. T b = imag_;
  213. U c = rhs.real();
  214. U d = rhs.imag();
  215. real_ = a * c - b * d;
  216. imag_ = a * d + b * c;
  217. return *this;
  218. }
  219. #ifdef __APPLE__
  220. #define FORCE_INLINE_APPLE __attribute__((always_inline))
  221. #else
  222. #define FORCE_INLINE_APPLE
  223. #endif
  224. template <typename U>
  225. constexpr FORCE_INLINE_APPLE complex<T>& operator/=(const complex<U>& rhs)
  226. __ubsan_ignore_float_divide_by_zero__ {
  227. // (a + bi) / (c + di) = (ac + bd)/(c^2 + d^2) + (bc - ad)/(c^2 + d^2) i
  228. // the calculation below follows numpy's complex division
  229. T a = real_;
  230. T b = imag_;
  231. U c = rhs.real();
  232. U d = rhs.imag();
  233. #if defined(__GNUC__) && !defined(__clang__)
  234. // std::abs is already constexpr by gcc
  235. auto abs_c = std::abs(c);
  236. auto abs_d = std::abs(d);
  237. #else
  238. auto abs_c = c < 0 ? -c : c;
  239. auto abs_d = d < 0 ? -d : d;
  240. #endif
  241. if (abs_c >= abs_d) {
  242. if (abs_c == 0 && abs_d == 0) {
  243. /* divide by zeros should yield a complex inf or nan */
  244. real_ = a / abs_c;
  245. imag_ = b / abs_d;
  246. } else {
  247. auto rat = d / c;
  248. auto scl = 1.0 / (c + d * rat);
  249. real_ = (a + b * rat) * scl;
  250. imag_ = (b - a * rat) * scl;
  251. }
  252. } else {
  253. auto rat = c / d;
  254. auto scl = 1.0 / (d + c * rat);
  255. real_ = (a * rat + b) * scl;
  256. imag_ = (b * rat - a) * scl;
  257. }
  258. return *this;
  259. }
  260. #undef FORCE_INLINE_APPLE
  261. template <typename U>
  262. constexpr complex<T>& operator=(const std::complex<U>& rhs) {
  263. real_ = rhs.real();
  264. imag_ = rhs.imag();
  265. return *this;
  266. }
  267. #if defined(__CUDACC__) || defined(__HIPCC__)
  268. template <typename U>
  269. C10_HOST_DEVICE complex<T>& operator=(const thrust::complex<U>& rhs) {
  270. real_ = rhs.real();
  271. imag_ = rhs.imag();
  272. return *this;
  273. }
  274. #endif
  275. template <typename U>
  276. explicit constexpr operator std::complex<U>() const {
  277. return std::complex<U>(std::complex<T>(real(), imag()));
  278. }
  279. #if defined(__CUDACC__) || defined(__HIPCC__)
  280. template <typename U>
  281. C10_HOST_DEVICE explicit operator thrust::complex<U>() const {
  282. return static_cast<thrust::complex<U>>(thrust::complex<T>(real(), imag()));
  283. }
  284. #endif
  285. // consistent with NumPy behavior
  286. explicit constexpr operator bool() const {
  287. return real() || imag();
  288. }
  289. C10_HOST_DEVICE constexpr T real() const {
  290. return real_;
  291. }
  292. constexpr void real(T value) {
  293. real_ = value;
  294. }
  295. C10_HOST_DEVICE constexpr T imag() const {
  296. return imag_;
  297. }
  298. constexpr void imag(T value) {
  299. imag_ = value;
  300. }
  301. };
  302. namespace complex_literals {
  303. constexpr complex<float> operator""_if(long double imag) {
  304. return complex<float>(0.0f, static_cast<float>(imag));
  305. }
  306. constexpr complex<double> operator""_id(long double imag) {
  307. return complex<double>(0.0, static_cast<double>(imag));
  308. }
  309. constexpr complex<float> operator""_if(unsigned long long imag) {
  310. return complex<float>(0.0f, static_cast<float>(imag));
  311. }
  312. constexpr complex<double> operator""_id(unsigned long long imag) {
  313. return complex<double>(0.0, static_cast<double>(imag));
  314. }
  315. } // namespace complex_literals
  316. template <typename T>
  317. constexpr complex<T> operator+(const complex<T>& val) {
  318. return val;
  319. }
  320. template <typename T>
  321. constexpr complex<T> operator-(const complex<T>& val) {
  322. return complex<T>(-val.real(), -val.imag());
  323. }
  324. template <typename T>
  325. constexpr complex<T> operator+(const complex<T>& lhs, const complex<T>& rhs) {
  326. complex<T> result = lhs;
  327. return result += rhs;
  328. }
  329. template <typename T>
  330. constexpr complex<T> operator+(const complex<T>& lhs, const T& rhs) {
  331. complex<T> result = lhs;
  332. return result += rhs;
  333. }
  334. template <typename T>
  335. constexpr complex<T> operator+(const T& lhs, const complex<T>& rhs) {
  336. return complex<T>(lhs + rhs.real(), rhs.imag());
  337. }
  338. template <typename T>
  339. constexpr complex<T> operator-(const complex<T>& lhs, const complex<T>& rhs) {
  340. complex<T> result = lhs;
  341. return result -= rhs;
  342. }
  343. template <typename T>
  344. constexpr complex<T> operator-(const complex<T>& lhs, const T& rhs) {
  345. complex<T> result = lhs;
  346. return result -= rhs;
  347. }
  348. template <typename T>
  349. constexpr complex<T> operator-(const T& lhs, const complex<T>& rhs) {
  350. complex<T> result = -rhs;
  351. return result += lhs;
  352. }
  353. template <typename T>
  354. constexpr complex<T> operator*(const complex<T>& lhs, const complex<T>& rhs) {
  355. complex<T> result = lhs;
  356. return result *= rhs;
  357. }
  358. template <typename T>
  359. constexpr complex<T> operator*(const complex<T>& lhs, const T& rhs) {
  360. complex<T> result = lhs;
  361. return result *= rhs;
  362. }
  363. template <typename T>
  364. constexpr complex<T> operator*(const T& lhs, const complex<T>& rhs) {
  365. complex<T> result = rhs;
  366. return result *= lhs;
  367. }
  368. template <typename T>
  369. constexpr complex<T> operator/(const complex<T>& lhs, const complex<T>& rhs) {
  370. complex<T> result = lhs;
  371. return result /= rhs;
  372. }
  373. template <typename T>
  374. constexpr complex<T> operator/(const complex<T>& lhs, const T& rhs) {
  375. complex<T> result = lhs;
  376. return result /= rhs;
  377. }
  378. template <typename T>
  379. constexpr complex<T> operator/(const T& lhs, const complex<T>& rhs) {
  380. complex<T> result(lhs, T());
  381. return result /= rhs;
  382. }
  383. // Define operators between integral scalars and c10::complex. std::complex does
  384. // not support this when T is a floating-point number. This is useful because it
  385. // saves a lot of "static_cast" when operate a complex and an integer. This
  386. // makes the code both less verbose and potentially more efficient.
  387. #define COMPLEX_INTEGER_OP_TEMPLATE_CONDITION \
  388. typename std::enable_if_t< \
  389. std::is_floating_point_v<fT> && std::is_integral_v<iT>, \
  390. int> = 0
  391. template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
  392. constexpr c10::complex<fT> operator+(const c10::complex<fT>& a, const iT& b) {
  393. return a + static_cast<fT>(b);
  394. }
  395. template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
  396. constexpr c10::complex<fT> operator+(const iT& a, const c10::complex<fT>& b) {
  397. return static_cast<fT>(a) + b;
  398. }
  399. template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
  400. constexpr c10::complex<fT> operator-(const c10::complex<fT>& a, const iT& b) {
  401. return a - static_cast<fT>(b);
  402. }
  403. template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
  404. constexpr c10::complex<fT> operator-(const iT& a, const c10::complex<fT>& b) {
  405. return static_cast<fT>(a) - b;
  406. }
  407. template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
  408. constexpr c10::complex<fT> operator*(const c10::complex<fT>& a, const iT& b) {
  409. return a * static_cast<fT>(b);
  410. }
  411. template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
  412. constexpr c10::complex<fT> operator*(const iT& a, const c10::complex<fT>& b) {
  413. return static_cast<fT>(a) * b;
  414. }
  415. template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
  416. constexpr c10::complex<fT> operator/(const c10::complex<fT>& a, const iT& b) {
  417. return a / static_cast<fT>(b);
  418. }
  419. template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
  420. constexpr c10::complex<fT> operator/(const iT& a, const c10::complex<fT>& b) {
  421. return static_cast<fT>(a) / b;
  422. }
  423. #undef COMPLEX_INTEGER_OP_TEMPLATE_CONDITION
  424. template <typename T>
  425. constexpr bool operator==(const complex<T>& lhs, const complex<T>& rhs) {
  426. return (lhs.real() == rhs.real()) && (lhs.imag() == rhs.imag());
  427. }
  428. template <typename T>
  429. constexpr bool operator==(const complex<T>& lhs, const T& rhs) {
  430. return (lhs.real() == rhs) && (lhs.imag() == T());
  431. }
  432. template <typename T>
  433. constexpr bool operator==(const T& lhs, const complex<T>& rhs) {
  434. return (lhs == rhs.real()) && (T() == rhs.imag());
  435. }
  436. template <typename T>
  437. constexpr bool operator!=(const complex<T>& lhs, const complex<T>& rhs) {
  438. return !(lhs == rhs);
  439. }
  440. template <typename T>
  441. constexpr bool operator!=(const complex<T>& lhs, const T& rhs) {
  442. return !(lhs == rhs);
  443. }
  444. template <typename T>
  445. constexpr bool operator!=(const T& lhs, const complex<T>& rhs) {
  446. return !(lhs == rhs);
  447. }
  448. template <typename T, typename CharT, typename Traits>
  449. std::basic_ostream<CharT, Traits>& operator<<(
  450. std::basic_ostream<CharT, Traits>& os,
  451. const complex<T>& x) {
  452. return (os << static_cast<std::complex<T>>(x));
  453. }
  454. template <typename T, typename CharT, typename Traits>
  455. std::basic_istream<CharT, Traits>& operator>>(
  456. std::basic_istream<CharT, Traits>& is,
  457. complex<T>& x) {
  458. std::complex<T> tmp;
  459. is >> tmp;
  460. x = tmp;
  461. return is;
  462. }
  463. } // namespace c10
  464. // std functions
  465. //
  466. // The implementation of these functions also follow the design of C++20
  467. namespace std {
  468. template <typename T>
  469. constexpr T real(const c10::complex<T>& z) {
  470. return z.real();
  471. }
  472. template <typename T>
  473. constexpr T imag(const c10::complex<T>& z) {
  474. return z.imag();
  475. }
  476. template <typename T>
  477. C10_HOST_DEVICE T abs(const c10::complex<T>& z) {
  478. #if defined(__CUDACC__) || defined(__HIPCC__)
  479. return thrust::abs(static_cast<thrust::complex<T>>(z));
  480. #else
  481. return std::abs(static_cast<std::complex<T>>(z));
  482. #endif
  483. }
  484. #if defined(USE_ROCM)
  485. #define ROCm_Bug(x)
  486. #else
  487. #define ROCm_Bug(x) x
  488. #endif
  489. template <typename T>
  490. C10_HOST_DEVICE T arg(const c10::complex<T>& z) {
  491. return ROCm_Bug(std)::atan2(std::imag(z), std::real(z));
  492. }
  493. #undef ROCm_Bug
  494. template <typename T>
  495. constexpr T norm(const c10::complex<T>& z) {
  496. return z.real() * z.real() + z.imag() * z.imag();
  497. }
  498. // For std::conj, there are other versions of it:
  499. // constexpr std::complex<float> conj( float z );
  500. // template< class DoubleOrInteger >
  501. // constexpr std::complex<double> conj( DoubleOrInteger z );
  502. // constexpr std::complex<long double> conj( long double z );
  503. // These are not implemented
  504. // TODO(@zasdfgbnm): implement them as c10::conj
  505. template <typename T>
  506. constexpr c10::complex<T> conj(const c10::complex<T>& z) {
  507. return c10::complex<T>(z.real(), -z.imag());
  508. }
  509. // Thrust does not have complex --> complex version of thrust::proj,
  510. // so this function is not implemented at c10 right now.
  511. // TODO(@zasdfgbnm): implement it by ourselves
  512. // There is no c10 version of std::polar, because std::polar always
  513. // returns std::complex. Use c10::polar instead;
  514. } // namespace std
  515. namespace c10 {
  516. template <typename T>
  517. C10_HOST_DEVICE complex<T> polar(const T& r, const T& theta = T()) {
  518. #if defined(__CUDACC__) || defined(__HIPCC__)
  519. return static_cast<complex<T>>(thrust::polar(r, theta));
  520. #else
  521. // std::polar() requires r >= 0, so spell out the explicit implementation to
  522. // avoid a branch.
  523. return complex<T>(r * std::cos(theta), r * std::sin(theta));
  524. #endif
  525. }
  526. } // namespace c10
  527. C10_CLANG_DIAGNOSTIC_POP()
  528. #define C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H
  529. // math functions are included in a separate file
  530. #include <c10/util/complex_math.h> // IWYU pragma: keep
  531. // utilities for complex types
  532. #include <c10/util/complex_utils.h> // IWYU pragma: keep
  533. #undef C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H