pyfunctorch.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. from abc import ABC, abstractmethod
  4. from typing import Any, List, Tuple
  5. import torch
  6. import torch.utils._pytree as pytree
  7. from torch._C._functorch import (
  8. CFunctionalizeInterpreterPtr,
  9. CGradInterpreterPtr,
  10. CInterpreter,
  11. CJvpInterpreterPtr,
  12. CVmapInterpreterPtr,
  13. pop_dynamic_layer_stack,
  14. push_dynamic_layer_stack,
  15. RandomnessType,
  16. TransformType,
  17. )
  18. from torch.autograd.forward_ad import _set_fwd_grad_enabled
  19. """
  20. This file contains the functorch integration with PyDispatcher.
  21. PyDispatcher does not understand functorch's DynamicLayerStack dispatching
  22. logic because it is entirely implemented in C++ in the fallbacks for two
  23. dispatch keys, FuncTorchDynamicLayer{Front, Back}Mode (PyDispatcher is unable
  24. to directly reuse C++ boxed fallbacks).
  25. Instead of trying to hammer PyDispatcher into understanding those fallbacks,
  26. we re-implement the logic of peeking the top of the stack for an interpreter,
  27. selecting the interpreter to dispatch on, etc, in Python. This leads to a
  28. simpler design.
  29. The main difference between C++ functorch and PyDispatcher's functorch logic
  30. is that:
  31. - C++ functorch needs to manually tweak dispatch keys to ping-pong between
  32. DynamicLayerFrontMode and DynamicLayerBackMode.
  33. - PyDispatcher's functorch logic pops an Interpreter from the top of the stack
  34. and asks it to execute the rule associated with the Interpreter.
  35. In C++ we do the ping-pong because e.g. vmap rules are associated with the
  36. batched DispatchKey, but in PyDispatcher we are able to avoid this by asking
  37. the user to register a batching rule directly to a transform that an
  38. interpreter then invokes.
  39. """
  40. # FuncTorchInterpreter is the Python version of Interpreter (recall that
  41. # the DynamicLayerStack is a stack of interpreters).
  42. # It is a wrapper around the actual C++ Interpreter object.
  43. #
  44. # Keep the methods in sync with aten/src/ATen/functorch/Interpreter.h
  45. class FuncTorchInterpreter(ABC):
  46. def __init__(self, cptr: Any):
  47. self._cptr = cptr
  48. # Process an operation. eg for vmap, this is invoking a batching rule.
  49. # Conceptually this is analogous to Interpreter::process in C++
  50. @abstractmethod
  51. def process(self, op, args, kwargs):
  52. pass
  53. # lower an operation from this Interpreter to the next Interpreter on the stack.
  54. # Concretely, this involves temporarily popping the current Interpreter.
  55. # Conceptually this is analogous to Interpreter::sendToNextInterpreter in C++
  56. def lower(self):
  57. return temporarily_pop_interpreter_stack()
  58. def level(self):
  59. return self._cptr.level()
  60. def key(self):
  61. return self._cptr.key()
  62. def get_state(self):
  63. raise NotImplementedError
  64. def check_state(self, state):
  65. return state == self.get_state()
  66. @contextlib.contextmanager
  67. def temporarily_pop_interpreter_stack():
  68. try:
  69. saved = pop_dynamic_layer_stack()
  70. yield
  71. finally:
  72. push_dynamic_layer_stack(saved)
  73. @contextlib.contextmanager
  74. def temporarily_clear_interpreter_stack():
  75. stack = []
  76. try:
  77. while torch._C._functorch.peek_interpreter_stack() is not None:
  78. stack.append(pop_dynamic_layer_stack())
  79. yield list(stack)
  80. finally:
  81. while stack:
  82. push_dynamic_layer_stack(stack.pop())
  83. @contextlib.contextmanager
  84. def temporarily_restore_interpreter_stack(stack):
  85. pushed = []
  86. try:
  87. for s in reversed(stack):
  88. push_dynamic_layer_stack(s)
  89. pushed.append(s)
  90. yield
  91. finally:
  92. for s in reversed(pushed):
  93. # TODO: would be nice to assert that the layers are the same, but
  94. # Python object identity is not preserved
  95. pop_dynamic_layer_stack()
  96. class VmapInterpreter(FuncTorchInterpreter):
  97. def __init__(self, cdata: CInterpreter):
  98. assert cdata.key() == TransformType.Vmap
  99. # NOTE: [Interpreter cdata vs cptr]
  100. # cdata is a generic CInterpreter. We wrap it in a CVmapInterpreterPtr
  101. # so that we can access methods specific to the vmap interpreter
  102. self._cdata = cdata
  103. self._cptr = CVmapInterpreterPtr(cdata)
  104. def process(self, op, args, kwargs):
  105. kernel = op.functorch_table[TransformType.Vmap]
  106. return kernel(self, *args, **kwargs)
  107. def batch_size(self):
  108. return self._cptr.batchSize()
  109. def randomness(self):
  110. typ = self._cptr.randomness()
  111. if typ == RandomnessType.Error:
  112. return "error"
  113. elif typ == RandomnessType.Same:
  114. return "same"
  115. elif typ == RandomnessType.Different:
  116. return "different"
  117. raise RuntimeError(f"Unknown RandomnessType: {typ}")
  118. def get_state(self):
  119. return (self.key().name, self.level(), self.randomness())
  120. @contextlib.contextmanager
  121. def nested(*contexts):
  122. with contextlib.ExitStack() as stack:
  123. for ctx in contexts:
  124. stack.enter_context(ctx)
  125. yield contexts
  126. class GradInterpreter(FuncTorchInterpreter):
  127. def __init__(self, cdata: CInterpreter):
  128. assert cdata.key() == TransformType.Grad
  129. # See NOTE: [Interpreter cdata vs cptr]
  130. self._cdata = cdata
  131. self._cptr = CGradInterpreterPtr(cdata)
  132. def lift(self, args, kwargs):
  133. args, kwargs = pytree.tree_map_only(
  134. torch.Tensor, self._cptr.lift, [args, kwargs]
  135. )
  136. return args, kwargs
  137. def process(self, op, args, kwargs):
  138. kernel = op.functorch_table[TransformType.Grad]
  139. args, kwargs = self.lift(args, kwargs)
  140. return kernel(self, *args, **kwargs)
  141. # GradInterpreter has custom lower because of the no_grad interaction
  142. # See NOTE [grad and vjp interaction with no_grad]
  143. # This logic is mirrored from C++ GradInterpreterPtr::sendToNextInterpreter
  144. def lower(self):
  145. prev_grad_mode = self.prev_grad_mode()
  146. if not prev_grad_mode:
  147. return nested(torch.no_grad(), super().lower())
  148. return super().lower()
  149. def prev_grad_mode(self):
  150. return self._cptr.prevGradMode()
  151. def get_state(self):
  152. return (self.key().name, self.level(), self.prev_grad_mode())
  153. class JvpInterpreter(FuncTorchInterpreter):
  154. def __init__(self, cdata: CInterpreter):
  155. assert cdata.key() == TransformType.Jvp
  156. # See NOTE: [Interpreter cdata vs cptr]
  157. self._cdata = cdata
  158. self._cptr = CJvpInterpreterPtr(cdata)
  159. def lift(self, args, kwargs):
  160. args, kwargs = pytree.tree_map_only(
  161. torch.Tensor, self._cptr.lift, [args, kwargs]
  162. )
  163. return args, kwargs
  164. def process(self, op, args, kwargs):
  165. kernel = op.functorch_table[TransformType.Jvp]
  166. args, kwargs = self.lift(args, kwargs)
  167. return kernel(self, *args, **kwargs)
  168. # Jvp has custom lower because of the no_fwd_grad interaction
  169. # See NOTE [grad and vjp interaction with no_grad] for related info.
  170. # This logic is mirrored from C++ JvpInterpreterPtr::sendToNextInterpreter
  171. def lower(self):
  172. prev_fwd_grad_mode = self.prev_fwd_grad_mode()
  173. if not prev_fwd_grad_mode:
  174. return nested(_set_fwd_grad_enabled(False), super().lower())
  175. return super().lower()
  176. def prev_fwd_grad_mode(self):
  177. return self._cptr.prevFwdGradMode()
  178. def get_state(self):
  179. return (self.key().name, self.level(), self.prev_fwd_grad_mode())
  180. class FunctionalizeInterpreter(FuncTorchInterpreter):
  181. def __init__(self, cdata: CInterpreter):
  182. assert cdata.key() == TransformType.Functionalize
  183. self._cdata = cdata
  184. self._cptr = CFunctionalizeInterpreterPtr(cdata)
  185. def process(self, op, args, kwargs):
  186. kernel = op.functorch_table[TransformType.Functionalize]
  187. return kernel(self, *args, **kwargs)
  188. def functionalize_add_back_views(self):
  189. return self._cptr.functionalizeAddBackViews()
  190. def get_state(self):
  191. return (self.key().name, self.level())
  192. def coerce_cinterpreter(cinterpreter: CInterpreter) -> FuncTorchInterpreter:
  193. key = cinterpreter.key()
  194. if key == TransformType.Grad:
  195. return GradInterpreter(cinterpreter)
  196. if key == TransformType.Vmap:
  197. return VmapInterpreter(cinterpreter)
  198. if key == TransformType.Jvp:
  199. return JvpInterpreter(cinterpreter)
  200. if key == TransformType.Functionalize:
  201. return FunctionalizeInterpreter(cinterpreter)
  202. raise RuntimeError(f"NYI: PyDispatcher has not implemented support for {key}")
  203. def retrieve_current_functorch_interpreter() -> FuncTorchInterpreter:
  204. interpreter = torch._C._functorch.peek_interpreter_stack()
  205. assert interpreter is not None
  206. return coerce_cinterpreter(interpreter)
  207. def retrieve_all_functorch_interpreters() -> List[FuncTorchInterpreter]:
  208. cis = torch._C._functorch.get_interpreter_stack()
  209. if cis is None:
  210. return []
  211. return [coerce_cinterpreter(ci) for ci in cis]
  212. def compare_functorch_state(states: List[Tuple[Any, ...]]) -> bool:
  213. # There are four possible cases covered here:
  214. # 1. Current stack empty AND stack when generated not empty -> Invalidate
  215. # 2. Current stack not empty AND stack when generated empty -> Invalidate
  216. # 3. Current stack and generated stack empty -> Valid FX graph
  217. # 4. Current stack and generated stack not empty -> Valid if both states match
  218. peek = torch._C._functorch.peek_interpreter_stack()
  219. if (peek is None and len(states) != 0) or (peek is not None and len(states) == 0):
  220. return False
  221. cis = retrieve_all_functorch_interpreters()
  222. return len(cis) == len(states) and all(
  223. ci.check_state(state) for ci, state in zip(cis, states)
  224. )
  225. def dispatch_functorch(op, args, kwargs):
  226. interpreter = retrieve_current_functorch_interpreter()
  227. # In traditional PyTorch operators, DispatchKey::FuncTorchTensorWrapper's
  228. # unwrap_dead_tensors fallback handles unwrapping dead tensor wrappers.
  229. # PyDispatcher sidesteps the PyTorch dispatcher when dealing with functorch
  230. # transforms, so we manually unwrap the dead tensors here.
  231. # This logic won't need to exist when we have mode-only functorch.
  232. args, kwargs = pytree.tree_map_only(
  233. torch.Tensor, torch._C._functorch.unwrap_if_dead, (args, kwargs)
  234. )
  235. return interpreter.process(op, args, kwargs)