impl.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873
  1. # mypy: allow-untyped-defs
  2. import dataclasses
  3. import functools
  4. import inspect
  5. import sys
  6. import typing
  7. import weakref
  8. from torchgen.model import FunctionSchema, OperatorName, SchemaKind, BaseType, ListType, BaseTy
  9. import torch
  10. import torch._C as _C
  11. import torch.library as library
  12. from torch._library.abstract_impl import AbstractImplCtx
  13. from torch.library import get_ctx
  14. from .autograd import autograd_kernel_indirection, construct_autograd_kernel
  15. import torch._library.infer_schema
  16. from torch._library.infer_schema import infer_schema
  17. """
  18. For a detailed guide on custom ops, please see
  19. https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
  20. This file includes pieces of the implementation of our custom operator API.
  21. """
  22. __all__ = ["custom_op", "CustomOp", "get_ctx", "AbstractImplCtx"]
  23. SUPPORTED_DEVICE_TYPE_TO_KEY = {
  24. "cpu": "CPU",
  25. "cuda": "CUDA",
  26. }
  27. # We will not let users register CustomOps with anything that could look like
  28. # PyTorch internals to avoid confusion.
  29. RESERVED_NS = {
  30. "prim",
  31. "prims",
  32. "aten",
  33. "at",
  34. "torch",
  35. "pytorch",
  36. }
  37. def custom_op(
  38. qualname: str, manual_schema: typing.Optional[str] = None
  39. ) -> typing.Callable:
  40. r"""Creates a new CustomOp object.
  41. WARNING: if you're a user, please do not use this directly
  42. (instead use the torch._custom_ops APIs).
  43. Also please see the following for a detailed guide on custom ops.
  44. https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
  45. In PyTorch, defining an op (short for "operator") is a two step-process:
  46. - we need to define (create) the op
  47. - we need to implement behavior for how the operator interacts with
  48. various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
  49. This entrypoint defines the CustomOp object (the first step);
  50. you must then perform the second step by calling various methods on
  51. the CustomOp object.
  52. This API is used as a decorator (see examples).
  53. Arguments:
  54. qualname (str): Should be a string that looks like
  55. "namespace::operator_name". Operators in PyTorch need a namespace to
  56. avoid name collisions; a given operator may only be created once.
  57. If you are writing a Python library, we recommend the namespace to
  58. be the name of your top-level module. The operator_name must be
  59. the same as the name of the function you pass to custom_op
  60. (see examples).
  61. manual_schema (Optional[str]): Each PyTorch operator needs a schema that
  62. tells PyTorch the types of the inputs/outputs. If None (default),
  63. we will infer the schema from the type annotations on the function
  64. (see examples). Otherwise, if you don't want to use type annotations,
  65. you may provide us the schema string.
  66. Example::
  67. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  68. >>> import numpy as np
  69. >>> from torch import Tensor
  70. >>>
  71. >>> # Step 1: define the CustomOp.
  72. >>> # We need to provide the decorator a "prototype function"
  73. >>> # (a function with Python ellipses as the body).
  74. >>> @custom_op("my_library::numpy_sin")
  75. >>> def numpy_sin(x: Tensor) -> Tensor:
  76. >>> ...
  77. >>>
  78. >>> # numpy_sin is now an instance of class CustomOp
  79. >>> print(type(numpy_sin))
  80. >>>
  81. >>> # Step 2: Register an implementation for various PyTorch subsystems
  82. >>>
  83. >>> # Register an implementation for CPU tensors
  84. >>> @numpy_sin.impl('cpu')
  85. >>> def numpy_sin_impl_cpu(x):
  86. >>> return torch.from_numpy(np.sin(x.numpy()))
  87. >>>
  88. >>> # Register an implementation for CUDA tensors
  89. >>> @numpy_sin.impl('cuda')
  90. >>> def numpy_sin_impl_cuda(x):
  91. >>> return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device)
  92. >>>
  93. >>> x = torch.randn(3)
  94. >>> numpy_sin(x) # calls numpy_sin_impl_cpu
  95. >>>
  96. >>> x_cuda = x.cuda()
  97. >>> numpy_sin(x) # calls numpy_sin_impl_cuda
  98. """
  99. def inner(func):
  100. if not inspect.isfunction(func):
  101. raise ValueError(
  102. f"custom_op(...)(func): Expected `func` to be a Python "
  103. f"function, got: {type(func)}"
  104. )
  105. ns, name = parse_qualname(qualname)
  106. validate_namespace(ns)
  107. if func.__name__ != name:
  108. raise ValueError(
  109. f"custom_op(qualname='{qualname}', ...)(func): expected `func` "
  110. f"to have name '{name}' but got '{func.__name__}'. "
  111. f"Please either change the name of `func` or the qualname that "
  112. f"is passed to `custom_op`"
  113. )
  114. schema = infer_schema(func) if manual_schema is None else manual_schema
  115. schema_str = f"{name}{schema}"
  116. function_schema = FunctionSchema.parse(schema_str)
  117. validate_schema(function_schema)
  118. if manual_schema is not None:
  119. validate_function_matches_schema(function_schema, func)
  120. lib = library.Library(ns, "FRAGMENT")
  121. lib.define(schema_str)
  122. ophandle = find_ophandle_or_throw(ns, function_schema.name)
  123. result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
  124. result.__name__ = func.__name__
  125. result.__module__ = func.__module__
  126. result.__doc__ = func.__doc__
  127. library.impl(lib, result._opname, "Autograd")(
  128. autograd_kernel_indirection(weakref.proxy(result))
  129. )
  130. torch._C._dispatch_set_report_error_callback(
  131. ophandle, functools.partial(report_error_callback, weakref.proxy(result))
  132. )
  133. return result
  134. return inner
  135. # Global dictionary holding references to all CustomOp objects
  136. # Yes, it keeps all CustomOps alive (see NOTE [CustomOp lifetime])
  137. # Used to query the CustomOp associated with a specific C++ dispatcher operator.
  138. # An example usage is FakeTensor: FakeTensor checks if a specific operator
  139. # has an implementation registered via the CustomOp API.
  140. # Indexed by qualname (e.g. aten::foo)
  141. global_registry: typing.Dict[str, "CustomOp"] = {}
  142. class CustomOp:
  143. r"""Class for custom operators in PyTorch.
  144. Use the CustomOp API to create user-defined custom operators that behave
  145. just like regular PyTorch operators (e.g. torch.sin, torch.mm) when it
  146. comes to various PyTorch subsystems (like torch.compile).
  147. To construct a `CustomOp`, use `custom_op`.
  148. """
  149. def __init__(self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_access=False):
  150. super().__init__()
  151. if not _private_access:
  152. raise RuntimeError(
  153. "The CustomOp constructor is private and we do not guarantee "
  154. "BC for it. Please use custom_op(...) to create a CustomOp object"
  155. )
  156. name = f"{cpp_ns}::{operator_name}"
  157. self._schema = schema
  158. self._cpp_ns = cpp_ns
  159. self._lib: library.Library = lib
  160. self._ophandle: _C._DispatchOperatorHandle = ophandle
  161. # Has the name of the op, e.g. "foo". We cache here for convenience.
  162. self._opname: str = operator_name
  163. # this is _opname but with namespace. e.g. "custom::foo"
  164. self._qualname: str = name
  165. self.__name__ = None # mypy requires this
  166. # NB: Some of these impls are registered as kernels to DispatchKeys.
  167. # Modifying the _impls dict directly won't do anything in that case.
  168. self._impls: typing.Dict[str, typing.Optional[FuncAndLocation]] = {}
  169. # See NOTE [CustomOp autograd kernel indirection]
  170. self._registered_autograd_kernel_indirection = False
  171. global_registry[self._qualname] = self
  172. def _register_autograd_kernel_indirection(self):
  173. assert not self._registered_autograd_kernel_indirection
  174. self._lib.impl(self._opname, autograd_kernel_indirection(weakref.proxy(self)), "Autograd")
  175. self._registered_autograd_kernel_indirection = True
  176. # Records the impl and the source location in self._impls
  177. # Note that this doesn't cause torch.library to use the impl, that
  178. # needs to be done in a separate self._lib.impl call.
  179. def _register_impl(self, kind, func, stacklevel=2):
  180. if self._has_impl(kind):
  181. func_and_location = self._impls[kind]
  182. assert func_and_location is not None # Pacify mypy
  183. location = func_and_location.location
  184. raise RuntimeError(
  185. f"Attempting to register a {kind} impl for operator {self._qualname} "
  186. f"that already has a {kind} impl registered from Python at "
  187. f"{location}. This is not supported."
  188. )
  189. frame = inspect.getframeinfo(sys._getframe(stacklevel))
  190. location = f"{frame.filename}:{frame.lineno}"
  191. self._impls[kind] = FuncAndLocation(func, location)
  192. def _get_impl(self, kind):
  193. return self._impls[kind]
  194. def _has_impl(self, kind):
  195. return kind in self._impls
  196. def _destroy(self):
  197. # NOTE: [CustomOp lifetime]
  198. # A CustomOp, once created, lives forever. The mechanism is that the
  199. # global registry holds a reference to it. However, to make testing
  200. # easier, we want to be able to destroy CustomOp objects.
  201. # CustomOp._destroy does the job, though it leaves the CustomOp
  202. # in a garbage state.
  203. del self._lib
  204. opnamespace = getattr(torch.ops, self._cpp_ns)
  205. if hasattr(opnamespace, self._opname):
  206. delattr(opnamespace, self._opname)
  207. del global_registry[self._qualname]
  208. def __repr__(self):
  209. return f'<CustomOp(op="{self._qualname}")>'
  210. def __call__(self, *args, **kwargs):
  211. # Bypass torch.ops.* and directly do OperatorHandle::callBoxed.
  212. # Using torch.ops.* is a bit of a pain (it can be slow and it has lifetime
  213. # issues from caching operators that make testing CustomOp difficult).
  214. result = _C._dispatch_call_boxed(self._ophandle, *args, **kwargs)
  215. return result
  216. def impl(
  217. self, device_types: typing.Union[str, typing.Iterable[str]], _stacklevel=2,
  218. ) -> typing.Callable:
  219. r"""Register an implementation for a device type for this CustomOp object.
  220. WARNING: if you're a user, please do not use this directly
  221. (instead use the torch._custom_ops APIs).
  222. Also please see the following for a detailed guide on custom ops.
  223. https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
  224. If the CustomOp is passed multiple Tensor inputs with different device
  225. types, it will dispatch to the registered implementation for the highest
  226. priority device type among those present.
  227. The supported device types, in order of priority, are {'cuda', 'cpu'}.
  228. This API is used as a decorator (see examples).
  229. Arguments:
  230. device_types (str or Iterable[str]): the device type(s) to register the function for.
  231. Examples::
  232. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  233. >>> import numpy as np
  234. >>> from torch import Tensor
  235. >>>
  236. >>> @custom_op("my_library::numpy_cos")
  237. >>> def numpy_cos(x: Tensor) -> Tensor:
  238. >>> ...
  239. >>>
  240. >>> # Register an implementation for CPU Tensors
  241. >>> @numpy_cos.impl('cpu')
  242. >>> def numpy_cos_impl_cpu(x):
  243. >>> return torch.from_numpy(np.cos(x.numpy()))
  244. >>>
  245. >>> # Register an implementation for CUDA Tensors
  246. >>> @numpy_cos.impl('cuda')
  247. >>> def numpy_cos_impl_cuda(x):
  248. >>> return torch.from_numpy(np.cos(x.cpu().numpy())).to(x.device)
  249. >>>
  250. >>> x = torch.randn(3)
  251. >>> numpy_cos(x) # calls numpy_cos_impl_cpu
  252. >>>
  253. >>> x_cuda = x.cuda()
  254. >>> numpy_cos(x) # calls numpy_cos_impl_cuda
  255. """
  256. if isinstance(device_types, str):
  257. device_types = [device_types]
  258. for device_type in device_types:
  259. validate_device_type(device_type)
  260. def inner(f):
  261. for device_type in set(device_types):
  262. self._check_doesnt_have_library_impl(device_type)
  263. self._register_impl(device_type, f, stacklevel=_stacklevel)
  264. dispatch_key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
  265. library.impl(self._lib, self._opname, dispatch_key)(f)
  266. return f
  267. return inner
  268. def _check_doesnt_have_library_impl(self, device_type):
  269. if self._has_impl(device_type):
  270. return
  271. key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
  272. if _C._dispatch_has_computed_kernel_for_dispatch_key(self._qualname, key):
  273. raise RuntimeError(
  274. f"impl(..., device_types={device_type}): the operator {self._qualname} "
  275. f"already has an implementation for this device type via a "
  276. f"pre-existing torch.library or TORCH_LIBRARY registration.")
  277. def impl_factory(self) -> typing.Callable:
  278. r"""Register an implementation for a factory function."""
  279. def inner(f):
  280. self._register_impl("factory", f)
  281. library.impl(self._lib, self._opname, "BackendSelect")(f)
  282. return f
  283. return inner
  284. def impl_abstract(self, _stacklevel=2) -> typing.Callable:
  285. r"""Register an abstract implementation for this operator.
  286. WARNING: please do not use this directly (and instead use the torch._custom_ops
  287. APIs). Also please see the following for a detailed guide on custom ops.
  288. https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
  289. An "abstract implementation" specifies the behavior of this operator on
  290. Tensors that carry no data. Given some input Tensors with certain properties
  291. (sizes/strides/storage_offset/device), it specifies what the properties of
  292. the output Tensors are.
  293. The abstract implementation has the same signature as the operator.
  294. It is run for both FakeTensors and meta tensors. To write an abstract
  295. implementation, assume that all Tensor inputs to the operator are
  296. regular CPU/CUDA/Meta tensors, but they do not have storage, and
  297. you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
  298. The abstract implementation must consist of only PyTorch operations
  299. (and may not directly access the storage or data of any input or
  300. intermediate Tensors).
  301. This API is used as a decorator (see examples).
  302. Examples::
  303. >>> import numpy as np
  304. >>> from torch import Tensor
  305. >>>
  306. >>> # Example 1: an operator without data-dependent output shape
  307. >>> @custom_op('my_library::custom_linear')
  308. >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
  309. >>> ...
  310. >>>
  311. >>> @custom_linear.impl_abstract()
  312. >>> def custom_linear_abstract(x, weight):
  313. >>> assert x.dim() == 2
  314. >>> assert weight.dim() == 2
  315. >>> assert bias.dim() == 1
  316. >>> assert x.shape[1] == weight.shape[1]
  317. >>> assert weight.shape[0] == bias.shape[0]
  318. >>> assert x.device == weight.device
  319. >>>
  320. >>> return (x @ weight.t()) + bias
  321. >>>
  322. >>> # Example 2: an operator with data-dependent output shape
  323. >>> @custom_op('my_library::custom_nonzero')
  324. >>> def custom_nonzero(x: Tensor) -> Tensor:
  325. >>> ...
  326. >>>
  327. >>> @custom_nonzero.impl_abstract()
  328. >>> def custom_nonzero_abstract(x):
  329. >>> # Number of nonzero-elements is data-dependent.
  330. >>> # Since we cannot peek at the data in an abstract impl,
  331. >>> # we use the ctx object to construct a new symint that
  332. >>> # represents the data-dependent size.
  333. >>> ctx = torch._custom_op.get_ctx()
  334. >>> nnz = ctx.create_unbacked_symint()
  335. >>> shape = [x.dim(), nnz]
  336. >>> result = x.new_empty(shape, dtype=torch.long)
  337. >>> return result
  338. >>>
  339. >>> @custom_nonzero.impl(['cpu', 'cuda'])
  340. >>> def custom_nonzero_impl(x):
  341. >>> x_np = to_numpy(x)
  342. >>> res = np.stack(np.nonzero(x_np), axis=1)
  343. >>> # unbacked symbolic ints in PyTorch must be >= 2, so we
  344. >>> # constrain the range to at least 2
  345. >>> if res.shape[0] <= 1:
  346. >>> raise RuntimeError("not supported")
  347. >>> return torch.tensor(res, device=x.device)
  348. """
  349. def inner(f):
  350. self._check_doesnt_have_library_meta_impl()
  351. self._register_impl("abstract", f, stacklevel=_stacklevel)
  352. location = self._get_impl("abstract").location
  353. qualname = self._qualname
  354. # Handle DispatchKey.Meta registration
  355. @functools.wraps(f)
  356. def f_with_ctx(*args, **kwargs):
  357. def error_on_ctx():
  358. raise RuntimeError(
  359. f"Attempted to call get_ctx() for the meta implementation "
  360. f"for {qualname}."
  361. f"You have presumably called get_ctx() because the operator "
  362. f"has a data-dependent output shape; if so, there is no "
  363. f"such meta implementation and this error is the correct "
  364. f"behavior. Otherwise, please remove the call to get_ctx() "
  365. f"in the implementation registered with impl_abstract "
  366. f"at {location}"
  367. )
  368. with torch._library.abstract_impl.set_ctx_getter(error_on_ctx):
  369. return f(*args, **kwargs)
  370. self._lib.impl(self._opname, f_with_ctx, "Meta")
  371. return f
  372. return inner
  373. def _check_can_register_backward(self):
  374. def error(detail):
  375. raise RuntimeError(
  376. f"Cannot use torch._custom_ops APIs to register backward "
  377. f"formula for {detail}. Got operator "
  378. f"{self._qualname} with schema: {schema}"
  379. )
  380. schema = self._schema
  381. if schema.kind() != SchemaKind.functional:
  382. error("non-functional operator")
  383. rets = schema.returns
  384. if not schema.returns:
  385. error("operator with no returns")
  386. assert len(rets) > 0
  387. is_non_mutating_view = any(
  388. r.annotation is not None and not r.annotation.is_write for r in rets
  389. )
  390. if is_non_mutating_view:
  391. error("operator that returns views")
  392. # We make assumptions about the schema's return types.
  393. allowed_return_types = {
  394. BaseType(BaseTy.int): "int",
  395. BaseType(BaseTy.SymInt): "SymInt",
  396. BaseType(BaseTy.bool): "bool",
  397. BaseType(BaseTy.float): "float",
  398. BaseType(BaseTy.Tensor): "Tensor",
  399. ListType(BaseType(BaseTy.Tensor), None): "List[Tensor]",
  400. }
  401. for ret in schema.returns:
  402. if ret.type in allowed_return_types:
  403. continue
  404. error(f"operator with return not in {list(allowed_return_types.values())} (got {ret.type})")
  405. def _check_doesnt_have_library_autograd_impl(self):
  406. if self._registered_autograd_kernel_indirection:
  407. return
  408. if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"):
  409. raise RuntimeError(
  410. f"impl_backward/impl_save_for_backward: the operator {self._qualname} "
  411. f"already has an implementation for this device type via a "
  412. f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
  413. f"CompositeImplicitAutograd operators do not need an autograd formula; "
  414. f"instead, the operator will decompose into its constituents and those "
  415. f"can have autograd formulas defined on them.")
  416. # We can improve this by adding "all Autograd<BACKEND> keys", but
  417. # realistically people will just be using this API for CPU/CUDA for now.
  418. for key in ["Autograd", "AutogradCPU", "AutogradCUDA"]:
  419. if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, key):
  420. raise RuntimeError(
  421. f"impl_backward/impl_save_for_backward: "
  422. f"the operator {self._qualname} already has an Autograd kernel "
  423. f"registered to DispatchKey::{key} vi a pre-existing "
  424. f"torch.library or TORCH_LIBRARY registration. Please either "
  425. f"remove those registrations or don't use the torch._custom_ops APIs")
  426. def _check_doesnt_have_library_meta_impl(self):
  427. if self._has_impl("abstract"):
  428. return
  429. # If the user's operator is CompositeExplicitAutograd,
  430. # allow them to impl_abstract. This is being pragmatic
  431. # (existing custom ops may have CompositeExplicitAutograd
  432. # registration that don't work with Meta kernels, so this
  433. # gives them an escape hatch).
  434. if (
  435. _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeExplicitAutograd")
  436. and not _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta")
  437. ):
  438. return
  439. # Otherwise, if the user's already has a Meta kernel or their
  440. # op is CompositeImplicitAutograd or some other alias dispatch key,
  441. # raise.
  442. # Special case for CompositeImplicitAutograd
  443. if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"):
  444. raise RuntimeError(
  445. f"impl_abstract(...): the operator {self._qualname} "
  446. f"already has an implementation for this device type via a "
  447. f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
  448. f"CompositeImplicitAutograd operators do not need an abstract impl; "
  449. f"instead, the operator will decompose into its constituents and those "
  450. f"can have abstract impls defined on them.")
  451. if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"):
  452. raise RuntimeError(
  453. f"impl_abstract(...): the operator {self._qualname} "
  454. f"already has an DispatchKey::Meta implementation via a "
  455. f"pre-existing torch.library or TORCH_LIBRARY registration. "
  456. f"Please either remove that registration or don't call impl_abstract.")
  457. # NOTE ["backward", "save_for_backward", and "autograd"]
  458. # As a part of the explicit autograd API, a user must provide us
  459. # a "save_for_backward" function and a "backward" function.
  460. # When both of these have been provided, then we automatically
  461. # construct the "autograd" kernel.
  462. def _register_autograd_kernel(self):
  463. assert self._has_impl("backward")
  464. assert self._has_impl("save_for_backward")
  465. kernel = construct_autograd_kernel(
  466. self._schema,
  467. self._output_differentiability,
  468. self,
  469. get_op(self._qualname),
  470. self._get_impl("save_for_backward").func,
  471. self._get_impl("backward").func)
  472. self._register_impl("autograd", kernel)
  473. def impl_save_for_backward(self, _stacklevel=2):
  474. r"""Register a function that tells us what to save for backward.
  475. Please see impl_backward for more details.
  476. """
  477. def inner(f):
  478. self._check_can_register_backward()
  479. self._check_doesnt_have_library_autograd_impl()
  480. if not self._registered_autograd_kernel_indirection:
  481. self._register_autograd_kernel_indirection()
  482. self._register_impl("save_for_backward", f, stacklevel=_stacklevel)
  483. if self._has_impl("backward"):
  484. self._register_autograd_kernel()
  485. return inner
  486. def impl_backward(self, output_differentiability=None, _stacklevel=2):
  487. r"""Registers a backward formula.
  488. WARNING: if you're a user, please do not use this directly
  489. (instead use the torch._custom_ops APIs).
  490. Also please see the following for a detailed guide on custom ops.
  491. https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
  492. In order for the CustomOp to work with autograd, you need to register
  493. a backward formula. There are two pieces to this:
  494. 1. You must give us a function to specify what to save for backward.
  495. Call this the "save for backward" function.
  496. 2. You must give us a function that computes gradients. Call this the
  497. "backward" function.
  498. Use `impl_save_for_backward` to define a "save for backward" function
  499. that specifies what gets saved for backward. The function should accept
  500. two arguments ``(inputs, output)`` and return the quantities to be saved
  501. for backward.
  502. During runtime, when you call the CustomOp, PyTorch will invoke the
  503. "save for backward" function with the inputs and output of the CustomOp.
  504. Use `impl_backward` to define the "backward" function. The backward
  505. function must accept ``(ctx, saved, *grads)``:
  506. - ``ctx`` is a context object where we may provide information
  507. - ``saved`` is exactly what gets returned from the "save for backward"
  508. function
  509. - ``grads`` is one or more gradients. The number of gradients matches
  510. the number of outputs of the CustomOp.
  511. The backward function must return a dict that maps the name of
  512. an input to the CustomOp to its corresponding gradient. All inputs that
  513. were declared to be Tensors in the CustomOp definition must be accounted
  514. for in the dict. The gradient may be a Tensor or None.
  515. """
  516. if output_differentiability is not None:
  517. def yell():
  518. raise RuntimeError(
  519. f"impl_backward(output_differentiability): expected "
  520. f"output_differentiability to be a list of bools with "
  521. f"length equal to the number of outputs of this CustomOp "
  522. f"got: {output_differentiability}")
  523. if not isinstance(output_differentiability, list):
  524. yell()
  525. for diff in output_differentiability:
  526. if not isinstance(diff, bool):
  527. yell()
  528. if len(self._schema.returns) != len(output_differentiability):
  529. yell()
  530. def inner(f):
  531. self._check_can_register_backward()
  532. self._check_doesnt_have_library_autograd_impl()
  533. if not self._registered_autograd_kernel_indirection:
  534. self._register_autograd_kernel_indirection()
  535. self._register_impl("backward", f, stacklevel=_stacklevel)
  536. self._output_differentiability = output_differentiability
  537. if self._has_impl("save_for_backward"):
  538. self._register_autograd_kernel()
  539. return inner
  540. @dataclasses.dataclass
  541. class FuncAndLocation:
  542. func: typing.Callable
  543. location: str
  544. def find_ophandle_or_throw(cpp_ns: str, operator_name: OperatorName):
  545. overload_name = (
  546. "" if operator_name.overload_name is None else operator_name.overload_name
  547. )
  548. return _C._dispatch_find_schema_or_throw(
  549. f"{cpp_ns}::{str(operator_name.name)}", overload_name
  550. )
  551. def validate_namespace(ns: str) -> None:
  552. if "." in ns:
  553. raise ValueError(
  554. f'custom_op(..., ns="{ns}"): expected ns to not contain any . (and be a '
  555. f"valid variable name)"
  556. )
  557. if ns in RESERVED_NS:
  558. raise ValueError(
  559. f"custom_op(..., ns='{ns}'): '{ns}' is a reserved namespace, "
  560. f"please choose something else. "
  561. )
  562. def validate_schema(schema: FunctionSchema) -> None:
  563. if not torch._library.utils.is_functional_schema(schema):
  564. raise ValueError(
  565. f"custom_op only supports functional operators "
  566. f"(ops that do not mutate any inputs, do not return "
  567. f"views of the inputs, and has at least one return). "
  568. f"Got the following non-functional schema: {schema}"
  569. )
  570. # For simplicity: don't allow self arguments
  571. if schema.arguments.self_arg is not None:
  572. raise ValueError(
  573. f"custom_op does not support arguments named 'self'. Please "
  574. f"rename your argument. Got: {schema}"
  575. )
  576. def parse_qualname(qualname: str) -> typing.Tuple[str, str]:
  577. names = qualname.split("::", 1)
  578. if len(names) != 2:
  579. raise ValueError(f"Expected there to be a namespace in {qualname}, i.e. The "
  580. f"operator name should look something like ns::foo")
  581. if '.' in names[1]:
  582. raise ValueError(f"The torch.custom_ops APIs do not handle overloads, "
  583. f"i.e. operator names with '.' in them. "
  584. f"Please name your operator something like ns::foo. "
  585. f"Got: {qualname}")
  586. return names[0], names[1]
  587. def validate_device_type(device_type: str) -> None:
  588. if device_type not in SUPPORTED_DEVICE_TYPE_TO_KEY:
  589. raise ValueError(
  590. f"CustomOp.impl(device_types=[{device_type}, ...]): we only support device_type "
  591. f"in {SUPPORTED_DEVICE_TYPE_TO_KEY.keys()}."
  592. )
  593. def supported_param(param: inspect.Parameter) -> bool:
  594. return param.kind in (
  595. inspect.Parameter.POSITIONAL_OR_KEYWORD,
  596. inspect.Parameter.KEYWORD_ONLY,
  597. )
  598. def validate_function_matches_schema(
  599. schema: FunctionSchema, func: typing.Callable
  600. ) -> None:
  601. sig = inspect.signature(func)
  602. if not all(supported_param(p) for _, p in sig.parameters.items()):
  603. raise ValueError(
  604. f"custom_op(..., manual_schema)(func): positional-only args, "
  605. f"varargs, and kwargs are not supported. Please rewrite `func` "
  606. f"to not have them. Got `func` with signature: {sig}"
  607. )
  608. if (
  609. any(
  610. p.annotation is not inspect.Parameter.empty
  611. for _, p in sig.parameters.items()
  612. )
  613. or sig.return_annotation is not inspect.Signature.empty
  614. ):
  615. raise ValueError(
  616. f"custom_op(..., manual_schema)(func): When passing in a manual "
  617. f"schema, we expect `func` to have no type annotations to avoid "
  618. f"ambiguity. Got `func` with signature: {sig}"
  619. )
  620. positional = [
  621. (name, param)
  622. for name, param in sig.parameters.items()
  623. if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
  624. ]
  625. kwargonly = [
  626. (name, param)
  627. for name, param in sig.parameters.items()
  628. if param.kind == inspect.Parameter.KEYWORD_ONLY
  629. ]
  630. def error():
  631. raise ValueError(
  632. f"custom_op(..., manual_schema)(func): When passing in a manual "
  633. f"schema, we expect `func`'s signature to match `manual_schema` "
  634. f"(aside from type annotations). "
  635. f"func's signature: {sig}, manual_schema: {schema}"
  636. )
  637. def error_default_args():
  638. raise ValueError(
  639. f"custom_op(..., manual_schema)(func): "
  640. f"neither func nor manual_schema should have default "
  641. f"arguments. Got "
  642. f"func's signature: {sig}, manual_schema: {schema}"
  643. )
  644. def compare(sig_args, schema_args):
  645. if len(sig_args) != len(schema_args):
  646. error()
  647. for (name, param), arg in zip(sig_args, schema_args):
  648. if name != arg.name:
  649. error()
  650. if param.default is not inspect.Parameter.empty or arg.default is not None:
  651. error_default_args()
  652. compare(positional, schema.arguments.flat_positional)
  653. compare(kwargonly, schema.arguments.flat_kwarg_only)
  654. def report_error_callback(custom_op: typing.Any, key: str) -> None:
  655. if key == "Undefined":
  656. raise NotImplementedError(
  657. f"{custom_op}: There were no Tensor inputs to this operator "
  658. f"(e.g. you passed an empty list of Tensors). If your operator is a "
  659. f"factory function (that is, it takes no Tensors and constructs "
  660. f"a new one), then please use CustomOp.impl_factory to register "
  661. f"an implementation for it"
  662. )
  663. if key == "Meta":
  664. raise NotImplementedError(
  665. f"{custom_op}: when running with device='Meta' tensors: there is no "
  666. f"abstract impl registered for this CustomOp. Please register one via "
  667. f"CustomOp.impl_abstract to get this CustomOp to work with Meta tensors"
  668. )
  669. if key in ("CPU", "CUDA"):
  670. device = key.lower()
  671. raise NotImplementedError(
  672. f"{custom_op}: when running with device='{device}' tensors: there is no "
  673. f"{device} impl registered for this CustomOp. Please register one via "
  674. f"CustomOp.impl(device_type='{device}')"
  675. )
  676. raise NotImplementedError(
  677. f"{custom_op}: No implementation for dispatch key {key}. It is likely "
  678. f"that we have not added this functionality yet, please either open an "
  679. f"issue or if you're feeling adventurous, use the low-level "
  680. f"torch.library API"
  681. )
  682. def custom_op_from_existing(op):
  683. ns = op.namespace
  684. lib = torch.library.Library(ns, "FRAGMENT")
  685. name = op.name().split("::")[-1]
  686. schema_str = str(op._schema)
  687. # CustomOp expects the schema string without the namespace
  688. schema_str = schema_str.split("::")[-1]
  689. schema = FunctionSchema.parse(schema_str)
  690. return CustomOp(lib, ns, schema, name, op, _private_access=True)
  691. def get_op(qualname):
  692. def error_not_found():
  693. raise ValueError(
  694. f"Could not find the operator {qualname}. Please make sure you have "
  695. f"already registered the operator and (if registered from C++) "
  696. f"loaded it via torch.ops.load_library.")
  697. ns, name = parse_qualname(qualname)
  698. if not hasattr(torch.ops, ns):
  699. error_not_found()
  700. opnamespace = getattr(torch.ops, ns)
  701. if not hasattr(opnamespace, name):
  702. error_not_found()
  703. packet = getattr(opnamespace, name)
  704. if not hasattr(packet, 'default'):
  705. error_not_found()
  706. return packet.default
  707. def _find_custom_op(qualname, also_check_torch_library=False):
  708. if qualname in global_registry:
  709. return global_registry[qualname]
  710. if not also_check_torch_library:
  711. raise RuntimeError(
  712. f'Could not find custom op "{qualname}". Did you register it via '
  713. f"the torch._custom_ops API?")
  714. overload = get_op(qualname)
  715. result = custom_op_from_existing(overload)
  716. return result
  717. def get_abstract_impl(qualname):
  718. if qualname not in torch._custom_op.impl.global_registry:
  719. return None
  720. custom_op = torch._custom_op.impl.global_registry[qualname]
  721. if custom_op is None:
  722. return None
  723. if not custom_op._has_impl("abstract"):
  724. return None
  725. return custom_op._get_impl("abstract").func
  726. def _custom_op_with_schema(qualname, schema, needs_fixed_stride_order=True):
  727. ns, name = qualname.split("::")
  728. schema_str = f"{name}{schema}"
  729. function_schema = FunctionSchema.parse(schema_str)
  730. validate_schema(function_schema)
  731. tags = [torch._C.Tag.needs_fixed_stride_order] if needs_fixed_stride_order else []
  732. lib = library.Library(ns, "FRAGMENT")
  733. lib.define(schema_str, tags=tags)
  734. ophandle = find_ophandle_or_throw(ns, function_schema.name)
  735. result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
  736. result._register_autograd_kernel_indirection()
  737. torch._C._dispatch_set_report_error_callback(
  738. ophandle, functools.partial(report_error_callback, weakref.proxy(result))
  739. )
  740. return get_op(qualname)