TensorDimApply.h 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <c10/util/irange.h>
  4. namespace at::native {
  5. //input tensors are non-zero dim and non-empty
  6. template<typename T1, typename T2, typename Function>
  7. void tensor_dim_apply3(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim, Function func) {
  8. int ndims = self.dim();
  9. int tensor_dim_apply_has_finished = 0;
  10. std::vector<int64_t> counter(ndims, 0);
  11. const T1* self_data = self.const_data_ptr<T1>();
  12. T1* values_data = values.data_ptr<T1>();
  13. T2* indices_data = indices.data_ptr<T2>();
  14. int64_t self_stride = self.stride(dim);
  15. int64_t values_stride = values.stride(dim);
  16. int64_t indices_stride = indices.stride(dim);
  17. int self_dim_size = self.size(dim);
  18. while (!tensor_dim_apply_has_finished) {
  19. func(self_data, values_data, indices_data, self_dim_size, self_stride, values_stride, indices_stride);
  20. if (ndims == 1) {
  21. break;
  22. }
  23. for (const auto dim_i : c10::irange(ndims)) {
  24. if (dim_i == dim) {
  25. if (dim_i == (ndims - 1)) {
  26. tensor_dim_apply_has_finished = 1;
  27. break;
  28. }
  29. continue;
  30. }
  31. counter[dim_i]++;
  32. self_data += self.stride(dim_i);
  33. values_data += values.stride(dim_i);
  34. indices_data += indices.stride(dim_i);
  35. if (counter[dim_i] == self.size(dim_i)) {
  36. if (dim_i == ndims-1) {
  37. tensor_dim_apply_has_finished = 1;
  38. break;
  39. } else {
  40. self_data -= counter[dim_i]*self.stride(dim_i);
  41. values_data -= counter[dim_i]*values.stride(dim_i);
  42. indices_data -= counter[dim_i]*indices.stride(dim_i);
  43. counter[dim_i] = 0;
  44. }
  45. } else {
  46. break;
  47. }
  48. }
  49. }
  50. }
  51. } // namespace at::native