WrapDimUtilsMulti.h 1.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. #pragma once
  2. #include <ATen/WrapDimUtils.h>
  3. #include <c10/core/TensorImpl.h>
  4. #include <c10/util/irange.h>
  5. #include <bitset>
  6. #include <sstream>
  7. namespace at {
  8. // This is in an extra file to work around strange interaction of
  9. // bitset on Windows with operator overloading
  10. constexpr size_t dim_bitset_size = 64;
  11. static inline std::bitset<dim_bitset_size> dim_list_to_bitset(
  12. OptionalIntArrayRef opt_dims,
  13. size_t ndims) {
  14. TORCH_CHECK(
  15. ndims <= dim_bitset_size,
  16. "only tensors with up to ",
  17. dim_bitset_size,
  18. " dims are supported");
  19. std::bitset<dim_bitset_size> seen;
  20. if (opt_dims.has_value()) {
  21. auto dims = opt_dims.value();
  22. for (const auto i : c10::irange(dims.size())) {
  23. size_t dim = maybe_wrap_dim(dims[i], static_cast<int64_t>(ndims));
  24. TORCH_CHECK(
  25. !seen[dim],
  26. "dim ",
  27. dim,
  28. " appears multiple times in the list of dims");
  29. seen[dim] = true;
  30. }
  31. } else {
  32. for (size_t dim = 0; dim < ndims; dim++) {
  33. seen[dim] = true;
  34. }
  35. }
  36. return seen;
  37. }
  38. } // namespace at