TensorAdvancedIndexing.h 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. #pragma once
  2. // Indexing tensors by tensors
  3. #include <ATen/core/List.h>
  4. #include <ATen/core/Tensor.h>
  5. #include <ATen/native/DispatchStub.h>
  6. #include <ATen/native/ReductionType.h>
  7. namespace at {
  8. struct TensorIterator;
  9. }
  10. namespace at::native {
  11. using index_put_with_sort_fn = void(*)(Tensor &, const c10::List<std::optional<Tensor>> &, const Tensor &, bool accumulate, bool unsafe);
  12. using index_put_with_sort_quantized_fn = void(*)(Tensor& self, const c10::List<std::optional<Tensor>>& indices, const Tensor& value, double scale, int zero_point, bool unsafe);
  13. using gather_fn = void (*)(const Tensor & result, const Tensor & self, int64_t dim, const Tensor & index);
  14. using scatter_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src);
  15. using scatter_fill_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& src);
  16. using scatter_add_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src);
  17. using scatter_reduce_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
  18. const Tensor& src, const ReductionType& reduce);
  19. using scatter_scalar_reduce_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
  20. const Scalar& value, const ReductionType& reduce);
  21. using scatter_reduce_two_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
  22. const Tensor& src, const ReductionType& reduce);
  23. DECLARE_DISPATCH(index_put_with_sort_fn, index_put_with_sort_stub);
  24. DECLARE_DISPATCH(index_put_with_sort_quantized_fn, index_put_with_sort_quantized_stub);
  25. DECLARE_DISPATCH(gather_fn, gather_stub);
  26. DECLARE_DISPATCH(scatter_fn, scatter_stub);
  27. DECLARE_DISPATCH(scatter_fill_fn, scatter_fill_stub);
  28. DECLARE_DISPATCH(scatter_add_fn, scatter_add_stub);
  29. DECLARE_DISPATCH(scatter_reduce_fn, scatter_reduce_stub);
  30. DECLARE_DISPATCH(scatter_scalar_reduce_fn, scatter_scalar_reduce_stub);
  31. DECLARE_DISPATCH(scatter_reduce_two_fn, scatter_reduce_two_stub);
  32. TORCH_API Tensor& index_out(Tensor& result, const Tensor & self, const c10::List<std::optional<at::Tensor>>& indices);
  33. using scatter_add_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&);
  34. using scatter_reduce_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const ReductionType& reduce, bool);
  35. using gather_expanded_index_fn = void (*)(const Tensor&, const Tensor&, const Tensor&);
  36. DECLARE_DISPATCH(scatter_add_expanded_index_fn, scatter_add_expanded_index_stub);
  37. DECLARE_DISPATCH(scatter_reduce_expanded_index_fn, scatter_reduce_expanded_index_stub);
  38. DECLARE_DISPATCH(gather_expanded_index_fn, gather_expanded_index_stub);
  39. } // namespace at::native