ParamsHash.h 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. #pragma once
  2. #include <c10/util/irange.h>
  3. #include <memory>
  4. #include <mutex>
  5. namespace at::native {
  6. // Hashing machinery for Params
  7. // Fowler–Noll–Vo hash function
  8. // see
  9. // https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
  10. template <typename Params>
  11. struct ParamsHash {
  12. // Params must be a POD because we read out its memory
  13. // contents as char* when hashing
  14. static_assert(std::is_standard_layout_v<Params>, "Params is not POD");
  15. size_t operator()(const Params& params) const {
  16. auto ptr = reinterpret_cast<const uint8_t*>(&params);
  17. uint32_t value = 0x811C9DC5;
  18. for (const auto i : c10::irange(sizeof(Params))) {
  19. value ^= ptr[i];
  20. value *= 0x01000193;
  21. }
  22. return (size_t)value;
  23. }
  24. };
  25. template <typename Params>
  26. struct ParamsEqual {
  27. // Params must be a POD because we read out its memory
  28. // contents as char* when comparing
  29. static_assert(std::is_standard_layout_v<Params>, "Params is not POD");
  30. bool operator()(const Params& a, const Params& b) const {
  31. auto ptr1 = reinterpret_cast<const uint8_t*>(&a);
  32. auto ptr2 = reinterpret_cast<const uint8_t*>(&b);
  33. return memcmp(ptr1, ptr2, sizeof(Params)) == 0;
  34. }
  35. };
  36. // Provide explicit byte-for-byte constructors to avoid uwittingly leaving
  37. // padding bytes unitialized (e.g., when passing Params by value)
  38. template <typename T>
  39. struct ParamsWrapper {
  40. T pod;
  41. static_assert(
  42. std::is_standard_layout_v<T>,
  43. "ParamsWrapper cannot wrap non-POD data");
  44. ParamsWrapper() {
  45. memset(&(this->pod), 0, sizeof(this->pod));
  46. }
  47. ParamsWrapper(const ParamsWrapper& other) {
  48. memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
  49. }
  50. ParamsWrapper(ParamsWrapper&& other) noexcept {
  51. memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
  52. }
  53. ParamsWrapper& operator=(const ParamsWrapper& other) {
  54. memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
  55. return *this;
  56. }
  57. ParamsWrapper& operator=(ParamsWrapper&& other) noexcept {
  58. memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
  59. return *this;
  60. }
  61. inline friend bool operator==(
  62. const ParamsWrapper& lhs,
  63. const ParamsWrapper& rhs) noexcept {
  64. auto ptr1 = reinterpret_cast<const uint8_t*>(&(lhs.pod));
  65. auto ptr2 = reinterpret_cast<const uint8_t*>(&(rhs.pod));
  66. return memcmp(ptr1, ptr2, sizeof(lhs.pod)) == 0;
  67. }
  68. };
  69. // Wrapped version: this allows the outer struct to have custom copy and move
  70. // constructors for additional safety
  71. template <typename ParamsWrapper>
  72. struct ParamsWrapperHash {
  73. // Params must be a POD because we read out its memory
  74. // contents as char* when hashing
  75. static_assert(
  76. std::is_standard_layout_v<decltype(ParamsWrapper::pod)>,
  77. "ParamsWrapper cannot wrap non-POD data");
  78. size_t operator()(const ParamsWrapper& params_wrapper) const {
  79. auto ptr = reinterpret_cast<const uint8_t*>(&(params_wrapper.pod));
  80. uint32_t value = 0x811C9DC5;
  81. for (const auto i : c10::irange(sizeof(params_wrapper.pod))) {
  82. value ^= ptr[i];
  83. value *= 0x01000193;
  84. }
  85. return (size_t)value;
  86. }
  87. };
  88. } // namespace at::native