BatchRulesHelper.h 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. // Copyright (c) Facebook, Inc. and its affiliates.
  2. // All rights reserved.
  3. //
  4. // This source code is licensed under the BSD-style license found in the
  5. // LICENSE file in the root directory of this source tree.
  6. #pragma once
  7. #include <c10/util/TypeList.h>
  8. #include <ATen/ATen.h>
  9. #include <ATen/Operators.h>
  10. #include <ATen/functorch/DynamicLayer.h>
  11. #include <ATen/functorch/TensorWrapper.h>
  12. #include <ATen/functorch/BatchingMetaprogramming.h>
  13. #include <ATen/functorch/LegacyVmapTransforms.h>
  14. #include <ATen/functorch/BatchedFallback.h>
  15. #include <ATen/functorch/PlumbingHelper.h>
  16. #include <ATen/core/dispatch/Dispatcher.h>
  17. #include <ATen/VmapGeneratedPlumbing.h>
  18. #include <utility>
  19. // This file contains helper functions for batching rules.
  20. namespace at::functorch {
  21. TORCH_API Tensor reshape_dim_into(int64_t src, int64_t dst, const Tensor& x);
  22. TORCH_API Tensor reshape_dim_outof(int64_t src, int64_t size1, const Tensor& x);
  23. TORCH_API Tensor reshape_dim_outof_symint(int64_t src, const c10::SymInt& size1, const Tensor& x);
  24. Tensor moveBatchDimToFront(const Tensor& tensor, optional<int64_t> maybe_batch_dim);
  25. int64_t rankWithoutBatchDim(const Tensor& tensor, optional<int64_t> maybe_batch_dim);
  26. int64_t numelWithoutBatchDim(const Tensor& tensor, optional<int64_t> maybe_batch_dim);
  27. optional<int64_t> valIfNonempty(optional<int64_t> maybe_empty, int64_t new_val);
  28. int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim);
  29. VmapDimVector getPhysicalDims(const Tensor& tensor, bool has_batch_dim, IntArrayRef logical_dims);
  30. void vmapIncompatibleInplaceError(const char* schema_name);
  31. Tensor maybePadToLogicalRank(const Tensor& tensor, optional<int64_t> has_bdim, int64_t logical_rank);
  32. void check_randomness(RandomnessType randomness);
  33. void check_randomness(RandomnessType randomness, bool any_tensor_bdim);
  34. inline Tensor ensure_has_bdim(const Tensor& tensor, bool has_bdim, c10::SymInt batch_size) {
  35. if (has_bdim) {
  36. return tensor;
  37. }
  38. const auto sizes = tensor.sym_sizes();
  39. SymDimVector expanded_shape;
  40. expanded_shape.reserve(sizes.size());
  41. expanded_shape.emplace_back(std::move(batch_size));
  42. expanded_shape.insert(expanded_shape.end(), sizes.begin(), sizes.end());
  43. return tensor.expand_symint(expanded_shape);
  44. }
  45. #define VMAP_SUPPORT(op, batch_rule) \
  46. m.impl(#op, op ## _generated_plumbing<decltype(&batch_rule), &batch_rule>);
  47. #define VMAP_SUPPORT2(op, overload, batch_rule) \
  48. m.impl(#op "." #overload, op ## _ ## overload ## _generated_plumbing<decltype(&batch_rule), &batch_rule>);
  49. #define OP_DECOMPOSE(op) m.impl(#op, static_cast<decltype(&ATEN_FN(op))>(native::op));
  50. #define OP_DECOMPOSE2(op, overload) m.impl(#op"."#overload, static_cast<decltype(&ATEN_FN2(op, overload))>(native::op));
  51. // DO NOT USE ME DIRECTLY! Use BASIC_UNARY_BATCH_RULE to save yourself some pain
  52. template <typename A, A a, typename C>
  53. struct BasicUnaryBatchRuleHelper;
  54. template <typename F, F Func, typename A, typename... T>
  55. struct BasicUnaryBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
  56. static std::tuple<Tensor,optional<int64_t>> apply(
  57. const Tensor& tensor,
  58. optional<int64_t> batch_dim,
  59. T... extra_args) {
  60. return std::make_tuple(Func(tensor, std::forward<T>(extra_args)...), batch_dim);
  61. }
  62. };
  63. // USAGE: BASIC_UNARY_BATCH_RULE(at::sin)
  64. // INCORRECT USAGE: BASIC_UNARY_BATCH_RULE(&at::sin)
  65. // It is important that this macro is not passed a function pointer!!
  66. #define BASIC_UNARY_BATCH_RULE(fn) SINGLE_ARG(\
  67. BasicUnaryBatchRuleHelper<\
  68. decltype(&fn),\
  69. &fn,\
  70. c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
  71. #define UNARY_POINTWISE(op) \
  72. VMAP_SUPPORT(op, BASIC_UNARY_BATCH_RULE(ATEN_FN(op)));
  73. template <typename A, A a, typename C>
  74. struct VariadicBdimsBatchRuleHelper;
  75. template <typename F, F Func, typename A, typename... T>
  76. struct VariadicBdimsBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
  77. static std::tuple<Tensor,optional<int64_t>> apply(
  78. const Tensor& tensor,
  79. optional<int64_t> batch_dim,
  80. T... extra_args) {
  81. auto tensor_ = moveBatchDimToFront(tensor, batch_dim);
  82. return std::make_tuple(Func(tensor_, std::forward<T>(extra_args)...), 0);
  83. }
  84. };
  85. // USAGE: VARIADIC_BDIMS_BATCH_RULE(at::cholesky_inverse)
  86. // INCORRECT USAGE: VARIADIC_BDIMS_BATCH_RULE(&at::cholesky_inverse)
  87. // It is important that this macro is not passed a function pointer!!
  88. #define VARIADIC_BDIMS_BATCH_RULE(fn) SINGLE_ARG(\
  89. VariadicBdimsBatchRuleHelper<\
  90. decltype(&fn),\
  91. &fn,\
  92. c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
  93. #define VARIADIC_BDIMS(op) \
  94. VMAP_SUPPORT(op, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(op)));
  95. #define VARIADIC_BDIMS2(op, overload) \
  96. VMAP_SUPPORT2(op, overload, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN2(op, overload)));
  97. template<class F, F Func>
  98. void boxed_tensor_inputs_batch_rule(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
  99. const auto& schema = op.schema();
  100. const auto num_returns = schema.returns().size();
  101. const auto num_arguments = schema.arguments().size();
  102. c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
  103. auto maybe_layer = maybeCurrentDynamicLayer();
  104. vmap_check_escaped(maybe_layer, "boxed_tensor_inputs_batch_rule");
  105. int64_t cur_level = maybe_layer->layerId();
  106. auto orig_arguments = torch::jit::last(*stack, num_arguments);
  107. if (std::none_of(orig_arguments.begin(), orig_arguments.end(), ivalueParticipatesInCurrentLevel)) {
  108. op.callBoxed(stack);
  109. return;
  110. }
  111. auto arguments = torch::jit::pop(*stack, num_arguments);
  112. std::vector<std::pair<Tensor, optional<int64_t>>> tensor_inputs;
  113. std::vector<int64_t> tensor_pos;
  114. for (const auto idx : c10::irange(0, num_arguments)) {
  115. const auto& ivalue = arguments[idx];
  116. if (ivalue.isTensor()) {
  117. auto [tensor_value, tensor_bdim] = unwrapTensorAtLevel(ivalue.toTensor(), cur_level);
  118. tensor_inputs.emplace_back(tensor_value, tensor_bdim);
  119. tensor_pos.push_back(static_cast<int64_t>(idx));
  120. }
  121. }
  122. Func(tensor_inputs);
  123. size_t tensor_idx = 0;
  124. TORCH_INTERNAL_ASSERT(!tensor_pos.empty());
  125. for (const auto arg_idx : c10::irange(0, num_arguments)) {
  126. if (tensor_idx >= tensor_pos.size() || (int64_t)arg_idx != tensor_pos[tensor_idx]) {
  127. torch::jit::push(stack, arguments[arg_idx]);
  128. } else {
  129. TORCH_INTERNAL_ASSERT(tensor_idx < tensor_inputs.size());
  130. torch::jit::push(stack, tensor_inputs[tensor_idx].first);
  131. tensor_idx++;
  132. }
  133. }
  134. op.callBoxed(stack);
  135. const auto returns = torch::jit::pop(*stack, num_returns);
  136. for (const auto& ret : returns) {
  137. if (ret.isTensor()) {
  138. torch::jit::push(stack, makeBatched(ret.toTensor(), 0, cur_level));
  139. } else {
  140. TORCH_INTERNAL_ASSERT(false, "This boxed batching rule does not currently support ops that return non-tensor values");
  141. }
  142. }
  143. }
  144. inline void handle_pointwise_ops(std::vector<std::pair<Tensor, optional<int64_t>>> &tensor_inputs) {
  145. int64_t out_logical_rank = 0;
  146. for (auto& tensor_input : tensor_inputs) {
  147. int64_t cur_logical_rank = rankWithoutBatchDim(tensor_input.first, tensor_input.second);
  148. out_logical_rank = std::max(out_logical_rank, cur_logical_rank);
  149. }
  150. for (auto& tensor_input: tensor_inputs) {
  151. tensor_input.first = moveBatchDimToFront(tensor_input.first, tensor_input.second);
  152. tensor_input.first = maybePadToLogicalRank(tensor_input.first, tensor_input.second, out_logical_rank);
  153. }
  154. }
  155. #define POINTWISE_BOXED(op) \
  156. m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_pointwise_ops), &handle_pointwise_ops>>());
  157. #define POINTWISE_BOXED2(op, overload) \
  158. m.impl(#op "." #overload, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_pointwise_ops), &handle_pointwise_ops>>());
  159. inline void handle_variadic_bdims(std::vector<std::pair<Tensor, optional<int64_t>>> &tensor_inputs) {
  160. for (auto & tensor_input : tensor_inputs) {
  161. tensor_input.first = moveBatchDimToFront(tensor_input.first, tensor_input.second);
  162. }
  163. }
  164. #define VARIADIC_BDIMS_BOXED(op) \
  165. m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_variadic_bdims), &handle_variadic_bdims>>());
  166. using UnpackedBatchedTensor = std::tuple<Tensor,optional<int64_t>>;
  167. inline void find_and_unpack_tensors(
  168. const torch::jit::Stack* stack,
  169. int64_t num_args,
  170. int64_t cur_level,
  171. SmallVector<UnpackedBatchedTensor, 5>* tensors,
  172. SmallVector<int64_t, 5>* tensors_pos,
  173. int64_t* batch_size) {
  174. int64_t computed_batch_size = -1;
  175. int64_t args_begin = static_cast<int64_t>(stack->size()) - num_args;
  176. for (const auto idx : c10::irange(0, num_args)) {
  177. const auto& ivalue = (*stack)[args_begin + idx];
  178. if (!ivalue.isTensor()) {
  179. continue;
  180. }
  181. auto unpacked = unwrapTensorAtLevel(ivalue.toTensor(), cur_level);
  182. const auto& tensor_value = std::get<0>(unpacked);
  183. const auto tensor_bdim = std::get<1>(unpacked);
  184. if (tensor_bdim.has_value()) {
  185. auto candidate_batch_size = tensor_value.size(*tensor_bdim);
  186. if (computed_batch_size == -1) {
  187. computed_batch_size = candidate_batch_size;
  188. }
  189. TORCH_INTERNAL_ASSERT(candidate_batch_size == computed_batch_size);
  190. }
  191. tensors->push_back(std::move(unpacked));
  192. tensors_pos->push_back(idx);
  193. }
  194. TORCH_INTERNAL_ASSERT(computed_batch_size > -1);
  195. *batch_size = computed_batch_size;
  196. }
  197. inline void boxed_existing_bdim_all_batch_rule(
  198. const c10::OperatorHandle& op, torch::jit::Stack* stack) {
  199. const auto& schema = op.schema();
  200. const auto num_returns = schema.returns().size();
  201. const auto num_arguments = static_cast<int64_t>(schema.arguments().size());
  202. c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
  203. auto maybe_layer = maybeCurrentDynamicLayer();
  204. vmap_check_escaped(maybe_layer, "boxed_existing_bdim_all_batch_rule");
  205. int64_t cur_level = maybe_layer->layerId();
  206. const auto arguments = torch::jit::last(stack, num_arguments);
  207. if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) {
  208. op.callBoxed(stack);
  209. return;
  210. }
  211. int64_t args_begin = static_cast<int64_t>(stack->size()) - num_arguments;
  212. SmallVector<UnpackedBatchedTensor, 5> tensor_inputs;
  213. SmallVector<int64_t, 5> tensor_pos;
  214. int64_t batch_size = 0;
  215. find_and_unpack_tensors(
  216. stack, num_arguments, cur_level,
  217. &tensor_inputs, &tensor_pos, &batch_size);
  218. // for each tensor, ensure it has a bdim and reshape it.
  219. for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) {
  220. const auto& value = std::get<0>(tensor_inputs[tensor_idx]);
  221. auto bdim = std::get<1>(tensor_inputs[tensor_idx]);
  222. auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size);
  223. if (!bdim.has_value()) {
  224. bdim = 0;
  225. }
  226. (*stack)[args_begin + tensor_pos[tensor_idx]] = reshape_dim_into(*bdim, 0, value_);
  227. }
  228. op.callBoxed(stack);
  229. for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) {
  230. const auto& ret = (*stack)[idx];
  231. TORCH_INTERNAL_ASSERT(ret.isTensor(),
  232. "This boxed batching rule does not currently support ops that return non-tensor values");
  233. (*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level);
  234. }
  235. }
  236. // Use when all tensors arguments accept one (normal) batch dim.
  237. // This batching rule expands the batch dim on all Tensors, reshapes it into
  238. // dim 0, calls the op, and then reshapes the batch dim out of dim 0.
  239. // This is not the most efficient thing; if there are alternatives, plese try
  240. // to use them. Use this only as a last resort.
  241. #define EXISTING_BDIM_ALL_BOXED(op) \
  242. m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_existing_bdim_all_batch_rule>());
  243. template <int64_t feature_rank, int64_t contig_tensor_index=-1>
  244. inline void boxed_all_tensors_have_optional_bdim(
  245. const c10::OperatorHandle& op, torch::jit::Stack* stack) {
  246. const auto& schema = op.schema();
  247. const auto num_returns = schema.returns().size();
  248. const auto num_arguments = schema.arguments().size();
  249. c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
  250. auto maybe_layer = maybeCurrentDynamicLayer();
  251. vmap_check_escaped(maybe_layer, "boxed_all_tensors_have_optional_bdim");
  252. int64_t cur_level = maybe_layer->layerId();
  253. const auto arguments = torch::jit::last(stack, num_arguments);
  254. if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) {
  255. op.callBoxed(stack);
  256. return;
  257. }
  258. int64_t args_begin = static_cast<int64_t>(stack->size() - num_arguments);
  259. SmallVector<UnpackedBatchedTensor, 5> tensor_inputs;
  260. SmallVector<int64_t, 5> tensor_pos;
  261. int64_t batch_size = 0;
  262. find_and_unpack_tensors(
  263. stack, static_cast<int64_t>(num_arguments), cur_level,
  264. &tensor_inputs, &tensor_pos, &batch_size);
  265. optional<bool> is_no_batch_dim_case;
  266. for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) {
  267. const auto& value = std::get<0>(tensor_inputs[tensor_idx]);
  268. auto bdim = std::get<1>(tensor_inputs[tensor_idx]);
  269. const auto logical_rank = rankWithoutBatchDim(value, bdim);
  270. if (!is_no_batch_dim_case.has_value()) {
  271. is_no_batch_dim_case = (logical_rank == feature_rank);
  272. }
  273. auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size);
  274. if (!bdim.has_value()) {
  275. bdim = 0;
  276. }
  277. if (*is_no_batch_dim_case) {
  278. TORCH_INTERNAL_ASSERT(logical_rank == feature_rank);
  279. value_ = moveBatchDimToFront(value_, bdim);
  280. if (tensor_idx == contig_tensor_index) {
  281. value_ = value_.contiguous();
  282. }
  283. (*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_);
  284. continue;
  285. }
  286. TORCH_INTERNAL_ASSERT(logical_rank == feature_rank + 1);
  287. value_ = reshape_dim_into(*bdim, 0, value_);
  288. if (tensor_idx == contig_tensor_index) {
  289. value_ = value_.contiguous();
  290. }
  291. (*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_);
  292. }
  293. op.callBoxed(stack);
  294. for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) {
  295. const auto& ret = (*stack)[idx];
  296. TORCH_INTERNAL_ASSERT(ret.isTensor(),
  297. "This boxed batching rule does not currently support ops that return non-tensor values");
  298. if (*is_no_batch_dim_case) {
  299. (*stack)[idx] = makeBatched(ret.toTensor(), 0, cur_level);
  300. } else {
  301. (*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level);
  302. }
  303. }
  304. }
  305. // Useful for many NN operators.
  306. // The operator must satisfy the following:
  307. // - All arguments must accept an optional batch dim.
  308. // - All arguments must be the same rank
  309. #define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED(feature_rank, op) \
  310. m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_all_tensors_have_optional_bdim<feature_rank>>());
  311. #define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(feature_rank, op, contig_tensor_index) \
  312. m.impl(#op, \
  313. torch::CppFunction::makeFromBoxedFunction<\
  314. boxed_all_tensors_have_optional_bdim<\
  315. feature_rank, \
  316. contig_tensor_index>\
  317. >());
  318. template <typename A, A a, typename C>
  319. struct ExistingBdimBatchRuleHelper;
  320. template <typename F, F Func, typename A, typename... T>
  321. struct ExistingBdimBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
  322. static std::tuple<Tensor,optional<int64_t>> apply(
  323. const Tensor& self,
  324. optional<int64_t> self_bdim,
  325. T... extra_args) {
  326. auto self_ = reshape_dim_into(*self_bdim, 0, self);
  327. auto out = Func(self_, std::forward<T>(extra_args)...);
  328. return std::make_tuple(reshape_dim_outof_symint(0, self.sym_sizes()[*self_bdim], out), 0);
  329. }
  330. };
  331. // USAGE: EXISTING_BDIM_BATCH_RULE(at::cholesky_inverse)
  332. // INCORRECT USAGE: EXISTING_BDIM_BATCH_RULE(&at::cholesky_inverse)
  333. // It is important that this macro is not passed a function pointer!!
  334. #define EXISTING_BDIM_BATCH_RULE(fn) SINGLE_ARG(\
  335. ExistingBdimBatchRuleHelper<\
  336. decltype(&fn),\
  337. &fn,\
  338. c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
  339. #define EXISTING_BDIM(op) \
  340. VMAP_SUPPORT(op, EXISTING_BDIM_BATCH_RULE(ATEN_FN(op)));
  341. #define EXISTING_BDIM2(op, overload) \
  342. VMAP_SUPPORT2(op, overload, EXISTING_BDIM_BATCH_RULE(ATEN_FN2(op, overload)));
  343. #define INVOKE(object,ptrToMember) ((object).*(ptrToMember))
  344. template <typename F, F Method, typename... ExtraArgs>
  345. Tensor& unary_inplace_batch_rule(Tensor& self, optional<int64_t>, ExtraArgs... extra_args) {
  346. INVOKE(self, Method)(std::forward<ExtraArgs>(extra_args)...);
  347. return self;
  348. }
  349. inline int64_t get_bdim_size4(
  350. const Tensor& a_value, optional<int64_t> a_bdim,
  351. const Tensor& b_value, optional<int64_t> b_bdim,
  352. const Tensor& c_value, optional<int64_t> c_bdim,
  353. const Tensor& d_value, optional<int64_t> d_bdim) {
  354. if (a_bdim)
  355. return a_value.size(*a_bdim);
  356. if (b_bdim)
  357. return b_value.size(*b_bdim);
  358. if (c_bdim)
  359. return c_value.size(*c_bdim);
  360. if (d_bdim)
  361. return d_value.size(*d_bdim);
  362. TORCH_INTERNAL_ASSERT(false);
  363. }
  364. inline int64_t get_bdim_size3(
  365. const Tensor& a_value, optional<int64_t> a_bdim,
  366. const Tensor& b_value, optional<int64_t> b_bdim,
  367. const Tensor& c_value, optional<int64_t> c_bdim) {
  368. if (a_bdim)
  369. return a_value.size(*a_bdim);
  370. if (b_bdim)
  371. return b_value.size(*b_bdim);
  372. if (c_bdim)
  373. return c_value.size(*c_bdim);
  374. TORCH_INTERNAL_ASSERT(false);
  375. }
  376. inline int64_t get_bdim_size2(
  377. const Tensor& a_value, optional<int64_t> a_bdim,
  378. const Tensor& b_value, optional<int64_t> b_bdim) {
  379. if (a_bdim)
  380. return a_value.size(*a_bdim);
  381. if (b_bdim)
  382. return b_value.size(*b_bdim);
  383. TORCH_INTERNAL_ASSERT(false);
  384. }
  385. // [start, start + 1, ..., stop - 1]
  386. inline VmapDimVector range(int64_t start, int64_t stop) {
  387. TORCH_INTERNAL_ASSERT(stop >= start);
  388. VmapDimVector dims;
  389. dims.reserve(stop - start);
  390. for (int64_t i = start; i < stop; i++) {
  391. dims.emplace_back(i);
  392. }
  393. return dims;
  394. }
  395. std::tuple<Tensor, Tensor> _binary_pointwise_helper(
  396. const Tensor& tensor, optional<int64_t> tensor_batch_dim, const Tensor& other, optional<int64_t> other_batch_dim,
  397. bool do_type_promotion=true);
  398. } // namespace at::functorch