_decomposition_utils.py 402 B

123456789101112
  1. # mypy: allow-untyped-defs
  2. import torch
  3. from torch._ops import OpOverload, OpOverloadPacket
  4. def _register_decomposition(op: OpOverload, graph: torch._C.Graph):
  5. assert not isinstance(
  6. op, OpOverloadPacket
  7. ), f"Must pass specific op overload, not overload packet, found {op}"
  8. assert isinstance(op, OpOverload)
  9. torch._C._jit_register_decomposition_for_schema(op._schema, graph)