EmbeddingBag.h 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. #include <ATen/core/Tensor.h>
  2. #include <ATen/Config.h>
  3. #include <cstdint>
  4. #ifdef USE_FBGEMM
  5. #include <fbgemm/FbgemmEmbedding.h>
  6. #endif
  7. namespace at::native {
  8. void check_arguments(
  9. const Tensor& weight,
  10. const Tensor& indices,
  11. const Tensor& offsets,
  12. const int64_t mode,
  13. const std::optional<Tensor>& per_sample_weights,
  14. bool include_last_offset);
  15. void make_bag_size_out(
  16. Tensor& bag_size_out,
  17. const Tensor& offsets,
  18. const Tensor& indices,
  19. const int64_t mode,
  20. const bool include_last_offset,
  21. const bool requires_grad);
  22. void make_max_indices_out(
  23. Tensor& max_indices_out,
  24. const Tensor& weight,
  25. const Tensor& indices,
  26. const Tensor& offsets,
  27. const Tensor& bag_size,
  28. const int64_t mode,
  29. bool include_last_offset);
  30. void make_offset2bag_out(
  31. Tensor& offset2bag,
  32. Tensor& output,
  33. const Tensor& weight,
  34. const Tensor& indices,
  35. const Tensor& offsets,
  36. const int64_t mode,
  37. const std::optional<Tensor>& per_sample_weights,
  38. const int64_t padding_idx = -1);
  39. #ifdef USE_FBGEMM
  40. template<bool has_weight, typename TIndex, typename TData>
  41. struct _CallbackAndBlockSize {
  42. using TCallback = typename fbgemm::EmbeddingSpMDMKernelSignature<TData, TIndex, TIndex, TData>::Type;
  43. int64_t blockSize = -1;
  44. TCallback callback = nullptr;
  45. static TCallback generateCallback(int64_t block_size) {
  46. return fbgemm::GenerateEmbeddingSpMDM<TData, TIndex, TIndex, TData>(
  47. block_size,
  48. has_weight,
  49. /* normalize_by_lengths */false,
  50. /* prefetch */16,
  51. /* is_weight_positional */false,
  52. /* use_offsets */true);
  53. }
  54. _CallbackAndBlockSize() = default;
  55. explicit _CallbackAndBlockSize(std::optional<int64_t> maybe_block_size)
  56. : blockSize(maybe_block_size.value_or(-1))
  57. , callback(maybe_block_size.has_value() ? generateCallback(maybe_block_size.value()) : nullptr)
  58. {}
  59. };
  60. template<typename... StorageMixins>
  61. struct _EmbeddingBagKernelCacheImpl : private StorageMixins... {
  62. _EmbeddingBagKernelCacheImpl() = default;
  63. // use each of the mixins to store corresponding kernel and block size
  64. explicit _EmbeddingBagKernelCacheImpl(std::optional<int64_t> maybe_block_size)
  65. : StorageMixins(maybe_block_size)...
  66. {}
  67. // this method is thread safe (call sites may call from different threads)
  68. template<bool has_weight, typename TIndex, typename TData>
  69. typename _CallbackAndBlockSize<has_weight, TIndex, TData>::TCallback
  70. getCallback(int64_t block_size) const {
  71. // if the cache doesn't store the kernel for the incoming block size
  72. // (so it is different from the one stored in corresponding mixin)
  73. // regenerate the kernel (not writing it into the cache so we avoid locks)
  74. if (block_size != _CallbackAndBlockSize<has_weight, TIndex, TData>::blockSize) {
  75. return _CallbackAndBlockSize<has_weight, TIndex, TData>::generateCallback(block_size);
  76. }
  77. // else retrieve the cached kernel from the corresponding mixin
  78. return _CallbackAndBlockSize<has_weight, TIndex, TData>::callback;
  79. }
  80. };
  81. // instantiate the cache with the list of storage mixins
  82. // for each of the 8 _EmbeddingBagKernelCache* usages in the EmbeddingBag.cpp impl file
  83. using _EmbeddingBagKernelCache = _EmbeddingBagKernelCacheImpl<
  84. _CallbackAndBlockSize<true, int32_t, float>,
  85. _CallbackAndBlockSize<false, int32_t, float>,
  86. _CallbackAndBlockSize<true, int64_t, float>,
  87. _CallbackAndBlockSize<false, int64_t, float>,
  88. _CallbackAndBlockSize<true, int32_t, unsigned short>,
  89. _CallbackAndBlockSize<false, int32_t, unsigned short>,
  90. _CallbackAndBlockSize<true, int64_t, unsigned short>,
  91. _CallbackAndBlockSize<false, int64_t, unsigned short>>;
  92. #else
  93. struct _EmbeddingBagKernelCache {
  94. explicit _EmbeddingBagKernelCache(std::optional<int64_t> /* maybe_block_size */) {}
  95. };
  96. #endif
  97. void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag,
  98. Tensor& bag_size, Tensor* max_indices,
  99. const Tensor &weight, const Tensor &indices,
  100. const Tensor &offsets, const int64_t mode = 0,
  101. const std::optional<Tensor>& per_sample_weights = c10::nullopt,
  102. bool include_last_offset = false,
  103. int64_t padding_idx = -1,
  104. _EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
  105. void _embedding_bag_cpu_out(
  106. at::Tensor& output,
  107. at::Tensor& offset2bag,
  108. at::Tensor& bag_size,
  109. at::Tensor* p_max_indices,
  110. const at::Tensor& weight,
  111. const at::Tensor& indices,
  112. const at::Tensor& offsets,
  113. const bool scale_grad_by_freq,
  114. const int64_t mode,
  115. const bool sparse,
  116. const std::optional<at::Tensor>& per_sample_weights,
  117. const bool include_last_offset,
  118. const std::optional<int64_t>& padding_idx,
  119. _EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
  120. } // namespace at::native