InferSize.h 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. #pragma once
  2. #include <ATen/DimVector.h>
  3. #include <c10/core/ScalarType.h>
  4. #include <c10/core/SymIntArrayRef.h>
  5. #include <c10/util/DimVector.h>
  6. #include <c10/util/Optional.h>
  7. #include <sstream>
  8. #include <vector>
  9. namespace at {
  10. // Infers the size of a dim with size -1, if it exists. Also checks that new
  11. // shape is compatible with the number of elements.
  12. //
  13. // templated to handle std::vector<int64_t> and DimVector use cases, see
  14. // below
  15. //
  16. template <typename InputArrayRef, typename NumelType, typename ResultVec>
  17. inline void infer_size_impl(
  18. InputArrayRef shape,
  19. NumelType numel,
  20. ResultVec& res) {
  21. NumelType newsize = 1;
  22. // N.B. this is an index, not a sym dim!
  23. auto infer_dim = std::optional<int64_t>();
  24. for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) {
  25. if (shape[dim] == -1) {
  26. if (infer_dim) {
  27. throw std::runtime_error("only one dimension can be inferred");
  28. }
  29. infer_dim = dim;
  30. } else if (shape[dim] >= 0) {
  31. newsize *= shape[dim];
  32. } else {
  33. AT_ERROR("invalid shape dimension ", shape[dim]);
  34. }
  35. }
  36. if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, newsize)) ||
  37. (infer_dim && newsize > 0 && numel % newsize == 0)) {
  38. if (infer_dim) {
  39. // We have a degree of freedom here to select the dimension size; follow
  40. // NumPy semantics and just bail. However, a nice error message is needed
  41. // because users often use `view` as a way to flatten & unflatten
  42. // dimensions and will otherwise be confused why
  43. // empty_tensor.view( 0, 0)
  44. // works yet
  45. // empty_tensor.view(-1, 0)
  46. // doesn't.
  47. TORCH_CHECK(
  48. newsize != 0,
  49. "cannot reshape tensor of 0 elements into shape ",
  50. shape,
  51. " because the unspecified dimension size -1 can be any "
  52. "value and is ambiguous");
  53. res[*infer_dim] = numel / newsize;
  54. }
  55. return;
  56. }
  57. std::ostringstream ss;
  58. ss << "shape '" << shape << "' is invalid for input of size " << numel;
  59. throw std::runtime_error(ss.str());
  60. }
  61. inline std::vector<int64_t> infer_size(IntArrayRef shape, int64_t numel) {
  62. auto res = shape.vec();
  63. infer_size_impl(shape, numel, res);
  64. return res;
  65. }
  66. inline at::DimVector infer_size_dv(IntArrayRef shape, int64_t numel) {
  67. auto res = at::DimVector(shape);
  68. infer_size_impl(shape, numel, res);
  69. return res;
  70. }
  71. inline at::SymDimVector infer_size_dv(
  72. c10::SymIntArrayRef shape,
  73. c10::SymInt numel) {
  74. auto res = at::SymDimVector(shape);
  75. infer_size_impl<c10::SymIntArrayRef, c10::SymInt, at::SymDimVector>(
  76. shape, std::move(numel), res);
  77. return res;
  78. }
  79. } // namespace at