decomposition.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750
  1. # mypy: allow-untyped-defs
  2. import functools
  3. import logging
  4. import math
  5. import sys
  6. import typing
  7. from typing import Optional
  8. import torch
  9. import torch._decomp as decomp
  10. import torch._prims_common as utils
  11. import torch.ao.quantization.fx._decomposed
  12. from torch._decomp import (
  13. core_aten_decompositions,
  14. get_decompositions,
  15. remove_decompositions,
  16. )
  17. from torch._decomp.decompositions import (
  18. _grid_sampler_2d as decomp_grid_sampler_2d,
  19. pw_cast_for_opmath,
  20. )
  21. from torch._decomp.decompositions_for_rng import extra_random_decomps
  22. from torch._dynamo.utils import counters
  23. from torch._higher_order_ops.out_dtype import out_dtype
  24. from torch._inductor.utils import pad_listlike
  25. from torch._prims_common import (
  26. elementwise_dtypes,
  27. ELEMENTWISE_TYPE_PROMOTION_KIND,
  28. type_to_dtype,
  29. )
  30. from . import config, inductor_prims
  31. from .utils import (
  32. is_gpu,
  33. needs_fallback_due_to_atomic_add_limitations,
  34. use_scatter_fallback,
  35. )
  36. log = logging.getLogger(__name__)
  37. aten = torch.ops.aten
  38. prims = torch.ops.prims
  39. quantized = torch.ops.quantized
  40. quantized_decomposed = torch.ops.quantized_decomposed
  41. inductor_decompositions = get_decompositions(
  42. [
  43. aten._adaptive_avg_pool2d_backward,
  44. aten.arange,
  45. aten.bitwise_and_,
  46. aten.bitwise_or_,
  47. aten.clamp_min_,
  48. aten.dist,
  49. aten.empty_like,
  50. aten.flip,
  51. aten.gelu,
  52. aten.hardtanh,
  53. aten.index_select,
  54. aten.lcm,
  55. aten.leaky_relu,
  56. aten.linalg_vector_norm,
  57. aten._log_softmax,
  58. aten.max_pool2d_with_indices_backward,
  59. aten._native_batch_norm_legit,
  60. aten._native_batch_norm_legit_functional,
  61. aten._native_batch_norm_legit_no_training,
  62. aten._batch_norm_with_update,
  63. aten._batch_norm_with_update_functional,
  64. aten._batch_norm_no_update,
  65. aten.batch_norm_backward,
  66. aten.native_batch_norm,
  67. aten.native_group_norm,
  68. aten.native_layer_norm,
  69. aten.nll_loss2d_backward,
  70. aten._softmax,
  71. aten.sin_,
  72. aten.sqrt_,
  73. out_dtype,
  74. aten._to_copy,
  75. aten.tril_indices,
  76. aten.triu_indices,
  77. aten.upsample_bilinear2d.vec,
  78. quantized.linear_dynamic_fp16_unpacked_weight,
  79. ]
  80. )
  81. decompositions = {**core_aten_decompositions(), **inductor_decompositions}
  82. # Remove unwanted decompositions included via the core ATen decompositions from
  83. # the Inductor decomp table.
  84. decomps_to_exclude = [
  85. aten._unsafe_index,
  86. aten._scaled_dot_product_flash_attention_for_cpu.default, # See comments in torch/_decomp/decompositions.py
  87. aten._softmax_backward_data,
  88. aten.clamp_max,
  89. aten.clamp_min,
  90. aten.glu, # inductor lowers this directly
  91. aten.select_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass
  92. aten.split.Tensor, # inductor lowers this directly
  93. aten.squeeze, # inductor lowers this directly
  94. aten.sum, # inductor lowers this directly
  95. aten.unbind, # inductor lowers this directly
  96. ]
  97. remove_decompositions(decompositions, decomps_to_exclude)
  98. def register_decomposition(ops):
  99. for op in [ops] if callable(ops) else ops:
  100. if op in decompositions:
  101. log.warning("duplicate decomp: %s", ops)
  102. return decomp.register_decomposition(ops, decompositions)
  103. # TODO: for now, inductor doesn't handle asserts
  104. # because the condition is symbol -> tensor in the graph.
  105. @register_decomposition([aten._assert_async.msg])
  106. def assert_async_msg_decomp(tensor, msg):
  107. return
  108. # Following `assert_async_msg_decomp` and implement as non-op.
  109. @register_decomposition([aten._functional_assert_async.msg])
  110. def functional_assert_async_msg_decomp(tensor, msg):
  111. return
  112. @register_decomposition([aten.sym_constrain_range_for_size.default])
  113. def sym_constrain_range_for_size(symbol, *, min=None, max=None):
  114. return
  115. @register_decomposition([aten.clamp])
  116. @pw_cast_for_opmath
  117. def clamp(x, min=None, max=None):
  118. if min is not None:
  119. x = x.clamp_min(min)
  120. if max is not None:
  121. x = x.clamp_max(max)
  122. return x
  123. @register_decomposition([aten.full])
  124. def full(size, fill_value, **kwargs):
  125. dtype = kwargs.get("dtype")
  126. if dtype is None:
  127. kwargs["dtype"] = type_to_dtype(type(fill_value))
  128. return torch.full(size, fill_value, **kwargs)
  129. return NotImplemented
  130. # Not really sure how to put this into the main library. PrimTorch wants
  131. # empty_permuted to go to the prim, and typically users don't really want
  132. # to decompose to empty_strided (but inductor is OK with it, because we are
  133. # cool with strides and everything goes to empty_strided)
  134. @register_decomposition([aten.empty_permuted.default])
  135. def empty_permuted(size, physical_layout, **kwargs):
  136. perm = [0] * len(size)
  137. for p, l in enumerate(physical_layout):
  138. perm[l] = p
  139. return torch.empty([size[l] for l in physical_layout], **kwargs).permute(perm)
  140. @register_decomposition([aten.convolution_backward])
  141. def convolution_backward(
  142. grad_output,
  143. input,
  144. weight,
  145. bias_sizes,
  146. stride,
  147. padding,
  148. dilation,
  149. transposed,
  150. output_padding,
  151. groups,
  152. output_mask,
  153. ):
  154. if not output_mask[2] or not is_gpu(grad_output.device.type):
  155. return NotImplemented
  156. grad_bias = aten.sum(grad_output, [0] + list(range(2, grad_output.dim())))
  157. grad_inp, grad_weight, _ = aten.convolution_backward(
  158. grad_output,
  159. input,
  160. weight,
  161. bias_sizes,
  162. stride,
  163. padding,
  164. dilation,
  165. transposed,
  166. output_padding,
  167. groups,
  168. [output_mask[0], output_mask[1], False],
  169. )
  170. return (grad_inp, grad_weight, grad_bias)
  171. @register_decomposition([aten.round.decimals])
  172. def round_dec(x, decimals=0):
  173. ten_pow_decimals = 10.0**decimals
  174. return aten.round(x * ten_pow_decimals) * (1.0 / ten_pow_decimals)
  175. @register_decomposition([aten.bmm])
  176. @pw_cast_for_opmath
  177. def bmm(self, batch2):
  178. if config.coordinate_descent_tuning:
  179. if self.shape[1] == 1 or batch2.shape[2] == 1:
  180. out = (self.unsqueeze(-1) * batch2.unsqueeze(1)).sum(dim=2)
  181. return out
  182. if self.device.type == "cpu":
  183. if self.size(1) == 1 and batch2.size(-1) == 1:
  184. counters["inductor"]["decompose_bmm"] += 1
  185. return torch.sum(
  186. self.squeeze(1) * batch2.squeeze(-1), dim=1, keepdim=True
  187. ).unsqueeze(1)
  188. return NotImplemented
  189. @register_decomposition([aten.addmm])
  190. @pw_cast_for_opmath
  191. def addmm(self, mat1, mat2, beta=1, alpha=1):
  192. if self.device.type == "cpu":
  193. if mat1.size(0) == 1 and mat2.size(-1) == 1:
  194. counters["inductor"]["decompose_addmm"] += 1
  195. out = torch.sum(
  196. mat1.squeeze(0) * mat2.squeeze(-1), dim=0, keepdim=True
  197. ).unsqueeze(0)
  198. return alpha * out + beta * self
  199. if mat1.size(0) == 1 and mat2.size(0) <= 16 and mat2.size(1) <= 16:
  200. counters["inductor"]["decompose_addmm"] += 1
  201. out = (mat1.T * mat2).sum(dim=0, keepdim=True)
  202. return alpha * out + beta * self
  203. return NotImplemented
  204. @register_decomposition([aten.mm])
  205. @pw_cast_for_opmath
  206. def mm(self, input2):
  207. from torch.fx.experimental.symbolic_shapes import (
  208. definitely_true,
  209. guard_size_oblivious,
  210. )
  211. # Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning.
  212. # todo: Look into why and fix it (hopefully)
  213. if config.coordinate_descent_tuning:
  214. if self.shape[0] == 1 or input2.shape[1] == 1:
  215. return (self.unsqueeze(2) * input2.unsqueeze(0)).sum(dim=1)
  216. if self.device.type == "cpu":
  217. if (
  218. guard_size_oblivious(self.size(-1) == 1)
  219. and guard_size_oblivious(self.size(0) > 0)
  220. and guard_size_oblivious(input2.size(0) == 1)
  221. and (self.dtype == input2.dtype)
  222. and definitely_true((torch.numel(self) + torch.numel(input2)) <= 32)
  223. ):
  224. counters["inductor"]["decompose_mm"] += 1
  225. return torch.cat([self[i, :] * input2 for i in range(self.size(0))])
  226. if guard_size_oblivious(self.size(0) == 1) and guard_size_oblivious(
  227. input2.size(-1) == 1
  228. ):
  229. counters["inductor"]["decompose_mm"] += 1
  230. return torch.sum(
  231. self.squeeze(0) * input2.squeeze(-1), dim=0, keepdim=True
  232. ).unsqueeze(0)
  233. return NotImplemented
  234. # This pass does two things:
  235. # - Eliminate cat when there is only one tensor input
  236. # - Normalize cat calls, so that legacy empty 1-D tensors are removed (NB: we
  237. # don't remove ALL empty tensors, only the naughty ones)
  238. @register_decomposition([aten.cat.default])
  239. def cat(tensors, dim=0):
  240. from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
  241. def non_empty_tensor(x):
  242. # For better or worse, this is a valid cat:
  243. #
  244. # torch.cat([torch.randn(2, 2, 4), torch.randn(0), torch.randn(3, 2, 4)])
  245. #
  246. # We'd like to eliminate naughtiness like this for downstream passes
  247. # like split_cat. The easiest way is to just drop such inputs
  248. # (guarding that they are non-zero).
  249. #
  250. # Is it permissible for this filtering to be size-oblivious? A case
  251. # where this could matter is cat([(2, 2), (u0,)], dim=0); if u0
  252. # happened to be zero, we would have liked to have filtered it out.
  253. # But actually, the ONLY way this could have passed is if u0 == 0,
  254. # so by the time we get here we have already installed a deferred
  255. # runtime assert forcing u0 to be zero. So if this hasn't happened,
  256. # we know that the unbacked SymInt has appropriate size and there are
  257. # no problems.
  258. return len(x.shape) != 1 or guard_size_oblivious(x.shape[0] > 0)
  259. filtered_tensors = list(filter(non_empty_tensor, tensors))
  260. if len(filtered_tensors) == 1:
  261. return filtered_tensors[0].clone()
  262. elif 1 < len(filtered_tensors) < len(tensors):
  263. # on the first call, when we remove empty tensors, we redispatch recursively
  264. return aten.cat.default(filtered_tensors, dim)
  265. # when no 'filtering' has occurred, we raise to prevent infinite recursion (no more decomposition needed)
  266. return NotImplemented
  267. @register_decomposition([aten.angle])
  268. def angle(x):
  269. if x.is_complex():
  270. return torch.where(
  271. torch.isnan(x.real), float("nan"), torch.atan2(x.imag, x.real)
  272. )
  273. # when x is real number
  274. # if x >= 0, return 0
  275. # if x < 0, return pi
  276. # if x is nan, return nan
  277. _, dtype = elementwise_dtypes(
  278. x,
  279. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  280. )
  281. pi = torch.scalar_tensor(math.pi, dtype=dtype, device=x.device)
  282. ret = torch.where(x < 0, pi, 0.0)
  283. return torch.where(torch.isnan(x), float("nan"), ret)
  284. @register_decomposition([aten.add])
  285. def add(x, y, *, alpha=None):
  286. x_is_complex_tensor = torch.is_tensor(x) and x.is_complex()
  287. y_is_complex_tensor = torch.is_tensor(y) and y.is_complex()
  288. if not x_is_complex_tensor or not y_is_complex_tensor:
  289. return NotImplemented
  290. z = y
  291. if alpha is not None:
  292. z = alpha * y
  293. complex_type = torch.promote_types(x.dtype, y.dtype)
  294. return (x.view(x.real.dtype) + z.view(y.real.dtype)).view(complex_type)
  295. @register_decomposition([aten.conj_physical])
  296. def conj_physical(self):
  297. assert not self.is_complex(), "TODO: implement this"
  298. return self
  299. @register_decomposition([aten.lift, aten.detach_])
  300. def lift(self):
  301. return self
  302. @register_decomposition([aten.bernoulli.default])
  303. def bernoulli(self, *, generator=None):
  304. assert generator is None
  305. return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype)
  306. @register_decomposition([aten.fmin, prims.fmin])
  307. def fmin(self, other):
  308. return torch.where(torch.isnan(other) | (other > self), self, other)
  309. @register_decomposition([aten.fmax, prims.fmax])
  310. def fmax(self, other):
  311. return torch.where(torch.isnan(other) | (other < self), self, other)
  312. @register_decomposition(aten.amax)
  313. def amax(self, dim=None, keepdim=False):
  314. if self.dtype == torch.bool:
  315. return torch.any(self, dim=dim, keepdim=keepdim)
  316. return NotImplemented
  317. @register_decomposition(aten.amin)
  318. def amin(self, dim=None, keepdim=False):
  319. if self.dtype == torch.bool:
  320. return torch.all(self, dim=dim, keepdim=keepdim)
  321. return NotImplemented
  322. @register_decomposition([aten.narrow_copy])
  323. def narrow_copy(self, dim, start, length):
  324. return torch.narrow(self, dim, start, length).clone()
  325. @register_decomposition([aten.expand_copy])
  326. def expand_copy(self, size, *, implicit=False):
  327. return aten.expand(self, size, implicit=implicit).clone()
  328. @register_decomposition([aten.view_copy.default])
  329. def view_copy_default(self, size):
  330. return aten.view(self, size).clone()
  331. @register_decomposition([aten.view_copy.dtype])
  332. def view_copy_dtype(self, dtype):
  333. return self.to(dtype).clone()
  334. def get_like_layout(
  335. tensor: torch.Tensor, memory_format: Optional[torch.memory_format]
  336. ) -> torch.memory_format:
  337. # TODO: _to_copy tensor to stride permutation
  338. if memory_format is torch.preserve_format or memory_format is None:
  339. return utils.suggest_memory_format(tensor)
  340. else:
  341. return memory_format
  342. @register_decomposition(aten.rand_like)
  343. def rand_like(self, *, dtype=None, device=None, memory_format=None, **kwargs):
  344. return torch.rand(
  345. [*self.size()],
  346. dtype=dtype or self.dtype,
  347. device=device or self.device,
  348. **kwargs,
  349. ).to(memory_format=get_like_layout(self, memory_format))
  350. @register_decomposition(aten.randn_like)
  351. def randn_like(self, *, dtype=None, device=None, memory_format=None, **kwargs):
  352. return torch.randn(
  353. [*self.size()],
  354. dtype=dtype or self.dtype,
  355. device=device or self.device,
  356. **kwargs,
  357. ).to(memory_format=get_like_layout(self, memory_format))
  358. @register_decomposition(aten.full_like)
  359. def full_like(
  360. self,
  361. fill_value,
  362. *,
  363. dtype=None,
  364. layout=None,
  365. device=None,
  366. pin_memory=False,
  367. requires_grad=False,
  368. memory_format=torch.preserve_format,
  369. ):
  370. return torch.full(
  371. [*self.size()],
  372. fill_value,
  373. dtype=dtype or self.dtype,
  374. layout=layout or self.layout,
  375. device=device or self.device,
  376. requires_grad=requires_grad,
  377. ).to(memory_format=get_like_layout(self, memory_format))
  378. @register_decomposition(aten.randint_like.default)
  379. def randint_like(self, high, *, dtype=None, device=None, memory_format=None, **kwargs):
  380. return aten.randint.low(
  381. 0,
  382. high,
  383. [*self.size()],
  384. dtype=dtype or self.dtype,
  385. device=device or self.device,
  386. **kwargs,
  387. ).to(memory_format=get_like_layout(self, memory_format))
  388. @register_decomposition(aten.randint_like.low_dtype)
  389. def randint_like_low(
  390. self, low, high, *, dtype=None, device=None, memory_format=None, **kwargs
  391. ):
  392. return aten.randint.low(
  393. low,
  394. high,
  395. [*self.size()],
  396. dtype=dtype or self.dtype,
  397. device=device or self.device,
  398. **kwargs,
  399. ).to(memory_format=get_like_layout(self, memory_format))
  400. @register_decomposition(aten.randint.default)
  401. def randint(high, size, **kwargs):
  402. return aten.randint.low(0, high, size, **kwargs)
  403. @register_decomposition(quantized.linear_dynamic_fp16_unpacked_weight.default)
  404. def linear_dynamic_fp16_unpacked_weight(
  405. input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
  406. ) -> torch.Tensor:
  407. packed_weight = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(weight)
  408. return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight(
  409. input, packed_weight, bias, weight.size()[0]
  410. )
  411. @register_decomposition(torch.ops.quantized.embedding_bag_byte_unpack)
  412. def q_embedding_bag_byte_unpack_decomp(packed):
  413. def bitcast_u8_to_f32(u8):
  414. x, y, z, w = (u8[..., n].to(torch.int32) for n in (0, 1, 2, 3))
  415. if sys.byteorder == "little":
  416. return (x + (y << 8) + (z << 16) + (w << 24)).view(torch.float32)[..., None]
  417. else:
  418. return ((x << 24) + (y << 16) + (z << 8) + w).view(torch.float32)[..., None]
  419. scales = bitcast_u8_to_f32(packed[..., -8:-4])
  420. offsets = bitcast_u8_to_f32(packed[..., -4:])
  421. return packed[..., :-8].to(torch.float32) * scales + offsets
  422. @register_decomposition([aten.grid_sampler_2d])
  423. @pw_cast_for_opmath
  424. def grid_sampler_2d(
  425. a: torch.Tensor,
  426. grid: torch.Tensor,
  427. interpolation_mode: int = 0,
  428. padding_mode: int = 0,
  429. align_corners: bool = False,
  430. ) -> torch.Tensor:
  431. # We do not expand the grid (_expand_grid=False) on cpu for performance reasons
  432. # Experimenting locally it was found that compiled CUDA code is accelerated by ~5x
  433. # and CPU code by ~2x on bicubic mode, if we expand the grid from (N, H, W, 2) into (N, C, H, W, 2)
  434. # However, this leads to a slowdown around ~0.8x on CPU bilinear mode, channels first.
  435. # Thus we apply this hack to not expand the grid for this case.
  436. _expand_grid = not (
  437. a.device == torch.device("cpu")
  438. and interpolation_mode == 0
  439. and a.is_contiguous(memory_format=torch.contiguous_format)
  440. )
  441. output = decomp_grid_sampler_2d(
  442. a,
  443. grid=grid,
  444. interpolation_mode=interpolation_mode,
  445. padding_mode=padding_mode,
  446. align_corners=align_corners,
  447. _expand_grid=_expand_grid,
  448. )
  449. return output
  450. @register_decomposition(aten._foreach_addcmul.Scalar)
  451. def _foreach_addcmul_scalar(self, left_tensors, right_tensors, scalar=1):
  452. return aten._foreach_add.List(
  453. self, aten._foreach_mul.List(left_tensors, right_tensors), alpha=scalar
  454. )
  455. @register_decomposition(aten._foreach_addcdiv.Scalar)
  456. def _foreach_addcdiv_scalar(self, left_tensors, right_tensors, scalar=1):
  457. return aten._foreach_add.List(
  458. self, aten._foreach_div.List(left_tensors, right_tensors), alpha=scalar
  459. )
  460. @register_decomposition(aten._foreach_lerp.Scalar)
  461. def _foreach_lerp_scalar(start_tensors, end_tensors, weight):
  462. return aten._foreach_add.List(
  463. start_tensors,
  464. aten._foreach_mul.Scalar(
  465. aten._foreach_sub.List(end_tensors, start_tensors), weight
  466. ),
  467. )
  468. @aten.miopen_batch_norm.default.py_impl(torch._C.DispatchKey.Autograd)
  469. @register_decomposition(aten.miopen_batch_norm)
  470. def miopen_batch_norm(
  471. input: torch.Tensor,
  472. weight: torch.Tensor,
  473. bias: typing.Optional[torch.Tensor],
  474. running_mean: typing.Optional[torch.Tensor],
  475. running_var: typing.Optional[torch.Tensor],
  476. training: bool,
  477. exponential_average_factor: float,
  478. epsilon: float,
  479. ):
  480. a, b, c = aten.native_batch_norm(
  481. input,
  482. weight,
  483. bias,
  484. running_mean,
  485. running_var,
  486. training,
  487. exponential_average_factor,
  488. epsilon,
  489. )
  490. if training:
  491. return (a, b, c)
  492. return (
  493. a,
  494. weight.new_zeros((0,)),
  495. weight.new_zeros((0,)),
  496. )
  497. @functools.lru_cache(None)
  498. def fast_random_decomps():
  499. return {**decompositions, **extra_random_decomps}
  500. def select_decomp_table():
  501. """decomps can change based on config"""
  502. if config.fallback_random:
  503. return decompositions
  504. return fast_random_decomps()
  505. @register_decomposition(aten.masked_scatter)
  506. def masked_scatter(self, mask, source):
  507. if is_gpu(self.device.type):
  508. # This two-step algorithm is the same as eager CUDA, for eager CPU we
  509. # use a 1-shot serial iteration.
  510. self, mask = aten.broadcast_tensors([self, mask])
  511. source_idx = mask.reshape(-1).cumsum(0) - 1
  512. return inductor_prims.masked_scatter_with_index(self, mask, source_idx, source)
  513. return NotImplemented
  514. @register_decomposition(quantized_decomposed.choose_qparams.tensor)
  515. def choose_qparams_tensor(
  516. input: torch.Tensor, quant_min: int, quant_max: int, eps: float, dtype: torch.dtype
  517. ):
  518. min_val, max_val = torch.aminmax(input)
  519. scale = (max_val - min_val) / float(quant_max - quant_min)
  520. scale = torch.max(scale, torch.Tensor([eps]))
  521. zero_point = quant_min - torch.round(min_val / scale).to(torch.int)
  522. zero_point = torch.clamp(zero_point, quant_min, quant_max)
  523. return scale.to(torch.float64), zero_point.to(torch.int64)
  524. @register_decomposition(aten.put)
  525. def put(self, index, source, accumulate=False):
  526. flattened = self.flatten()
  527. flattened = torch.index_put(
  528. flattened, [index], source.reshape(index.shape), accumulate
  529. )
  530. return flattened.reshape(self.shape)
  531. @register_decomposition(aten.put_)
  532. def put_(self, index, source, accumulate=False):
  533. out = aten.put(self, index, source, accumulate=accumulate)
  534. return self.copy_(out)
  535. @register_decomposition(aten._softmax_backward_data.default)
  536. @pw_cast_for_opmath
  537. def _softmax_backward_data(grad_output, output, dim, input_dtype):
  538. new_grad_output = grad_output * output
  539. sum_new_grad = torch.sum(new_grad_output, dim=dim, keepdim=True)
  540. # grad_input = new_grad_output - output * sum_new_grad
  541. grad_input = inductor_prims.fma(-output, sum_new_grad, new_grad_output)
  542. # CPU kernel doesn't respect input_dtype, but following check doesn't work for meta tensor
  543. # if grad_output.device == torch.device("cpu"):
  544. # return grad_input.contiguous()
  545. if grad_output.dtype != input_dtype:
  546. grad_input = grad_input.to(input_dtype)
  547. return grad_input.contiguous()
  548. @register_decomposition(aten.index_reduce)
  549. def index_reduce(
  550. self, dim: int, index, src, reduction_type: str, *, include_self: bool = True
  551. ):
  552. if reduction_type == "mean" and not needs_fallback_due_to_atomic_add_limitations(
  553. self.dtype
  554. ):
  555. true_division = self.dtype.is_floating_point or self.dtype.is_complex
  556. ones = torch.ones_like(src)
  557. if include_self:
  558. out = self
  559. counts = torch.ones_like(self).index_add(dim, index, ones)
  560. else:
  561. out = self.index_fill(dim, index, 0)
  562. counts = torch.zeros_like(self).index_add(dim, index, ones)
  563. counts = counts.masked_fill(counts < 1, 1)
  564. out = out.index_add(dim, index, src)
  565. return out / counts if true_division else out // counts
  566. if use_scatter_fallback(
  567. aten.scatter_reduce_.two,
  568. reduction_type,
  569. self.dtype,
  570. src.dtype,
  571. src.device.type,
  572. True,
  573. ):
  574. return NotImplemented
  575. repeats = self.shape[dim + 1 :].numel() * self.shape[:dim].numel()
  576. index_shape = (index.numel(), *self.shape[dim + 1 :], *self.shape[:dim])
  577. perm = (*range(self.ndim - dim, self.ndim), 0, *range(1, self.ndim - dim))
  578. scatter_index = (
  579. index.to(torch.int64)
  580. .repeat_interleave(repeats)
  581. .reshape(index_shape)
  582. .permute(perm)
  583. )
  584. return self.scatter_reduce(
  585. dim,
  586. scatter_index,
  587. src,
  588. reduction_type,
  589. include_self=include_self,
  590. )
  591. @register_decomposition(aten.max_pool2d_with_indices)
  592. def max_pool2d_with_indices(
  593. x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False
  594. ):
  595. if dilation == 1:
  596. dilation = [1, 1]
  597. if padding == 0:
  598. padding = [0, 0]
  599. if stride is None:
  600. stride = kernel_size
  601. kernel_size = pad_listlike(kernel_size, 2)
  602. dilation = pad_listlike(dilation, 2)
  603. padding = pad_listlike(padding, 2)
  604. stride = pad_listlike(stride, 2)
  605. window_size = kernel_size[0] * kernel_size[1]
  606. # We fallback when using non-default dilation or when the window size is too large
  607. if (
  608. torch._inductor.lowering.should_fallback_max_pool2d_with_indices(
  609. kernel_size, dilation
  610. )
  611. or window_size > torch.iinfo(torch.int8).max
  612. ):
  613. return NotImplemented
  614. vals, offsets = prims._low_memory_max_pool2d_with_offsets(
  615. x,
  616. kernel_size,
  617. stride,
  618. padding,
  619. dilation,
  620. ceil_mode,
  621. )
  622. indices = prims._low_memory_max_pool2d_offsets_to_indices(
  623. offsets,
  624. kernel_size[1],
  625. x.size(-1),
  626. stride,
  627. padding,
  628. )
  629. return vals, indices