abstract_impl.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import functools
  4. from typing import Callable, Optional
  5. from typing_extensions import deprecated
  6. import torch
  7. from torch._library.utils import Kernel, RegistrationHandle
  8. class AbstractImplHolder:
  9. """A holder where one can register an fake impl to."""
  10. def __init__(self, qualname: str):
  11. self.qualname: str = qualname
  12. self.kernel: Optional[Kernel] = None
  13. self.lib: Optional[torch.library.Library] = None
  14. def register(self, func: Callable, source: str) -> RegistrationHandle:
  15. """Register an fake impl.
  16. Returns a RegistrationHandle that one can use to de-register this
  17. fake impl.
  18. """
  19. if self.kernel is not None:
  20. raise RuntimeError(
  21. f"register_fake(...): the operator {self.qualname} "
  22. f"already has an fake impl registered at "
  23. f"{self.kernel.source}."
  24. )
  25. if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"):
  26. raise RuntimeError(
  27. f"register_fake(...): the operator {self.qualname} "
  28. f"already has an DispatchKey::Meta implementation via a "
  29. f"pre-existing torch.library or TORCH_LIBRARY registration. "
  30. f"Please either remove that registration or don't call "
  31. f"register_fake."
  32. )
  33. if torch._C._dispatch_has_kernel_for_dispatch_key(
  34. self.qualname, "CompositeImplicitAutograd"
  35. ):
  36. raise RuntimeError(
  37. f"register_fake(...): the operator {self.qualname} "
  38. f"already has an implementation for this device type via a "
  39. f"pre-existing registration to "
  40. f"DispatchKey::CompositeImplicitAutograd."
  41. f"CompositeImplicitAutograd operators do not need an fake "
  42. f"impl; "
  43. f"instead, the operator will decompose into its constituents "
  44. f"and those "
  45. f"can have fake impls defined on them."
  46. )
  47. # Store the kernel in this holder
  48. self.kernel = Kernel(func, source)
  49. # Also register the fake impl to Meta key
  50. if self.lib is None:
  51. ns = self.qualname.split("::")[0]
  52. self.lib = torch.library.Library(ns, "FRAGMENT")
  53. meta_kernel = construct_meta_kernel(self.qualname, self)
  54. self.lib.impl(self.qualname, meta_kernel, "Meta")
  55. def deregister_fake_class():
  56. if self.lib:
  57. self.lib._destroy()
  58. self.lib = None
  59. self.kernel = None
  60. return RegistrationHandle(deregister_fake_class)
  61. def construct_meta_kernel(
  62. qualname: str, abstract_impl_holder: AbstractImplHolder
  63. ) -> Callable:
  64. assert abstract_impl_holder.kernel is not None
  65. @functools.wraps(abstract_impl_holder.kernel.func)
  66. def meta_kernel(*args, **kwargs):
  67. assert abstract_impl_holder.kernel is not None
  68. source = abstract_impl_holder.kernel.source
  69. def error_on_ctx():
  70. raise RuntimeError(
  71. f"Attempted to call get_ctx() for the meta implementation "
  72. f"for {qualname} (implemented at {source})"
  73. f"You have presumably called get_ctx() because the operator "
  74. f"has a data-dependent output shape; if so, there is no "
  75. f"such meta implementation and this error is the correct "
  76. f"behavior."
  77. )
  78. with set_ctx_getter(error_on_ctx):
  79. return abstract_impl_holder.kernel(*args, **kwargs)
  80. return meta_kernel
  81. def get_none():
  82. return None
  83. global_ctx_getter: Callable = get_none
  84. @contextlib.contextmanager
  85. def set_ctx_getter(ctx_getter):
  86. global global_ctx_getter
  87. prev = global_ctx_getter
  88. try:
  89. global_ctx_getter = ctx_getter
  90. yield
  91. finally:
  92. global_ctx_getter = prev
  93. class AbstractImplCtx:
  94. """
  95. Context object for writing fake implementations for custom operators.
  96. """
  97. def __init__(self, _fake_mode, _op):
  98. self._fake_mode = _fake_mode
  99. self._shape_env = _fake_mode.shape_env
  100. self._op = _op
  101. @deprecated(
  102. "`create_unbacked_symint` is deprecated, please use `new_dynamic_size` instead",
  103. category=FutureWarning,
  104. )
  105. def create_unbacked_symint(self, *, min=2, max=None) -> torch.SymInt:
  106. return self.new_dynamic_size(min=min, max=max)
  107. def new_dynamic_size(self, *, min=0, max=None) -> torch.SymInt:
  108. """Constructs a new symint (symbolic int) representing a data-dependent value.
  109. This is useful for writing the fake implementation (which is necessary
  110. for torch.compile) for a CustomOp where an output Tensor has a size
  111. that depends on the data of the input Tensors.
  112. Args:
  113. min (int): A statically known inclusive lower bound for this symint. Default: 0
  114. max (Optional[int]): A statically known inclusive upper bound for this
  115. symint. Default: None
  116. .. warning:
  117. It is important that the ``min`` and ``max`` (if not None) values are set
  118. correctly, otherwise, there will be undefined behavior under
  119. torch.compile. The default value of ``min`` is 2 due to torch.compile
  120. specializing on 0/1 sizes.
  121. You must also verify that your implementation on concrete Tensors
  122. (e.g. CPU/CUDA) only returns Tensors where the size that corresponds
  123. to the symint also has respects these constraint.
  124. The easiest way to do this is to add an assertion in the CPU/CUDA/etc
  125. implementation that the size follows these bounds.
  126. Example::
  127. >>> # An operator with data-dependent output shape
  128. >>> lib = torch.library.Library("mymodule", "FRAGMENT")
  129. >>> lib.define("mymodule::custom_nonzero(Tensor x) -> Tensor")
  130. >>>
  131. >>> @torch.library.register_fake("mymodule::custom_nonzero")
  132. >>> def _(x):
  133. >>> # Number of nonzero-elements is data-dependent.
  134. >>> # Since we cannot peek at the data in an fake impl,
  135. >>> # we use the ctx object to construct a new symint that
  136. >>> # represents the data-dependent size.
  137. >>> ctx = torch.library.get_ctx()
  138. >>> nnz = ctx.new_dynamic_size()
  139. >>> shape = [nnz, x.dim()]
  140. >>> result = x.new_empty(shape, dtype=torch.int64)
  141. >>> return result
  142. >>>
  143. >>> @torch.library.impl(lib, "custom_nonzero", "CPU")
  144. >>> def _(x):
  145. >>> x_np = x.numpy()
  146. >>> res = np.stack(np.nonzero(x_np), axis=1)
  147. >>> return torch.tensor(res, device=x.device)
  148. """
  149. if (
  150. self._shape_env is None
  151. or not self._shape_env.allow_dynamic_output_shape_ops
  152. ):
  153. raise torch._subclasses.fake_tensor.DynamicOutputShapeException(self._op)
  154. if isinstance(min, torch.SymInt) or isinstance(max, torch.SymInt):
  155. raise ValueError(
  156. f"ctx.new_dynamic_size(min={min}, max={max}): expected "
  157. f"min and max to be statically known ints but got SymInt. "
  158. f"This is not supported."
  159. )
  160. if min < 0:
  161. raise ValueError(
  162. f"ctx.new_dynamic_size(min={min}, ...): expected min to be "
  163. f"greater than or equal to 0: this API can only create "
  164. f"non-negative sizes."
  165. )
  166. result = self._shape_env.create_unbacked_symint()
  167. torch.fx.experimental.symbolic_shapes._constrain_range_for_size(
  168. result, min=min, max=max
  169. )
  170. return result