_custom_ops.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. # mypy: allow-untyped-defs
  2. import inspect
  3. from torch._custom_op.impl import (
  4. _custom_op_with_schema,
  5. _find_custom_op,
  6. infer_schema,
  7. parse_qualname,
  8. validate_namespace,
  9. )
  10. from torch.library import get_ctx
  11. __all__ = [
  12. "custom_op",
  13. "impl",
  14. "impl_abstract",
  15. "get_ctx",
  16. "impl_save_for_backward",
  17. "impl_backward",
  18. ]
  19. def custom_op(qualname, func_or_schema=None):
  20. r"""Register a new custom operator
  21. In PyTorch, defining an op (short for "operator") is a two step-process:
  22. - we need to define the op (by providing an operator name and schema)
  23. - we need to implement behavior for how the operator interacts with
  24. various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
  25. This entrypoint defines the custom operator (the first step)
  26. you must then perform the second step by calling various
  27. ``impl_*`` APIs.
  28. This API may be used as a decorator (see examples).
  29. For a detailed guide on custom ops, please see
  30. https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
  31. Arguments:
  32. qualname (str): Should be a string that looks like
  33. "namespace::operator_name". Operators in PyTorch need a namespace to
  34. avoid name collisions; a given operator may only be created once.
  35. If you are writing a Python library, we recommend the namespace to
  36. be the name of your top-level module.
  37. func_or_schema (Union[Callable, str]): Each PyTorch operator needs a
  38. schema that tells PyTorch the types of the inputs/outputs.
  39. If this is a Callable, we will automatically infer the schema from
  40. the type annotations on the function (see examples). Otherwise,
  41. if you don't want to use type annotations, you may provide us the
  42. schema string.
  43. Example::
  44. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  45. >>> import torch
  46. >>> import numpy as np
  47. >>> from torch import Tensor
  48. >>>
  49. >>> # Step 1: define the custom op.
  50. >>> # We need to provide the API a "prototype function"
  51. >>> # (a function that returns NotImplementedError), from which
  52. >>> # we will infer the types of the inputs and outputs.
  53. >>> @torch._custom_ops.custom_op("mylibrary::numpy_sin")
  54. >>> def numpy_sin(x: Tensor) -> Tensor:
  55. >>> raise NotImplementedError
  56. >>>
  57. >>> # The custom op is now accessible via the torch.ops module:
  58. >>> torch.ops.mylibrary.numpy_sin
  59. >>>
  60. >>> # Step 2: Register an implementation for various PyTorch subsystems
  61. >>>
  62. >>> # Register an implementation for CPU tensors
  63. >>> @torch._custom_ops.impl("mylibrary::numpy_sin", device_types="cpu")
  64. >>> def numpy_sin_impl_cpu(x):
  65. >>> return torch.from_numpy(np.sin(x.numpy()))
  66. >>>
  67. >>> # Register an implementation for CUDA tensors
  68. >>> @torch._custom_ops.impl("mylibrary::numpy_sin", device_types="cuda")
  69. >>> def numpy_sin_impl_cuda(x):
  70. >>> return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device)
  71. >>>
  72. >>> x = torch.randn(3)
  73. >>> torch.ops.mylibrary.numpy_sin(x) # calls numpy_sin_impl_cpu
  74. >>>
  75. >>> x_cuda = x.cuda()
  76. >>> torch.ops.mylibrary.numpy_sin(x) # calls numpy_sin_impl_cuda
  77. """
  78. ns, name = parse_qualname(qualname)
  79. validate_namespace(ns)
  80. def inner(func):
  81. if not inspect.isfunction(func):
  82. raise ValueError(
  83. f"custom_op(...)(func): Expected `func` to be a Python "
  84. f"function, got: {type(func)}"
  85. )
  86. if func.__name__ != name:
  87. raise ValueError(
  88. f"custom_op(qualname='{qualname}', ...)(func): expected `func` "
  89. f"to have name '{name}' but got '{func.__name__}'. "
  90. f"Please either change the name of `func` or the qualname that "
  91. f"is passed to `custom_op`"
  92. )
  93. schema = infer_schema(func)
  94. _custom_op_with_schema(qualname, schema)
  95. return func
  96. if func_or_schema is None:
  97. return inner
  98. if isinstance(func_or_schema, str):
  99. _custom_op_with_schema(qualname, func_or_schema)
  100. else:
  101. return inner(func_or_schema)
  102. def impl(qualname, *, device_types=("cpu", "cuda"), func=None):
  103. r"""Register an implementation for a device type for this custom op.
  104. If the op is passed multiple Tensor inputs with different device
  105. types, it will dispatch to the registered implementation for the highest
  106. priority device type among those present.
  107. The supported device types, in order of priority, are {'cuda', 'cpu'}.
  108. This API may be used as a decorator (see examples).
  109. For a detailed guide on custom ops, please see
  110. https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
  111. Arguments:
  112. device_types (str or Iterable[str]): the device type(s) to register the function for.
  113. Example::
  114. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  115. >>> import torch
  116. >>> import numpy as np
  117. >>> from torch import Tensor
  118. >>>
  119. >>> # Step 1: define the custom op.
  120. >>> # We need to provide the API a "prototype function"
  121. >>> # (a function that returns NotImplementedError), from which
  122. >>> # we will infer the types of the inputs and outputs.
  123. >>> @torch._custom_ops.custom_op("mylibrary::numpy_cos")
  124. >>> def numpy_cos(x: Tensor) -> Tensor:
  125. >>> raise NotImplementedError
  126. >>>
  127. >>> # The custom op is now accessible via the torch.ops module:
  128. >>> torch.ops.mylibrary.numpy_cos
  129. >>>
  130. >>> # Step 2: Register an implementation for various PyTorch subsystems
  131. >>>
  132. >>> # Register an implementation for CPU tensors
  133. >>> @torch._custom_ops.impl("mylibrary::numpy_cos", device_types="cpu")
  134. >>> def numpy_cos_impl_cpu(x):
  135. >>> return torch.from_numpy(np.cos(x.numpy()))
  136. >>>
  137. >>> # Register an implementation for CUDA tensors
  138. >>> @torch._custom_ops.impl("mylibrary::numpy_cos", device_types="cuda")
  139. >>> def numpy_cos_impl_cuda(x):
  140. >>> return torch.from_numpy(np.cos(x.cpu().numpy())).to(x.device)
  141. >>>
  142. >>> x = torch.randn(3)
  143. >>> torch.ops.mylibrary.numpy_cos(x) # calls numpy_cos_impl_cpu
  144. >>>
  145. >>> x_cuda = x.cuda()
  146. >>> torch.ops.mylibrary.numpy_cos(x) # calls numpy_cos_impl_cuda
  147. """
  148. def inner(func):
  149. custom_op = _find_custom_op(qualname, also_check_torch_library=True)
  150. custom_op.impl(device_types, _stacklevel=3)(func)
  151. return func
  152. if func is None:
  153. return inner
  154. return inner(func)
  155. def impl_abstract(qualname, *, func=None):
  156. r"""Register an abstract implementation for this operator.
  157. An "abstract implementation" specifies the behavior of this operator on
  158. Tensors that carry no data. Given some input Tensors with certain properties
  159. (sizes/strides/storage_offset/device), it specifies what the properties of
  160. the output Tensors are.
  161. The abstract implementation has the same signature as the operator.
  162. It is run for both FakeTensors and meta tensors. To write an abstract
  163. implementation, assume that all Tensor inputs to the operator are
  164. regular CPU/CUDA/Meta tensors, but they do not have storage, and
  165. you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
  166. The abstract implementation must consist of only PyTorch operations
  167. (and may not directly access the storage or data of any input or
  168. intermediate Tensors).
  169. This API may be used as a decorator (see examples).
  170. For a detailed guide on custom ops, please see
  171. https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
  172. Examples::
  173. >>> import numpy as np
  174. >>> from torch import Tensor
  175. >>>
  176. >>> # Example 1: an operator without data-dependent output shape
  177. >>> @torch._custom_ops.custom_op("mylibrary::custom_linear")
  178. >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
  179. >>> raise NotImplementedError
  180. >>>
  181. >>> @torch._custom_ops.impl_abstract("mylibrary::custom_linear")
  182. >>> def custom_linear_abstract(x, weight):
  183. >>> assert x.dim() == 2
  184. >>> assert weight.dim() == 2
  185. >>> assert bias.dim() == 1
  186. >>> assert x.shape[1] == weight.shape[1]
  187. >>> assert weight.shape[0] == bias.shape[0]
  188. >>> assert x.device == weight.device
  189. >>>
  190. >>> return (x @ weight.t()) + bias
  191. >>>
  192. >>> # Example 2: an operator with data-dependent output shape
  193. >>> @torch._custom_ops.custom_op('mylibrary::custom_nonzero')
  194. >>> def custom_nonzero(x: Tensor) -> Tensor:
  195. >>> ...
  196. >>>
  197. >>> @torch._custom_ops.impl_abstract("mylibrary::custom_nonzero")
  198. >>> def custom_nonzero_abstract(x):
  199. >>> # Number of nonzero-elements is data-dependent.
  200. >>> # Since we cannot peek at the data in an abstract impl,
  201. >>> # we use the ctx object to construct a new symint that
  202. >>> # represents the data-dependent size.
  203. >>> ctx = torch._custom_ops.get_ctx()
  204. >>> nnz = ctx.create_unbacked_symint()
  205. >>> shape = [x.dim(), nnz]
  206. >>> result = x.new_empty(shape, dtype=torch.long)
  207. >>> return result
  208. >>>
  209. >>> @torch._custom_ops.impl("mylibrary::custom_nonzero")
  210. >>> def custom_nonzero_impl(x):
  211. >>> x_np = to_numpy(x)
  212. >>> res = np.stack(np.nonzero(x_np), axis=1)
  213. >>> # unbacked symbolic ints in PyTorch must be >= 2, so we
  214. >>> # constrain the range to at least 2
  215. >>> if res.shape[0] <= 1:
  216. >>> raise RuntimeError("not supported")
  217. >>> return torch.tensor(res, device=x.device)
  218. """
  219. import torch.library
  220. return torch.library.register_fake(qualname, func, _stacklevel=2)
  221. def impl_save_for_backward(qualname, *, func=None):
  222. r"""Register a function that tells us what to save for backward.
  223. Please see :func:`impl_backward` for more details.
  224. """
  225. def inner(func):
  226. custom_op = _find_custom_op(qualname, also_check_torch_library=True)
  227. custom_op.impl_save_for_backward(_stacklevel=3)(func)
  228. return func
  229. if func is None:
  230. return inner
  231. return inner(func)
  232. def impl_backward(qualname, output_differentiability=None, *, func=None):
  233. r"""Registers a backward formula for an operator.
  234. In order for an operator to work with autograd, you need to register
  235. a backward formula. There are two pieces to this:
  236. 1. You must give us a function to specify what to save for backward.
  237. Call this the "save for backward" function.
  238. 2. You must give us a function that computes gradients. Call this the
  239. "backward" function.
  240. Use `impl_save_for_backward` to define a "save for backward" function
  241. that specifies what gets saved for backward. The function should accept
  242. two arguments ``(inputs, output)`` and return the quantities to be saved
  243. for backward.
  244. During runtime, when you call the operator in a forwards pass, PyTorch
  245. will invoke the "save for backward" function with the inputs and output
  246. of the operator.
  247. Use `impl_backward` to define the "backward" function. The backward
  248. function must accept ``(ctx, saved, *grads)``:
  249. - ``ctx`` is a context object where we may provide information
  250. - ``saved`` is exactly what gets returned from the "save for backward"
  251. function
  252. - ``grads`` is one or more gradients. The number of gradients matches
  253. the number of outputs of the operator.
  254. The backward function must return a dict that maps the name of
  255. an input to the operator to its corresponding gradient. All inputs that
  256. were declared to be Tensors in the operator definition must be accounted
  257. for in the dict. The gradient may be a Tensor or None.
  258. For a detailed guide on custom ops, please see
  259. https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
  260. """
  261. def inner(func):
  262. custom_op = _find_custom_op(qualname, also_check_torch_library=True)
  263. custom_op.impl_backward(output_differentiability, _stacklevel=3)(func)
  264. return func
  265. if func is None:
  266. return inner
  267. return inner(func)
  268. def _destroy(qualname):
  269. """De-registers a custom op. For testing purposes only"""
  270. custom_op = _find_custom_op(qualname)
  271. custom_op._destroy()