| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323 |
- # mypy: allow-untyped-defs
- import inspect
- from torch._custom_op.impl import (
- _custom_op_with_schema,
- _find_custom_op,
- infer_schema,
- parse_qualname,
- validate_namespace,
- )
- from torch.library import get_ctx
- __all__ = [
- "custom_op",
- "impl",
- "impl_abstract",
- "get_ctx",
- "impl_save_for_backward",
- "impl_backward",
- ]
- def custom_op(qualname, func_or_schema=None):
- r"""Register a new custom operator
- In PyTorch, defining an op (short for "operator") is a two step-process:
- - we need to define the op (by providing an operator name and schema)
- - we need to implement behavior for how the operator interacts with
- various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
- This entrypoint defines the custom operator (the first step)
- you must then perform the second step by calling various
- ``impl_*`` APIs.
- This API may be used as a decorator (see examples).
- For a detailed guide on custom ops, please see
- https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
- Arguments:
- qualname (str): Should be a string that looks like
- "namespace::operator_name". Operators in PyTorch need a namespace to
- avoid name collisions; a given operator may only be created once.
- If you are writing a Python library, we recommend the namespace to
- be the name of your top-level module.
- func_or_schema (Union[Callable, str]): Each PyTorch operator needs a
- schema that tells PyTorch the types of the inputs/outputs.
- If this is a Callable, we will automatically infer the schema from
- the type annotations on the function (see examples). Otherwise,
- if you don't want to use type annotations, you may provide us the
- schema string.
- Example::
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
- >>> import torch
- >>> import numpy as np
- >>> from torch import Tensor
- >>>
- >>> # Step 1: define the custom op.
- >>> # We need to provide the API a "prototype function"
- >>> # (a function that returns NotImplementedError), from which
- >>> # we will infer the types of the inputs and outputs.
- >>> @torch._custom_ops.custom_op("mylibrary::numpy_sin")
- >>> def numpy_sin(x: Tensor) -> Tensor:
- >>> raise NotImplementedError
- >>>
- >>> # The custom op is now accessible via the torch.ops module:
- >>> torch.ops.mylibrary.numpy_sin
- >>>
- >>> # Step 2: Register an implementation for various PyTorch subsystems
- >>>
- >>> # Register an implementation for CPU tensors
- >>> @torch._custom_ops.impl("mylibrary::numpy_sin", device_types="cpu")
- >>> def numpy_sin_impl_cpu(x):
- >>> return torch.from_numpy(np.sin(x.numpy()))
- >>>
- >>> # Register an implementation for CUDA tensors
- >>> @torch._custom_ops.impl("mylibrary::numpy_sin", device_types="cuda")
- >>> def numpy_sin_impl_cuda(x):
- >>> return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device)
- >>>
- >>> x = torch.randn(3)
- >>> torch.ops.mylibrary.numpy_sin(x) # calls numpy_sin_impl_cpu
- >>>
- >>> x_cuda = x.cuda()
- >>> torch.ops.mylibrary.numpy_sin(x) # calls numpy_sin_impl_cuda
- """
- ns, name = parse_qualname(qualname)
- validate_namespace(ns)
- def inner(func):
- if not inspect.isfunction(func):
- raise ValueError(
- f"custom_op(...)(func): Expected `func` to be a Python "
- f"function, got: {type(func)}"
- )
- if func.__name__ != name:
- raise ValueError(
- f"custom_op(qualname='{qualname}', ...)(func): expected `func` "
- f"to have name '{name}' but got '{func.__name__}'. "
- f"Please either change the name of `func` or the qualname that "
- f"is passed to `custom_op`"
- )
- schema = infer_schema(func)
- _custom_op_with_schema(qualname, schema)
- return func
- if func_or_schema is None:
- return inner
- if isinstance(func_or_schema, str):
- _custom_op_with_schema(qualname, func_or_schema)
- else:
- return inner(func_or_schema)
- def impl(qualname, *, device_types=("cpu", "cuda"), func=None):
- r"""Register an implementation for a device type for this custom op.
- If the op is passed multiple Tensor inputs with different device
- types, it will dispatch to the registered implementation for the highest
- priority device type among those present.
- The supported device types, in order of priority, are {'cuda', 'cpu'}.
- This API may be used as a decorator (see examples).
- For a detailed guide on custom ops, please see
- https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
- Arguments:
- device_types (str or Iterable[str]): the device type(s) to register the function for.
- Example::
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
- >>> import torch
- >>> import numpy as np
- >>> from torch import Tensor
- >>>
- >>> # Step 1: define the custom op.
- >>> # We need to provide the API a "prototype function"
- >>> # (a function that returns NotImplementedError), from which
- >>> # we will infer the types of the inputs and outputs.
- >>> @torch._custom_ops.custom_op("mylibrary::numpy_cos")
- >>> def numpy_cos(x: Tensor) -> Tensor:
- >>> raise NotImplementedError
- >>>
- >>> # The custom op is now accessible via the torch.ops module:
- >>> torch.ops.mylibrary.numpy_cos
- >>>
- >>> # Step 2: Register an implementation for various PyTorch subsystems
- >>>
- >>> # Register an implementation for CPU tensors
- >>> @torch._custom_ops.impl("mylibrary::numpy_cos", device_types="cpu")
- >>> def numpy_cos_impl_cpu(x):
- >>> return torch.from_numpy(np.cos(x.numpy()))
- >>>
- >>> # Register an implementation for CUDA tensors
- >>> @torch._custom_ops.impl("mylibrary::numpy_cos", device_types="cuda")
- >>> def numpy_cos_impl_cuda(x):
- >>> return torch.from_numpy(np.cos(x.cpu().numpy())).to(x.device)
- >>>
- >>> x = torch.randn(3)
- >>> torch.ops.mylibrary.numpy_cos(x) # calls numpy_cos_impl_cpu
- >>>
- >>> x_cuda = x.cuda()
- >>> torch.ops.mylibrary.numpy_cos(x) # calls numpy_cos_impl_cuda
- """
- def inner(func):
- custom_op = _find_custom_op(qualname, also_check_torch_library=True)
- custom_op.impl(device_types, _stacklevel=3)(func)
- return func
- if func is None:
- return inner
- return inner(func)
- def impl_abstract(qualname, *, func=None):
- r"""Register an abstract implementation for this operator.
- An "abstract implementation" specifies the behavior of this operator on
- Tensors that carry no data. Given some input Tensors with certain properties
- (sizes/strides/storage_offset/device), it specifies what the properties of
- the output Tensors are.
- The abstract implementation has the same signature as the operator.
- It is run for both FakeTensors and meta tensors. To write an abstract
- implementation, assume that all Tensor inputs to the operator are
- regular CPU/CUDA/Meta tensors, but they do not have storage, and
- you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
- The abstract implementation must consist of only PyTorch operations
- (and may not directly access the storage or data of any input or
- intermediate Tensors).
- This API may be used as a decorator (see examples).
- For a detailed guide on custom ops, please see
- https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
- Examples::
- >>> import numpy as np
- >>> from torch import Tensor
- >>>
- >>> # Example 1: an operator without data-dependent output shape
- >>> @torch._custom_ops.custom_op("mylibrary::custom_linear")
- >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
- >>> raise NotImplementedError
- >>>
- >>> @torch._custom_ops.impl_abstract("mylibrary::custom_linear")
- >>> def custom_linear_abstract(x, weight):
- >>> assert x.dim() == 2
- >>> assert weight.dim() == 2
- >>> assert bias.dim() == 1
- >>> assert x.shape[1] == weight.shape[1]
- >>> assert weight.shape[0] == bias.shape[0]
- >>> assert x.device == weight.device
- >>>
- >>> return (x @ weight.t()) + bias
- >>>
- >>> # Example 2: an operator with data-dependent output shape
- >>> @torch._custom_ops.custom_op('mylibrary::custom_nonzero')
- >>> def custom_nonzero(x: Tensor) -> Tensor:
- >>> ...
- >>>
- >>> @torch._custom_ops.impl_abstract("mylibrary::custom_nonzero")
- >>> def custom_nonzero_abstract(x):
- >>> # Number of nonzero-elements is data-dependent.
- >>> # Since we cannot peek at the data in an abstract impl,
- >>> # we use the ctx object to construct a new symint that
- >>> # represents the data-dependent size.
- >>> ctx = torch._custom_ops.get_ctx()
- >>> nnz = ctx.create_unbacked_symint()
- >>> shape = [x.dim(), nnz]
- >>> result = x.new_empty(shape, dtype=torch.long)
- >>> return result
- >>>
- >>> @torch._custom_ops.impl("mylibrary::custom_nonzero")
- >>> def custom_nonzero_impl(x):
- >>> x_np = to_numpy(x)
- >>> res = np.stack(np.nonzero(x_np), axis=1)
- >>> # unbacked symbolic ints in PyTorch must be >= 2, so we
- >>> # constrain the range to at least 2
- >>> if res.shape[0] <= 1:
- >>> raise RuntimeError("not supported")
- >>> return torch.tensor(res, device=x.device)
- """
- import torch.library
- return torch.library.register_fake(qualname, func, _stacklevel=2)
- def impl_save_for_backward(qualname, *, func=None):
- r"""Register a function that tells us what to save for backward.
- Please see :func:`impl_backward` for more details.
- """
- def inner(func):
- custom_op = _find_custom_op(qualname, also_check_torch_library=True)
- custom_op.impl_save_for_backward(_stacklevel=3)(func)
- return func
- if func is None:
- return inner
- return inner(func)
- def impl_backward(qualname, output_differentiability=None, *, func=None):
- r"""Registers a backward formula for an operator.
- In order for an operator to work with autograd, you need to register
- a backward formula. There are two pieces to this:
- 1. You must give us a function to specify what to save for backward.
- Call this the "save for backward" function.
- 2. You must give us a function that computes gradients. Call this the
- "backward" function.
- Use `impl_save_for_backward` to define a "save for backward" function
- that specifies what gets saved for backward. The function should accept
- two arguments ``(inputs, output)`` and return the quantities to be saved
- for backward.
- During runtime, when you call the operator in a forwards pass, PyTorch
- will invoke the "save for backward" function with the inputs and output
- of the operator.
- Use `impl_backward` to define the "backward" function. The backward
- function must accept ``(ctx, saved, *grads)``:
- - ``ctx`` is a context object where we may provide information
- - ``saved`` is exactly what gets returned from the "save for backward"
- function
- - ``grads`` is one or more gradients. The number of gradients matches
- the number of outputs of the operator.
- The backward function must return a dict that maps the name of
- an input to the operator to its corresponding gradient. All inputs that
- were declared to be Tensors in the operator definition must be accounted
- for in the dict. The gradient may be a Tensor or None.
- For a detailed guide on custom ops, please see
- https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
- """
- def inner(func):
- custom_op = _find_custom_op(qualname, also_check_torch_library=True)
- custom_op.impl_backward(output_differentiability, _stacklevel=3)(func)
- return func
- if func is None:
- return inner
- return inner(func)
- def _destroy(qualname):
- """De-registers a custom op. For testing purposes only"""
- custom_op = _find_custom_op(qualname)
- custom_op._destroy()
|