functional.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. # mypy: allow-untyped-defs
  2. import weakref
  3. import torch
  4. import torch.utils._pytree as pytree
  5. from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet
  6. from torch._ops import OpOverload
  7. from torch.library import Library
  8. from torchgen.model import (
  9. BaseTy,
  10. BaseType,
  11. FunctionSchema,
  12. OperatorName,
  13. OptionalType,
  14. SchemaKind,
  15. )
  16. from .autograd import autograd_not_implemented
  17. def register_functional_op(
  18. lib: Library,
  19. new_op_name: str,
  20. mutable_op: OpOverload,
  21. ) -> None:
  22. """Given a mutable operator, registers the functional variant.
  23. This API also correctly links the functional variant with the mutable
  24. operator for the purposes of functionalization.
  25. All of the new registrations are performed on the ``lib`` passed in.
  26. Arguments:
  27. lib (Library): Should be a torch.library.Library object that has
  28. the same namespace as ``mutable_op``'s namespace.
  29. lib will be used to register the new functional op as well
  30. as a functionalization kernel for the ``mutable_op``
  31. If you don't have a library handy, use
  32. ``torch.library.Library(ns, 'FRAGMENT')`` to construct one.
  33. new_op_name (str): The name of the functional operator (without the
  34. namespace). If no namespace, the new functional variant will be
  35. accessible under ``torch.ops.{lib.ns}.new_op_name``.
  36. mutable_op (OpOverload): The mutable custom operator. Note
  37. that you may need to add a `.default` to it, like
  38. `torch.ops.aten.abs_.default`.
  39. """
  40. validate(mutable_op)
  41. schema = functional_schema(new_op_name, mutable_op)
  42. lib.define(schema)
  43. functional_impl = construct_functional_impl(mutable_op)
  44. lib.impl(new_op_name, functional_impl, 'CompositeExplicitAutograd')
  45. functional_op = getattr(getattr(torch.ops, lib.ns), new_op_name).default
  46. # There's no easy way for us to generate the autograd kernel, so we
  47. # use autograd_not_implemented. Also, this makes it so that the user
  48. # is unable to register an autograd formula themselves. This shouldn't
  49. # be a problem if the user doesn't use the functional op direclty
  50. # in their program, but we may need to revist this in the future.
  51. lib.impl(new_op_name, autograd_not_implemented(functional_op), 'Autograd')
  52. f_kernel = construct_functionalization_kernel(weakref.proxy(mutable_op), functional_op)
  53. lib.impl(mutable_op, f_kernel, 'Functionalize')
  54. def construct_functional_impl(mutable_op):
  55. def functional_impl(*args):
  56. # Strategy:
  57. # - clone args that would have been mutated
  58. # - run mutable_op
  59. # - return the cloned args as additional outputs
  60. new_args = []
  61. extra_rets = []
  62. for is_write, arg in zip(mutable_args(mutable_op), args):
  63. if is_write:
  64. cloned = arg.clone() if arg is not None else None
  65. new_args.append(cloned)
  66. extra_rets.append(cloned)
  67. else:
  68. new_args.append(arg)
  69. result = mutable_op(*new_args)
  70. if result is None:
  71. return tuple(extra_rets)
  72. if isinstance(result, tuple):
  73. return (*result, *extra_rets)
  74. return (result, *extra_rets)
  75. return functional_impl
  76. def construct_functionalization_kernel(mutable_op, functional_op):
  77. def kernel(*args):
  78. # There's nothing to be functionalized!
  79. # We can still end up here because DispatchKey::Functionalize is a mode key
  80. if pytree.tree_all_only(torch.Tensor, lambda x: not torch._is_functional_tensor(x), args):
  81. with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
  82. return mutable_op(*args)
  83. # NB: This differs from the codegen -- codegen handles cases where there
  84. # are mixed FunctionalTensorWrapper and non-FunctionalTensorWrapper.
  85. # This only really matters for XLA (mixed CPU-XLA tensors) and
  86. # running functionalization without the PT2 stack (which guarantees to us that
  87. # all tensors are FunctionalTensorWrapper).
  88. if not pytree.tree_all_only(torch.Tensor, torch._is_functional_tensor, args):
  89. raise RuntimeError("{mutable_op}: expected all args to be FunctionalTensorWrapper")
  90. unwrapped_args = []
  91. for arg in args:
  92. if isinstance(arg, torch.Tensor) and torch._is_functional_tensor(arg):
  93. torch._sync(arg)
  94. unwrapped = torch._from_functional_tensor(arg)
  95. unwrapped_args.append(unwrapped)
  96. else:
  97. unwrapped_args.append(arg)
  98. with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
  99. output = functional_op(*unwrapped_args)
  100. num_actual_output = len(mutable_op._schema.returns)
  101. actual_output = pytree.tree_map(
  102. torch._to_functional_tensor, output[:num_actual_output])
  103. new_values_to_propagate = output[num_actual_output:]
  104. inputs_to_replace = [arg for is_write, arg in zip(mutable_args(mutable_op), args)
  105. if is_write]
  106. assert len(new_values_to_propagate) == len(inputs_to_replace)
  107. for new_value, arg in zip(new_values_to_propagate, inputs_to_replace):
  108. if (arg is None and new_value is None) or (arg is not None and new_value is not None):
  109. continue
  110. torch._C._propagate_xla_data(arg, new_value)
  111. torch._C._replace_(arg, new_value)
  112. torch._C._commit_update(arg)
  113. torch._sync(arg)
  114. if len(actual_output) == 1:
  115. return actual_output[0]
  116. elif len(actual_output) == 0:
  117. return None
  118. return actual_output
  119. return kernel
  120. def validate(mutable_op: OpOverload):
  121. if not isinstance(mutable_op, OpOverload):
  122. raise TypeError(
  123. f"register_functional_op(mutable_op): expected mutable_op to be instance of "
  124. f"OpOverload but got {type(mutable_op)}")
  125. # There are generally three types of "in-place" or "mutable" ops.
  126. # Each of them have their own conventions:
  127. # - inplace (first input modified in-place and returned as only output)
  128. # - out= (some args modified in-place and returned as outputs)
  129. # - mutable (some args modified in-place but none of those returned as outputs)
  130. # In theory we can support all three, but we'll just support the last
  131. # option right now for simplicity.
  132. schema = FunctionSchema.parse(str(mutable_op._schema))
  133. if not schema.kind() == SchemaKind.mutable:
  134. raise RuntimeError("Expected op to be mutable (as opposed to functional, inplace or out)")
  135. for ret in schema.returns:
  136. # construct_functionalization_kernel assumes this for simplicity
  137. if ret.annotation is not None:
  138. raise NotImplementedError(
  139. "NYI: register_functional_op(op) where op returns a mutated or aliased value. "
  140. "Please file an issue (and as a workaround, modify your operator to "
  141. "not return the mutated value or aliases)")
  142. for arg in schema.arguments.flat_all:
  143. # construct_functionalization_kernel assumes this for simplicity
  144. if arg.type.is_tensor_like() and (
  145. arg.type != BaseType(BaseTy.Tensor)
  146. and arg.type != OptionalType(BaseType(BaseTy.Tensor))
  147. ):
  148. raise NotImplementedError(
  149. "NYI: register_functional_op(op) where op has a List[Tensor] input."
  150. "Please file an issue.")
  151. def functional_schema(new_op_name, op: OpOverload):
  152. schema = FunctionSchema.parse(str(op._schema))
  153. schema = schema.signature().with_name(OperatorName.parse(new_op_name))
  154. return str(schema)
  155. def mutable_args(op: OpOverload):
  156. return tuple(False if arg.alias_info is None else arg.alias_info.is_write
  157. for arg in op._schema.arguments)