Repeat.h 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <ATen/TensorOperators.h>
  4. #ifndef AT_PER_OPERATOR_HEADERS
  5. #include <ATen/Functions.h>
  6. #else
  7. #include <ATen/ops/empty.h>
  8. #include <ATen/ops/empty_like.h>
  9. #endif
  10. namespace at::native {
  11. template <
  12. typename index_t,
  13. void compute(const index_t*, const int64_t*, index_t*, int64_t, int64_t)>
  14. static inline Tensor repeat_interleave_common(
  15. const Tensor& repeats,
  16. std::optional<int64_t> output_size) {
  17. TORCH_CHECK(
  18. repeats.dim() == 1, "repeat_interleave only accept 1D vector as repeat");
  19. TORCH_CHECK(
  20. repeats.scalar_type() == at::kLong || repeats.scalar_type() == at::kInt,
  21. "repeats has to be Long or Int tensor");
  22. if (repeats.size(0) == 0) {
  23. return at::empty_like(repeats, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  24. }
  25. Tensor repeats_ = repeats.contiguous();
  26. Tensor cumsum = repeats.cumsum(0);
  27. int64_t total;
  28. if (output_size.has_value()) {
  29. total = output_size.value();
  30. } else {
  31. total = cumsum[-1].item<int64_t>();
  32. TORCH_CHECK(
  33. (repeats >= 0).all().item<uint8_t>(), "repeats can not be negative");
  34. }
  35. Tensor result = at::empty({total}, repeats.options());
  36. const index_t* repeat_ptr = repeats_.const_data_ptr<index_t>();
  37. const int64_t* cumsum_ptr = cumsum.const_data_ptr<int64_t>();
  38. index_t* result_ptr = result.data_ptr<index_t>();
  39. compute(repeat_ptr, cumsum_ptr, result_ptr, repeats.size(0), total);
  40. return result;
  41. }
  42. } // namespace at::native