__init__.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. # mypy: allow-untyped-defs
  2. import inspect
  3. from collections import defaultdict
  4. from functools import wraps
  5. from itertools import chain
  6. from typing import Callable, Dict, List, Sequence, Union
  7. import torch
  8. import torch.library
  9. from torch._ops import HigherOrderOperator, OpOverload, OpOverloadPacket
  10. from torch._prims_common import CustomOutParamAnnotation
  11. from torch.utils import _pytree as pytree
  12. __all__ = [
  13. "decomposition_table",
  14. "pre_autograd_decomposition_table",
  15. "meta_table",
  16. "register_decomposition",
  17. "get_decompositions",
  18. "core_aten_decompositions",
  19. ]
  20. # TODO: relax key type here; torch registrations should be possible to; but
  21. # right now this type is accurate
  22. global_decomposition_table: Dict[
  23. str, Dict[torch._ops.OperatorBase, Callable]
  24. ] = defaultdict(dict)
  25. decomposition_table = global_decomposition_table["post_autograd"]
  26. pre_autograd_decomposition_table = global_decomposition_table["pre_autograd"]
  27. meta_table = global_decomposition_table["meta"]
  28. def _add_op_to_registry(registry, op, fn):
  29. """
  30. This is an internal API for adding an op to the decomposition table.
  31. If op is OpOverload, it will be added to the registry directly.
  32. If op is OpOverloadPacket, all the valid op_overloads in the packet will be added to the registry.
  33. """
  34. overloads: List[Union[torch._ops.OperatorBase]] = []
  35. if isinstance(op, HigherOrderOperator):
  36. # There's no concept of overloads for HigherOrderOperator
  37. registry[op] = fn
  38. return
  39. elif isinstance(op, OpOverload):
  40. overloads.append(op)
  41. else:
  42. assert isinstance(op, OpOverloadPacket)
  43. for ol in op.overloads():
  44. overloads.append(getattr(op, ol))
  45. for op_overload in overloads:
  46. if op_overload in registry:
  47. raise RuntimeError(f"duplicate registrations for {op_overload}")
  48. # TorchScript dumps a bunch of extra nonsense overloads
  49. # which don't have corresponding dispatcher entries, we need
  50. # to filter those out, e.g aten.add.float_int
  51. if torch._C._dispatch_has_kernel(op_overload.name()):
  52. registry[op_overload] = fn
  53. def _convert_out_params(f):
  54. out_annotation = f.__annotations__.get("out")
  55. # If there are no out params, do not wrap the function.
  56. if not out_annotation:
  57. return f
  58. # Hack to detect when out is a Tuple. There seems to be no pretty way of doing this
  59. if getattr(out_annotation, "__origin__", None) is tuple:
  60. sig = inspect.signature(f)
  61. out_names = sig.return_annotation._fields
  62. # If out is a tuple, we need to register a function that unpacks all the out
  63. # elements as this is what native_functions.yaml expects
  64. @wraps(f)
  65. def _fn(*args, **kwargs):
  66. out_kwargs = tuple(kwargs.pop(o, None) for o in out_names)
  67. # Either all of the out kwargs are set or none of them
  68. is_none = out_kwargs[0] is None
  69. assert all((o is None) == is_none for o in out_kwargs)
  70. return f(*args, **kwargs, out=None if is_none else out_kwargs)
  71. out_params = [
  72. inspect.Parameter(
  73. o,
  74. kind=inspect.Parameter.KEYWORD_ONLY,
  75. default=None,
  76. annotation=t,
  77. )
  78. for o, t in zip(out_names, out_annotation.__args__)
  79. ]
  80. # Drop the out parameter and concatenate the new kwargs in the signature
  81. params = chain((v for k, v in sig.parameters.items() if k != "out"), out_params)
  82. _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
  83. parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type]
  84. )
  85. # Drop the out parameter and concatenate the new kwargs in the annotations
  86. _fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"}
  87. for o in out_params:
  88. _fn.__annotations__[o.name] = o.annotation
  89. # Propagate that this function is wrapped by `out_wrapper`
  90. _fn._torch_decompositions_out_wrapper = f._torch_decompositions_out_wrapper # type: ignore[attr-defined]
  91. return _fn
  92. # Alternatively, there may be a single tensor out parameter with a name
  93. # other than "out". This will need special treatment and is indicated by an
  94. # annotation, which we will remove here so it is not exposed after wrapping.
  95. custom_out_param_name = f.__annotations__.pop(CustomOutParamAnnotation, None)
  96. if custom_out_param_name:
  97. @wraps(f)
  98. def _fn(*args, **kwargs):
  99. out_kwarg = kwargs.pop(custom_out_param_name, None)
  100. return f(*args, **kwargs, out=out_kwarg)
  101. out_param = inspect.Parameter(
  102. custom_out_param_name,
  103. kind=inspect.Parameter.KEYWORD_ONLY,
  104. default=None,
  105. annotation=out_annotation,
  106. )
  107. # Drop the out parameter and concatenate the new kwarg in the signature
  108. sig = inspect.signature(f)
  109. params = chain(
  110. (v for k, v in sig.parameters.items() if k != "out"), (out_param,)
  111. )
  112. _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
  113. parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type]
  114. )
  115. # Drop the out parameter and concatenate the new kwargs in the annotations
  116. _fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"}
  117. _fn.__annotations__[out_param.name] = out_param.annotation
  118. return _fn
  119. return f
  120. def register_decomposition(
  121. aten_op, registry=None, *, type="post_autograd", unsafe=False
  122. ):
  123. """
  124. A decorator to register a function as a decomposition to the Python
  125. decomposition table. Use it like this::
  126. @register_decomposition(torch.ops.aten.clamp_min)
  127. def clamp_min(x):
  128. return torch.clamp(self, min=min)
  129. If you are writing a new decomposition, consider contributing it
  130. directly to PyTorch in torch._decomp.decompositions.
  131. This API is experimental; we are almost certainly going to extend
  132. the API when we make decompositions eligible for use in transforms (e.g.,
  133. autograd) and not just backend tracing, where we then need to know if a
  134. decomposition can be used to simulate a transform.
  135. By default, we also will register it to the Meta key of dispatcher,
  136. and replace the c++ Meta implementation if there is already one.
  137. unsafe kwarg is for reuse of this function for registering non-function
  138. things
  139. """
  140. assert type in {"post_autograd", "pre_autograd", "meta"}
  141. def decomposition_decorator(fn: Callable) -> Callable:
  142. orig_fn = fn
  143. if not unsafe:
  144. fn = _convert_out_params(fn)
  145. nonlocal registry
  146. if registry is None:
  147. registry = global_decomposition_table[type]
  148. def register(op):
  149. _add_op_to_registry(registry, op, fn)
  150. # To handle allowing multiple aten_ops at once
  151. pytree.tree_map_(register, aten_op)
  152. return orig_fn
  153. return decomposition_decorator
  154. def get_decompositions(
  155. aten_ops: Sequence[Union[torch._ops.OperatorBase, OpOverloadPacket]],
  156. type: str = "post_autograd",
  157. ) -> Dict[torch._ops.OperatorBase, Callable]:
  158. """
  159. Retrieve a dictionary of decompositions corresponding to the list of
  160. operator overloads and overload packets passed as input. Overload
  161. packets will include all decomposed overloads in the packet. If there is
  162. no decomposition for a requested operator, it is silently ignored.
  163. This API is experimental; we are almost certainly going to give an alternate,
  164. more recommended formulation, where a user provides the set of operators
  165. they know how to implement, and we provide decompositions for everything
  166. not in this set.
  167. """
  168. assert type in {"post_autograd", "pre_autograd", "meta"}
  169. registry = global_decomposition_table[type]
  170. packets_to_overloads = defaultdict(list)
  171. for opo in registry:
  172. if isinstance(opo, (OpOverload, OpOverloadPacket)):
  173. packets_to_overloads[opo.overloadpacket].append(opo)
  174. decompositions: Dict[torch._ops.OperatorBase, Callable] = {}
  175. for op in aten_ops:
  176. if isinstance(op, OpOverloadPacket) and op in packets_to_overloads:
  177. for op_overload in packets_to_overloads[op]:
  178. decompositions[op_overload] = registry[op_overload]
  179. elif isinstance(op, (torch._ops.OperatorBase)) and op in registry:
  180. decompositions[op] = registry[op]
  181. return decompositions
  182. def remove_decompositions(
  183. decompositions: Dict[torch._ops.OperatorBase, Callable],
  184. aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]],
  185. ) -> None:
  186. """
  187. Given a dictionary of decompositions obtained from get_decompositions(), removes
  188. operators associated with a list of operator overloads and overload packets passed
  189. as input. If the decomposition dictionary does not contain a decomposition that is
  190. specified to be removed, it is silently ignored.
  191. """
  192. for op in aten_ops:
  193. if isinstance(op, OpOverloadPacket):
  194. for overload_name in op.overloads():
  195. opo = getattr(op, overload_name)
  196. decompositions.pop(opo, None)
  197. elif isinstance(op, OpOverload):
  198. decompositions.pop(op, None)
  199. # populate the table
  200. import torch._decomp.decompositions
  201. import torch._refs
  202. # See NOTE [Core ATen Ops]
  203. #
  204. # list was copied from torch/_inductor/decomposition.py
  205. # excluding decompositions that results in prim ops
  206. # Resulting opset of decomposition is core aten ops
  207. def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
  208. aten = torch.ops.aten
  209. return get_decompositions(
  210. [
  211. aten.addcdiv,
  212. aten.addcdiv_,
  213. aten.addcmul,
  214. aten.addcmul_,
  215. aten.addr,
  216. aten.affine_grid_generator,
  217. aten.all,
  218. aten.aminmax,
  219. aten.arange.default,
  220. aten.arange.start,
  221. aten.avg_pool2d_backward,
  222. aten.baddbmm,
  223. aten.binary_cross_entropy,
  224. aten.binary_cross_entropy_backward,
  225. aten.binary_cross_entropy_with_logits,
  226. aten.block_diag,
  227. aten.celu,
  228. aten.celu_,
  229. aten.clamp_max,
  230. aten.clamp_min,
  231. aten.col2im,
  232. aten.count_nonzero,
  233. aten.linalg_cross,
  234. aten.cudnn_batch_norm,
  235. aten.cudnn_batch_norm_backward,
  236. aten.miopen_batch_norm_backward,
  237. aten.deg2rad,
  238. aten.deg2rad_,
  239. aten.detach,
  240. aten.diag_embed,
  241. aten.diagonal_backward,
  242. aten.dot,
  243. aten.vdot,
  244. aten.elu,
  245. aten.elu_,
  246. aten.elu_backward,
  247. aten._embedding_bag,
  248. aten.embedding_dense_backward,
  249. aten.empty_like,
  250. aten._euclidean_dist.default,
  251. aten.expand_as,
  252. aten.eye,
  253. aten.fill,
  254. aten.fill_,
  255. aten.floor_divide,
  256. aten.frac,
  257. aten.frac_,
  258. aten._fused_moving_avg_obs_fq_helper,
  259. aten.gelu_,
  260. aten.gelu_backward,
  261. aten.glu,
  262. aten.glu_backward,
  263. aten.hardshrink,
  264. aten.hardsigmoid,
  265. aten.hardsigmoid_,
  266. aten.hardsigmoid_backward,
  267. aten.hardswish,
  268. aten.hardswish_,
  269. aten.hardswish_backward,
  270. aten.hardtanh_,
  271. aten.hardtanh_backward,
  272. aten.heaviside,
  273. aten.heaviside_,
  274. aten.huber_loss,
  275. aten.huber_loss_backward,
  276. aten.im2col,
  277. aten.index_add,
  278. aten.index_add_,
  279. aten.index_copy,
  280. aten.index_copy_,
  281. aten.index_fill,
  282. aten.index_fill_,
  283. aten.isin,
  284. aten.isneginf,
  285. aten.isposinf,
  286. aten.l1_loss,
  287. aten._lazy_clone,
  288. aten._test_parallel_materialize,
  289. aten.leaky_relu_,
  290. aten.leaky_relu_backward,
  291. aten.lerp,
  292. aten.lerp_,
  293. aten.linspace,
  294. aten.logaddexp,
  295. aten.logaddexp2,
  296. aten.logit,
  297. aten.logit_,
  298. aten.logit_backward,
  299. aten.log_sigmoid_backward,
  300. aten.log_sigmoid_forward,
  301. aten._log_softmax_backward_data,
  302. aten.logspace,
  303. aten.logsumexp.default,
  304. aten.masked_fill,
  305. aten.masked_fill_,
  306. aten.mish,
  307. aten.mish_,
  308. aten.mse_loss,
  309. aten.mse_loss_backward,
  310. aten.multi_margin_loss,
  311. aten.multilabel_margin_loss_forward,
  312. aten.mv,
  313. aten.mvlgamma,
  314. aten.mvlgamma_,
  315. aten.nansum,
  316. aten.nan_to_num,
  317. aten.nan_to_num_,
  318. aten.narrow,
  319. aten.native_batch_norm_backward,
  320. aten.native_dropout_backward,
  321. aten.native_group_norm_backward,
  322. aten.native_layer_norm_backward,
  323. aten.new_empty,
  324. aten.new_full,
  325. aten.new_ones,
  326. aten.new_zeros,
  327. aten.nll_loss_backward,
  328. aten.nll_loss_forward,
  329. aten.norm,
  330. aten.ones,
  331. aten.ones_like,
  332. aten.pixel_shuffle,
  333. aten.pixel_unshuffle,
  334. aten._prelu_kernel,
  335. aten._prelu_kernel_backward,
  336. aten._reshape_alias,
  337. aten.rad2deg,
  338. aten.rad2deg_,
  339. aten.reflection_pad1d,
  340. aten.reflection_pad2d,
  341. aten.reflection_pad3d,
  342. aten.replication_pad1d,
  343. aten.replication_pad2d,
  344. aten.replication_pad3d,
  345. aten.renorm,
  346. aten.renorm_,
  347. aten.replication_pad2d,
  348. aten.resize_as,
  349. aten.roll,
  350. aten.rot90,
  351. aten.rrelu_with_noise,
  352. aten.rrelu_with_noise_,
  353. aten.rsub,
  354. aten._scaled_dot_product_flash_attention_for_cpu.default,
  355. aten.select_backward,
  356. aten.select_scatter,
  357. aten.sgn,
  358. aten.sgn_,
  359. aten.sigmoid_backward,
  360. aten.silu,
  361. aten.silu_,
  362. aten.silu_backward,
  363. aten.sinc,
  364. aten.sinc_,
  365. aten.slice_backward,
  366. aten.smooth_l1_loss,
  367. aten.smooth_l1_loss_backward,
  368. aten.soft_margin_loss,
  369. aten.soft_margin_loss_backward,
  370. aten._softmax_backward_data,
  371. aten.softplus,
  372. aten.softplus_backward,
  373. aten.softshrink,
  374. aten.special_entr,
  375. aten.special_log_ndtr,
  376. aten.special_xlog1py,
  377. aten.split.Tensor,
  378. aten.split_with_sizes_copy,
  379. aten.squeeze.default,
  380. aten.squeeze.dim,
  381. aten.std,
  382. aten.std_mean,
  383. aten.stack,
  384. aten.sum.default,
  385. aten.sum.out,
  386. aten.t,
  387. aten.take,
  388. aten.tanh_backward,
  389. aten.threshold,
  390. aten.threshold_,
  391. aten.threshold_backward,
  392. aten.trace,
  393. aten.transpose.int,
  394. aten.tril,
  395. aten.tril_,
  396. aten.triu,
  397. aten.triu_,
  398. aten.unbind,
  399. aten.unfold_backward,
  400. aten.unfold_copy,
  401. aten._unsafe_index,
  402. aten.unsafe_split.Tensor,
  403. aten.unsafe_split_with_sizes,
  404. aten._unsafe_view,
  405. aten.upsample_linear1d,
  406. aten.upsample_bilinear2d,
  407. aten.upsample_trilinear3d,
  408. aten.upsample_nearest2d_backward,
  409. aten.view_as_complex,
  410. aten.xlogy,
  411. aten.xlogy_,
  412. aten.zero,
  413. aten.zero_,
  414. aten.zeros,
  415. aten.zeros_like,
  416. aten._chunk_cat,
  417. aten._weight_norm_interface,
  418. ]
  419. )