| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466 |
- # mypy: allow-untyped-defs
- import inspect
- from collections import defaultdict
- from functools import wraps
- from itertools import chain
- from typing import Callable, Dict, List, Sequence, Union
- import torch
- import torch.library
- from torch._ops import HigherOrderOperator, OpOverload, OpOverloadPacket
- from torch._prims_common import CustomOutParamAnnotation
- from torch.utils import _pytree as pytree
- __all__ = [
- "decomposition_table",
- "pre_autograd_decomposition_table",
- "meta_table",
- "register_decomposition",
- "get_decompositions",
- "core_aten_decompositions",
- ]
- # TODO: relax key type here; torch registrations should be possible to; but
- # right now this type is accurate
- global_decomposition_table: Dict[
- str, Dict[torch._ops.OperatorBase, Callable]
- ] = defaultdict(dict)
- decomposition_table = global_decomposition_table["post_autograd"]
- pre_autograd_decomposition_table = global_decomposition_table["pre_autograd"]
- meta_table = global_decomposition_table["meta"]
- def _add_op_to_registry(registry, op, fn):
- """
- This is an internal API for adding an op to the decomposition table.
- If op is OpOverload, it will be added to the registry directly.
- If op is OpOverloadPacket, all the valid op_overloads in the packet will be added to the registry.
- """
- overloads: List[Union[torch._ops.OperatorBase]] = []
- if isinstance(op, HigherOrderOperator):
- # There's no concept of overloads for HigherOrderOperator
- registry[op] = fn
- return
- elif isinstance(op, OpOverload):
- overloads.append(op)
- else:
- assert isinstance(op, OpOverloadPacket)
- for ol in op.overloads():
- overloads.append(getattr(op, ol))
- for op_overload in overloads:
- if op_overload in registry:
- raise RuntimeError(f"duplicate registrations for {op_overload}")
- # TorchScript dumps a bunch of extra nonsense overloads
- # which don't have corresponding dispatcher entries, we need
- # to filter those out, e.g aten.add.float_int
- if torch._C._dispatch_has_kernel(op_overload.name()):
- registry[op_overload] = fn
- def _convert_out_params(f):
- out_annotation = f.__annotations__.get("out")
- # If there are no out params, do not wrap the function.
- if not out_annotation:
- return f
- # Hack to detect when out is a Tuple. There seems to be no pretty way of doing this
- if getattr(out_annotation, "__origin__", None) is tuple:
- sig = inspect.signature(f)
- out_names = sig.return_annotation._fields
- # If out is a tuple, we need to register a function that unpacks all the out
- # elements as this is what native_functions.yaml expects
- @wraps(f)
- def _fn(*args, **kwargs):
- out_kwargs = tuple(kwargs.pop(o, None) for o in out_names)
- # Either all of the out kwargs are set or none of them
- is_none = out_kwargs[0] is None
- assert all((o is None) == is_none for o in out_kwargs)
- return f(*args, **kwargs, out=None if is_none else out_kwargs)
- out_params = [
- inspect.Parameter(
- o,
- kind=inspect.Parameter.KEYWORD_ONLY,
- default=None,
- annotation=t,
- )
- for o, t in zip(out_names, out_annotation.__args__)
- ]
- # Drop the out parameter and concatenate the new kwargs in the signature
- params = chain((v for k, v in sig.parameters.items() if k != "out"), out_params)
- _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
- parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type]
- )
- # Drop the out parameter and concatenate the new kwargs in the annotations
- _fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"}
- for o in out_params:
- _fn.__annotations__[o.name] = o.annotation
- # Propagate that this function is wrapped by `out_wrapper`
- _fn._torch_decompositions_out_wrapper = f._torch_decompositions_out_wrapper # type: ignore[attr-defined]
- return _fn
- # Alternatively, there may be a single tensor out parameter with a name
- # other than "out". This will need special treatment and is indicated by an
- # annotation, which we will remove here so it is not exposed after wrapping.
- custom_out_param_name = f.__annotations__.pop(CustomOutParamAnnotation, None)
- if custom_out_param_name:
- @wraps(f)
- def _fn(*args, **kwargs):
- out_kwarg = kwargs.pop(custom_out_param_name, None)
- return f(*args, **kwargs, out=out_kwarg)
- out_param = inspect.Parameter(
- custom_out_param_name,
- kind=inspect.Parameter.KEYWORD_ONLY,
- default=None,
- annotation=out_annotation,
- )
- # Drop the out parameter and concatenate the new kwarg in the signature
- sig = inspect.signature(f)
- params = chain(
- (v for k, v in sig.parameters.items() if k != "out"), (out_param,)
- )
- _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
- parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type]
- )
- # Drop the out parameter and concatenate the new kwargs in the annotations
- _fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"}
- _fn.__annotations__[out_param.name] = out_param.annotation
- return _fn
- return f
- def register_decomposition(
- aten_op, registry=None, *, type="post_autograd", unsafe=False
- ):
- """
- A decorator to register a function as a decomposition to the Python
- decomposition table. Use it like this::
- @register_decomposition(torch.ops.aten.clamp_min)
- def clamp_min(x):
- return torch.clamp(self, min=min)
- If you are writing a new decomposition, consider contributing it
- directly to PyTorch in torch._decomp.decompositions.
- This API is experimental; we are almost certainly going to extend
- the API when we make decompositions eligible for use in transforms (e.g.,
- autograd) and not just backend tracing, where we then need to know if a
- decomposition can be used to simulate a transform.
- By default, we also will register it to the Meta key of dispatcher,
- and replace the c++ Meta implementation if there is already one.
- unsafe kwarg is for reuse of this function for registering non-function
- things
- """
- assert type in {"post_autograd", "pre_autograd", "meta"}
- def decomposition_decorator(fn: Callable) -> Callable:
- orig_fn = fn
- if not unsafe:
- fn = _convert_out_params(fn)
- nonlocal registry
- if registry is None:
- registry = global_decomposition_table[type]
- def register(op):
- _add_op_to_registry(registry, op, fn)
- # To handle allowing multiple aten_ops at once
- pytree.tree_map_(register, aten_op)
- return orig_fn
- return decomposition_decorator
- def get_decompositions(
- aten_ops: Sequence[Union[torch._ops.OperatorBase, OpOverloadPacket]],
- type: str = "post_autograd",
- ) -> Dict[torch._ops.OperatorBase, Callable]:
- """
- Retrieve a dictionary of decompositions corresponding to the list of
- operator overloads and overload packets passed as input. Overload
- packets will include all decomposed overloads in the packet. If there is
- no decomposition for a requested operator, it is silently ignored.
- This API is experimental; we are almost certainly going to give an alternate,
- more recommended formulation, where a user provides the set of operators
- they know how to implement, and we provide decompositions for everything
- not in this set.
- """
- assert type in {"post_autograd", "pre_autograd", "meta"}
- registry = global_decomposition_table[type]
- packets_to_overloads = defaultdict(list)
- for opo in registry:
- if isinstance(opo, (OpOverload, OpOverloadPacket)):
- packets_to_overloads[opo.overloadpacket].append(opo)
- decompositions: Dict[torch._ops.OperatorBase, Callable] = {}
- for op in aten_ops:
- if isinstance(op, OpOverloadPacket) and op in packets_to_overloads:
- for op_overload in packets_to_overloads[op]:
- decompositions[op_overload] = registry[op_overload]
- elif isinstance(op, (torch._ops.OperatorBase)) and op in registry:
- decompositions[op] = registry[op]
- return decompositions
- def remove_decompositions(
- decompositions: Dict[torch._ops.OperatorBase, Callable],
- aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]],
- ) -> None:
- """
- Given a dictionary of decompositions obtained from get_decompositions(), removes
- operators associated with a list of operator overloads and overload packets passed
- as input. If the decomposition dictionary does not contain a decomposition that is
- specified to be removed, it is silently ignored.
- """
- for op in aten_ops:
- if isinstance(op, OpOverloadPacket):
- for overload_name in op.overloads():
- opo = getattr(op, overload_name)
- decompositions.pop(opo, None)
- elif isinstance(op, OpOverload):
- decompositions.pop(op, None)
- # populate the table
- import torch._decomp.decompositions
- import torch._refs
- # See NOTE [Core ATen Ops]
- #
- # list was copied from torch/_inductor/decomposition.py
- # excluding decompositions that results in prim ops
- # Resulting opset of decomposition is core aten ops
- def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
- aten = torch.ops.aten
- return get_decompositions(
- [
- aten.addcdiv,
- aten.addcdiv_,
- aten.addcmul,
- aten.addcmul_,
- aten.addr,
- aten.affine_grid_generator,
- aten.all,
- aten.aminmax,
- aten.arange.default,
- aten.arange.start,
- aten.avg_pool2d_backward,
- aten.baddbmm,
- aten.binary_cross_entropy,
- aten.binary_cross_entropy_backward,
- aten.binary_cross_entropy_with_logits,
- aten.block_diag,
- aten.celu,
- aten.celu_,
- aten.clamp_max,
- aten.clamp_min,
- aten.col2im,
- aten.count_nonzero,
- aten.linalg_cross,
- aten.cudnn_batch_norm,
- aten.cudnn_batch_norm_backward,
- aten.miopen_batch_norm_backward,
- aten.deg2rad,
- aten.deg2rad_,
- aten.detach,
- aten.diag_embed,
- aten.diagonal_backward,
- aten.dot,
- aten.vdot,
- aten.elu,
- aten.elu_,
- aten.elu_backward,
- aten._embedding_bag,
- aten.embedding_dense_backward,
- aten.empty_like,
- aten._euclidean_dist.default,
- aten.expand_as,
- aten.eye,
- aten.fill,
- aten.fill_,
- aten.floor_divide,
- aten.frac,
- aten.frac_,
- aten._fused_moving_avg_obs_fq_helper,
- aten.gelu_,
- aten.gelu_backward,
- aten.glu,
- aten.glu_backward,
- aten.hardshrink,
- aten.hardsigmoid,
- aten.hardsigmoid_,
- aten.hardsigmoid_backward,
- aten.hardswish,
- aten.hardswish_,
- aten.hardswish_backward,
- aten.hardtanh_,
- aten.hardtanh_backward,
- aten.heaviside,
- aten.heaviside_,
- aten.huber_loss,
- aten.huber_loss_backward,
- aten.im2col,
- aten.index_add,
- aten.index_add_,
- aten.index_copy,
- aten.index_copy_,
- aten.index_fill,
- aten.index_fill_,
- aten.isin,
- aten.isneginf,
- aten.isposinf,
- aten.l1_loss,
- aten._lazy_clone,
- aten._test_parallel_materialize,
- aten.leaky_relu_,
- aten.leaky_relu_backward,
- aten.lerp,
- aten.lerp_,
- aten.linspace,
- aten.logaddexp,
- aten.logaddexp2,
- aten.logit,
- aten.logit_,
- aten.logit_backward,
- aten.log_sigmoid_backward,
- aten.log_sigmoid_forward,
- aten._log_softmax_backward_data,
- aten.logspace,
- aten.logsumexp.default,
- aten.masked_fill,
- aten.masked_fill_,
- aten.mish,
- aten.mish_,
- aten.mse_loss,
- aten.mse_loss_backward,
- aten.multi_margin_loss,
- aten.multilabel_margin_loss_forward,
- aten.mv,
- aten.mvlgamma,
- aten.mvlgamma_,
- aten.nansum,
- aten.nan_to_num,
- aten.nan_to_num_,
- aten.narrow,
- aten.native_batch_norm_backward,
- aten.native_dropout_backward,
- aten.native_group_norm_backward,
- aten.native_layer_norm_backward,
- aten.new_empty,
- aten.new_full,
- aten.new_ones,
- aten.new_zeros,
- aten.nll_loss_backward,
- aten.nll_loss_forward,
- aten.norm,
- aten.ones,
- aten.ones_like,
- aten.pixel_shuffle,
- aten.pixel_unshuffle,
- aten._prelu_kernel,
- aten._prelu_kernel_backward,
- aten._reshape_alias,
- aten.rad2deg,
- aten.rad2deg_,
- aten.reflection_pad1d,
- aten.reflection_pad2d,
- aten.reflection_pad3d,
- aten.replication_pad1d,
- aten.replication_pad2d,
- aten.replication_pad3d,
- aten.renorm,
- aten.renorm_,
- aten.replication_pad2d,
- aten.resize_as,
- aten.roll,
- aten.rot90,
- aten.rrelu_with_noise,
- aten.rrelu_with_noise_,
- aten.rsub,
- aten._scaled_dot_product_flash_attention_for_cpu.default,
- aten.select_backward,
- aten.select_scatter,
- aten.sgn,
- aten.sgn_,
- aten.sigmoid_backward,
- aten.silu,
- aten.silu_,
- aten.silu_backward,
- aten.sinc,
- aten.sinc_,
- aten.slice_backward,
- aten.smooth_l1_loss,
- aten.smooth_l1_loss_backward,
- aten.soft_margin_loss,
- aten.soft_margin_loss_backward,
- aten._softmax_backward_data,
- aten.softplus,
- aten.softplus_backward,
- aten.softshrink,
- aten.special_entr,
- aten.special_log_ndtr,
- aten.special_xlog1py,
- aten.split.Tensor,
- aten.split_with_sizes_copy,
- aten.squeeze.default,
- aten.squeeze.dim,
- aten.std,
- aten.std_mean,
- aten.stack,
- aten.sum.default,
- aten.sum.out,
- aten.t,
- aten.take,
- aten.tanh_backward,
- aten.threshold,
- aten.threshold_,
- aten.threshold_backward,
- aten.trace,
- aten.transpose.int,
- aten.tril,
- aten.tril_,
- aten.triu,
- aten.triu_,
- aten.unbind,
- aten.unfold_backward,
- aten.unfold_copy,
- aten._unsafe_index,
- aten.unsafe_split.Tensor,
- aten.unsafe_split_with_sizes,
- aten._unsafe_view,
- aten.upsample_linear1d,
- aten.upsample_bilinear2d,
- aten.upsample_trilinear3d,
- aten.upsample_nearest2d_backward,
- aten.view_as_complex,
- aten.xlogy,
- aten.xlogy_,
- aten.zero,
- aten.zero_,
- aten.zeros,
- aten.zeros_like,
- aten._chunk_cat,
- aten._weight_norm_interface,
- ]
- )
|