ForeachUtils.h 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. #pragma once
  2. #include <ATen/Device.h>
  3. #include <ATen/Dispatch.h>
  4. #include <ATen/ScalarType.h>
  5. #include <ATen/core/Tensor.h>
  6. #include <ATen/native/utils/ParamsHash.h>
  7. #include <c10/util/Exception.h>
  8. #include <c10/util/irange.h>
  9. #ifndef AT_PER_OPERATOR_HEADERS
  10. #include <ATen/NativeFunctions.h>
  11. #else
  12. #include <ATen/ops/result_type_native.h>
  13. #endif
  14. #include <unordered_map>
  15. #include <vector>
  16. namespace at::native {
  17. namespace {
  18. // Check if tensor list has either a boolean tensor or a integer tensor
  19. inline bool has_integral_tensor(TensorList tensors, const bool includeBool) {
  20. return std::any_of(
  21. tensors.begin(), tensors.end(), [&includeBool](const auto& t) {
  22. return at::isIntegralType(t.scalar_type(), includeBool);
  23. });
  24. }
  25. // check if tensor list has bool tensors
  26. inline bool has_bool_tensor(TensorList tensors) {
  27. return std::any_of(tensors.begin(), tensors.end(), [](const auto& t) -> bool {
  28. return t.scalar_type() == ScalarType::Bool;
  29. });
  30. }
  31. // Check foreach API restrictions
  32. // - Tensor lists must be non-empty.
  33. // - All TensorLists and ScalarLists must have the same number of elements.
  34. // - Corresponding tensors must have the same size.
  35. inline void check_foreach_api_restrictions(TensorList tensors) {
  36. TORCH_CHECK(!tensors.empty(), "Tensor list must have at least one tensor.");
  37. }
  38. inline void check_foreach_api_restrictions(
  39. TensorList tensors,
  40. ArrayRef<Scalar> scalars) {
  41. check_foreach_api_restrictions(tensors);
  42. TORCH_CHECK(
  43. tensors.size() == scalars.size(),
  44. "Tensor list must have same number of elements as scalar list.");
  45. }
  46. inline void check_foreach_api_restrictions(
  47. TensorList tensors1,
  48. TensorList tensors2) {
  49. TORCH_CHECK(!tensors1.empty(), "Tensor list must have at least one tensor.");
  50. TORCH_CHECK(!tensors2.empty(), "Tensor list must have at least one tensor.");
  51. TORCH_CHECK(
  52. tensors1.size() == tensors2.size(),
  53. "Tensor lists must have the same number of tensors, got ",
  54. tensors1.size(),
  55. " and ",
  56. tensors2.size());
  57. }
  58. inline void check_foreach_api_restrictions(
  59. TensorList tensors1,
  60. TensorList tensors2,
  61. TensorList tensors3) {
  62. TORCH_CHECK(!tensors1.empty(), "Tensor list must have at least one tensor.");
  63. TORCH_CHECK(!tensors2.empty(), "Tensor list must have at least one tensor.");
  64. TORCH_CHECK(!tensors3.empty(), "Tensor list must have at least one tensor.");
  65. TORCH_CHECK(
  66. tensors1.size() == tensors2.size(),
  67. "Tensor lists must have the same number of tensors, got ",
  68. tensors1.size(),
  69. " and ",
  70. tensors2.size());
  71. TORCH_CHECK(
  72. tensors1.size() == tensors3.size(),
  73. "Tensor lists must have the same number of tensors, got ",
  74. tensors1.size(),
  75. " and ",
  76. tensors3.size());
  77. }
  78. inline void check_foreach_api_restrictions(
  79. TensorList tensors1,
  80. TensorList tensors2,
  81. TensorList tensors3,
  82. ArrayRef<Scalar> scalars) {
  83. check_foreach_api_restrictions(tensors1, tensors2, tensors3);
  84. TORCH_CHECK(
  85. tensors1.size() == scalars.size(),
  86. "Tensor list must have same number of elements as scalar list, got ",
  87. tensors1.size(),
  88. " and ",
  89. scalars.size());
  90. }
  91. // Helper function called in check_fast_path_restrictions to check whether all
  92. // corresponding tensors (aligning in index across the tensorLists) share the
  93. // same device and dtype.
  94. inline bool _check_tensors_share_device_and_dtype(
  95. ArrayRef<TensorList> tensorLists,
  96. const bool skip_dtype_check = false) {
  97. const auto expected_dtype = tensorLists[0][0].dtype();
  98. const auto expected_device = tensorLists[0][0].device();
  99. auto is_tensor_okay = [&](const Tensor& tensor) {
  100. return (skip_dtype_check || tensor.dtype() == expected_dtype) &&
  101. tensor.device() == expected_device && tensor.layout() == at::kStrided &&
  102. tensor.is_non_overlapping_and_dense();
  103. };
  104. for (const auto& tensorList : tensorLists) {
  105. for (const auto& tensor : tensorList) {
  106. if (!is_tensor_okay(tensor)) {
  107. return false;
  108. }
  109. }
  110. }
  111. return true;
  112. }
  113. // Helper function called in check_fast_path_restrictions to check if
  114. // corresponding tensors in tensor lists have the same sizes and strides.
  115. inline bool _check_tensors_share_sizes_and_strides(
  116. ArrayRef<TensorList> tensorLists) {
  117. for (const auto i : c10::irange(1, tensorLists.size())) {
  118. for (const auto j : c10::irange(tensorLists[0].size())) {
  119. if (tensorLists[0][j].sizes() != tensorLists[i][j].sizes() ||
  120. tensorLists[0][j].strides() != tensorLists[i][j].strides()) {
  121. return false;
  122. }
  123. }
  124. }
  125. return true;
  126. }
  127. // Helper function called in check_fast_path_restrictions to check whether
  128. // all tensors type promote properly with the scalars in scalarList. This
  129. // function assumes that _check_tensors_share_device_and_dtype has already been
  130. // called so that all corresponding tensors in tensorLists have the same dtype.
  131. // Then, it is sufficient to check the type promotion with just one tensorList.
  132. inline bool _check_tensors_do_type_promotion_with_scalars(
  133. TensorList tensorList,
  134. ArrayRef<Scalar> scalarList = {},
  135. bool does_op_promote_integer_inputs_to_float = false) {
  136. for (const auto i : c10::irange(tensorList.size())) {
  137. // For division, integer inputs will result in float.
  138. if (does_op_promote_integer_inputs_to_float) {
  139. if (at::isIntegralType(
  140. tensorList[i].scalar_type(), /*includeBool*/ true)) {
  141. return false;
  142. }
  143. }
  144. if (!scalarList.empty()) {
  145. const auto& scalar =
  146. scalarList.size() == 1 ? scalarList[0] : scalarList[i];
  147. const auto& tensor = tensorList[i];
  148. // note(mkozuki): This check might be responsible for
  149. // `_foreach_add(bool_tensors, bool_tensors)` being pushed to slow path.
  150. if (tensor.scalar_type() != at::native::result_type(scalar, tensor)) {
  151. return false;
  152. }
  153. }
  154. }
  155. return true;
  156. }
  157. // To go via 'fast' path, several conditions must be satisfied
  158. // - All tensors in all lists must have the same dtype.
  159. // - All tensors must be on the same device
  160. // - All tensors must have strided layout
  161. // - All tensors must be non-overlapping and dense
  162. // - Resulting tensor must have the same dtype as the input one
  163. // [note: what's ``does_op_promote_integer_inputs_to_float=true``?]
  164. // ``does_op_promote_integer_inputs_to_float=true`` means that the result of
  165. // the op will be float even if inputs are integer or boolean, which
  166. // currently fast path does not support. In short, this flag, when
  167. // turned on, gatekeeps the op from going down the fastpath.
  168. // Please, make sure to call check_foreach_api_restrictions before calling this
  169. // method. There is a set of preconditions that have to be satisfied.
  170. inline bool check_fast_path_restrictions(
  171. ArrayRef<TensorList> tensorLists,
  172. ArrayRef<Scalar> scalarList = {},
  173. bool does_op_promote_integer_inputs_to_float = false) {
  174. return _check_tensors_share_device_and_dtype(tensorLists) &&
  175. _check_tensors_share_sizes_and_strides(tensorLists) &&
  176. _check_tensors_do_type_promotion_with_scalars(
  177. tensorLists[0],
  178. scalarList,
  179. does_op_promote_integer_inputs_to_float);
  180. }
  181. inline std::vector<c10::Scalar> convert_tensor_to_scalar_list(
  182. const Tensor& scalarList_,
  183. int64_t expect_length) {
  184. std::vector<c10::Scalar> scalarList;
  185. TORCH_CHECK(
  186. scalarList_.device() == c10::kCPU,
  187. "Expected scalars to be on CPU, got ",
  188. scalarList_.device(),
  189. " instead.");
  190. TORCH_CHECK(
  191. scalarList_.is_contiguous(), "Expected scalars to be contiguous.");
  192. TORCH_CHECK(
  193. scalarList_.dim() == 1,
  194. "Expected packed scalar Tensor to be of dimension 1. Got ",
  195. scalarList_.dim(),
  196. " instead.");
  197. AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
  198. kComplexHalf,
  199. kHalf,
  200. kBool,
  201. kBFloat16,
  202. scalarList_.scalar_type(),
  203. "convert_tensor_to_scalar_list",
  204. [&]() {
  205. const scalar_t* scalar_data = scalarList_.const_data_ptr<scalar_t>();
  206. TORCH_CHECK(
  207. (expect_length == scalarList_.size(0)),
  208. "Expected length of scalars to match input of length ",
  209. expect_length,
  210. " but got ",
  211. scalarList_.size(0),
  212. " instead.");
  213. for (int64_t i = 0; i < scalarList_.size(0); i++) {
  214. scalarList.emplace_back(scalar_data[i]);
  215. }
  216. });
  217. return scalarList;
  218. }
  219. // see: [note: what's ``does_op_promote_integer_inputs_to_float=true``?]
  220. inline bool can_use_fast_route(
  221. ArrayRef<TensorList> tensorLists,
  222. ArrayRef<Scalar> scalarList = {},
  223. bool does_op_promote_integer_inputs_to_float = false) {
  224. return check_fast_path_restrictions(
  225. tensorLists, scalarList, does_op_promote_integer_inputs_to_float);
  226. }
  227. // see: [note: what's ``does_op_promote_integer_inputs_to_float=true``?]
  228. inline bool can_use_fast_route(
  229. TensorList tensors1,
  230. TensorList tensors2,
  231. bool does_op_promote_integer_inputs_to_float = false) {
  232. return can_use_fast_route(
  233. {tensors1, tensors2}, {}, does_op_promote_integer_inputs_to_float);
  234. }
  235. using DeviceDtypeKey = std::pair<at::Device, at::ScalarType>;
  236. using IndicesT = std::vector<size_t>;
  237. using nested_optional_tensorvec_t =
  238. std::vector<std::vector<std::optional<at::Tensor>>>;
  239. using TensorsAndIndicesT = std::pair<nested_optional_tensorvec_t, IndicesT>;
  240. using FlatMap = std::unordered_map<
  241. DeviceDtypeKey,
  242. TensorsAndIndicesT,
  243. ParamsHash<DeviceDtypeKey>>;
  244. inline FlatMap _group_tensors_by_first_tensors_device_and_dtype(
  245. const nested_optional_tensorvec_t& nested_tensorlist,
  246. const bool with_indices) {
  247. FlatMap grouped_tensors_with_indices;
  248. TORCH_CHECK(!nested_tensorlist.empty());
  249. TORCH_CHECK(!nested_tensorlist[0].empty());
  250. const auto num_lists = nested_tensorlist.size();
  251. const auto num_tensors = nested_tensorlist[0].size();
  252. TORCH_CHECK(std::all_of(
  253. nested_tensorlist.cbegin(),
  254. nested_tensorlist.cend(),
  255. [&](const auto& tensorlist) -> bool {
  256. // note(crcrpar): Allow empty tensorlists following
  257. // ref:
  258. // https://github.com/pytorch/pytorch/blob/85885301fd3c6adb8b9dc3cf7afadf6945566684/torch/utils/_foreach_utils.py#L21-L24
  259. return tensorlist.size() == num_tensors || tensorlist.size() == 0;
  260. }));
  261. for (const auto& tensor_index : c10::irange(num_tensors)) {
  262. const auto key = [&]() -> DeviceDtypeKey {
  263. const auto t = nested_tensorlist[0][tensor_index];
  264. TORCH_CHECK(
  265. t.has_value(),
  266. "Tensors of the first list of nested Tensor lists are supposed to be defined but ",
  267. "the ",
  268. tensor_index,
  269. "-th Tensor is not.");
  270. return {t->device(), t->scalar_type()};
  271. }();
  272. TORCH_CHECK(
  273. std::all_of(
  274. nested_tensorlist.cbegin(),
  275. nested_tensorlist.cend(),
  276. [&](const auto& tensorlist) -> bool {
  277. if (tensorlist.size() == 0) {
  278. return true;
  279. }
  280. const auto& tensor = tensorlist[tensor_index];
  281. // note(crcrpar): Currently the scope of this function is
  282. // optimizers so there could be `state_steps` and other scalars
  283. // whose elements are float tensors no matter what the parameter's
  284. // dtype is.
  285. if (!tensor.has_value()) {
  286. return true;
  287. } else {
  288. const auto s = tensor->scalar_type();
  289. const auto d = tensor->device();
  290. // Note: `step` or `state_step` is float32 by default.
  291. if (key.first == d) {
  292. return key.second == s || s == at::ScalarType::Float ||
  293. s == at::ScalarType::Double;
  294. } else if (d.is_cpu()) {
  295. // note(crcrpar): There are some test cases (e.g.
  296. // TestOptim::test_adam) where state_steps are on CPU and the
  297. // others are on CUDA. Currently a state_step Tensor has the
  298. // dtype of float.
  299. return s == at::ScalarType::Float ||
  300. s == at::ScalarType::Double;
  301. } else {
  302. return false;
  303. }
  304. }
  305. }),
  306. "Tensors of the same index must be on the same device and the same dtype except `step` tensors that can be CPU and float32/64 notwithstanding");
  307. if (!grouped_tensors_with_indices.count(key)) {
  308. grouped_tensors_with_indices.insert(
  309. {key,
  310. TensorsAndIndicesT{
  311. [&]() -> nested_optional_tensorvec_t {
  312. nested_optional_tensorvec_t nested_tensorvec;
  313. nested_tensorvec.reserve(num_lists);
  314. for (const auto& i : c10::irange(num_lists)) {
  315. std::vector<std::optional<at::Tensor>> tensors;
  316. if (!nested_tensorlist[i].empty()) {
  317. // NB: num_tensors is the max possible length for any of
  318. // the inner lists of tensor references. Reserving the max
  319. // trades memory for perf. This should not have significant
  320. // impact.
  321. tensors.reserve(num_tensors);
  322. }
  323. nested_tensorvec.emplace_back(tensors);
  324. }
  325. return nested_tensorvec;
  326. }(),
  327. [&]() -> IndicesT {
  328. if (!with_indices) {
  329. return {};
  330. } else {
  331. IndicesT indices;
  332. indices.reserve(num_tensors);
  333. return indices;
  334. }
  335. }()}});
  336. }
  337. for (const auto& list_index : c10::irange(num_lists)) {
  338. if (!nested_tensorlist[list_index].empty()) {
  339. grouped_tensors_with_indices[key].first[list_index].emplace_back(
  340. nested_tensorlist[list_index][tensor_index]);
  341. }
  342. }
  343. if (with_indices) {
  344. grouped_tensors_with_indices[key].second.emplace_back(tensor_index);
  345. }
  346. }
  347. return grouped_tensors_with_indices;
  348. }
  349. } // namespace
  350. } // namespace at::native