generic_math.h 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. #pragma once
  2. #include <c10/macros/Macros.h>
  3. #include <c10/util/TypeSafeSignMath.h>
  4. #include <cmath>
  5. #if defined(__CUDA_ARCH__)
  6. #include <c10/cuda/CUDAMathCompat.h>
  7. #define C10_COMPAT_COPYSIGN c10::cuda::compat::copysign
  8. #elif defined(__HIPCC__)
  9. #include <c10/hip/HIPMathCompat.h>
  10. #define C10_COMPAT_COPYSIGN c10::hip::compat::copysign
  11. #else
  12. #include <c10/util/copysign.h>
  13. #define C10_COMPAT_COPYSIGN c10::copysign
  14. #endif
  15. // The functions in this file should be header-only as it is used under
  16. // ABI-compatibility mode.
  17. namespace c10 {
  18. // NOTE: [Floor Division in Python]
  19. // Python's __floordiv__ operator is more complicated than just floor(a / b).
  20. // It aims to maintain the property: a == (a // b) * b + remainder(a, b)
  21. // which can otherwise fail due to rounding errors in the remainder.
  22. // So, instead it is calculated as: a // b = (a - remainder(a, b)) / b
  23. // With some additional fix-ups added to the result.
  24. //
  25. // For reference, see CPython's implementation:
  26. // https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636
  27. template <typename scalar_t>
  28. inline C10_HOST_DEVICE scalar_t div_floor_floating(scalar_t a, scalar_t b)
  29. __ubsan_ignore_float_divide_by_zero__ {
  30. if (C10_UNLIKELY(b == 0)) {
  31. // Divide by zero: return standard IEEE result
  32. return a / b;
  33. }
  34. auto mod = std::fmod(a, b);
  35. auto div = (a - mod) / b;
  36. if ((mod != 0) && (b < 0) != (mod < 0)) {
  37. div -= scalar_t(1);
  38. }
  39. scalar_t floordiv;
  40. if (div != 0) {
  41. floordiv = std::floor(div);
  42. if (div - floordiv > scalar_t(0.5)) {
  43. floordiv += scalar_t(1.0);
  44. }
  45. } else {
  46. floordiv = C10_COMPAT_COPYSIGN(scalar_t(0), a / b);
  47. }
  48. return floordiv;
  49. }
  50. template <typename scalar_t>
  51. inline C10_HOST_DEVICE scalar_t div_floor_integer(scalar_t a, scalar_t b) {
  52. if (c10::signs_differ(a, b)) {
  53. // Subtracts one from the results of truncation division if the
  54. // divisor and dividend have different sign(bit)s and the remainder of
  55. // the division is nonzero
  56. const auto quot = a / b;
  57. const auto rem = a % b;
  58. return rem ? quot - 1 : quot;
  59. }
  60. return a / b;
  61. }
  62. } // namespace c10