TensorShape.h 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <c10/util/irange.h>
  4. #include <ATen/core/IListRef.h>
  5. namespace at::native {
  6. TORCH_API at::Tensor clone_preserve_strides(const at::Tensor& self);
  7. inline bool cat_should_skip_tensor(const Tensor& t) {
  8. return t.sym_numel() == 0 && t.dim() == 1;
  9. }
  10. // Check to see if the shape of tensors is compatible
  11. // for being concatenated along a given dimension.
  12. inline void check_cat_shape_except_dim(const Tensor & first, const Tensor & second, int64_t dimension, int64_t index) {
  13. int64_t first_dims = first.dim();
  14. int64_t second_dims = second.dim();
  15. TORCH_CHECK(first_dims == second_dims, "Tensors must have same number of dimensions: got ",
  16. first_dims, " and ", second_dims);
  17. for (const auto dim : c10::irange(first_dims)) {
  18. if (dim == dimension) {
  19. continue;
  20. }
  21. int64_t first_dim_size = first.sizes()[dim];
  22. int64_t second_dim_size = second.sizes()[dim];
  23. TORCH_CHECK(first_dim_size == second_dim_size, "Sizes of tensors must match except in dimension ",
  24. dimension, ". Expected size ", static_cast<long long>(first_dim_size), " but got size ", static_cast<long long>(second_dim_size), " for tensor number ", index, " in the list.");
  25. }
  26. }
  27. inline void check_cat_no_zero_dim(const MaterializedITensorListRef& tensors) {
  28. int64_t i = 0;
  29. for(const Tensor& t : tensors) {
  30. TORCH_CHECK(t.dim() > 0,
  31. "zero-dimensional tensor (at position ", i, ") cannot be concatenated");
  32. i++;
  33. }
  34. }
  35. inline int64_t get_num_splits(const Tensor& self, int64_t split_size, int64_t dim) {
  36. TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor");
  37. TORCH_CHECK(split_size >= 0, "split expects split_size be non-negative, but got split_size=", split_size);
  38. int64_t dim_size = self.size(dim);
  39. TORCH_CHECK(split_size > 0 || dim_size == 0,
  40. "split_size can only be 0 if dimension size is 0, "
  41. "but got dimension size of ", dim_size);
  42. // if split_size is 0 and dimension size is 0, there is 1 split.
  43. int64_t num_splits = 1;
  44. if (split_size != 0) {
  45. // ensuring num_splits is at least 1 makes consistent the case where split_size > dim_size
  46. // (returns a single split). We might want to error here, but keep it for BC.
  47. num_splits = std::max<int64_t>((dim_size + split_size - 1) / split_size, 1);
  48. }
  49. return num_splits;
  50. }
  51. inline bool have_same_ndims(TensorList tensors) {
  52. auto ndim = tensors[0].dim();
  53. for (const auto tensor_idx : c10::irange(tensors.size())) {
  54. if(tensors[tensor_idx].dim() != ndim) {
  55. return false;
  56. }
  57. }
  58. return true;
  59. }
  60. inline void leading_dimension_matches(TensorList tensors, int64_t dim) {
  61. auto tensor_zero_size = tensors[0].sizes();
  62. std::vector<c10::SymInt> leading_dim_sizes(tensor_zero_size.begin(), tensor_zero_size.begin() + dim);
  63. for (const auto i : c10::irange(tensors.size())) {
  64. at::Tensor tensor = tensors[i];
  65. for(const auto j : c10::irange(dim)) {
  66. TORCH_CHECK(
  67. tensor.size(j) == leading_dim_sizes[j],
  68. "_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors"
  69. );
  70. }
  71. }
  72. }
  73. inline int64_t preprocess_chunk_cat_inputs(TensorList tensors, int64_t dim, int64_t num_chunks) {
  74. TORCH_CHECK(num_chunks >= 1, "_chunk_cat expects positive num_chunks");
  75. TORCH_CHECK(!tensors.empty(),
  76. "_chunk_cat expects a non-empty input tensor list");
  77. auto expected_dtype = tensors[0].dtype();
  78. auto expected_device = tensors[0].device();
  79. for(const auto i : c10::irange(tensors.size())) {
  80. TORCH_CHECK(tensors[i].numel() > 0, "_chunk_cat expects non-empty tensor");
  81. TORCH_CHECK(tensors[i].dtype() == expected_dtype, "_chunk_cat expects all input tensors with the same dtype");
  82. TORCH_CHECK(tensors[i].device() == expected_device, "_chunk_cat expects all inputs tensors on the same device");
  83. }
  84. if (have_same_ndims(tensors)) {
  85. dim = maybe_wrap_dim(dim, tensors[0].dim());
  86. } else {
  87. TORCH_CHECK(dim >= 0, "_chunk_cat expects non-negative dim when input tensors have different ndims")
  88. for(const auto i : c10::irange(tensors.size())) {
  89. TORCH_CHECK(dim < tensors[i].ndimension(), "_chunk_cat expects dim < ndim for all input tensors");
  90. }
  91. }
  92. leading_dimension_matches(tensors, dim);
  93. return dim;
  94. }
  95. } // namespace at::native