compiled_autograd.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import functools
  4. from typing import Dict, List, Optional, TYPE_CHECKING
  5. import torch
  6. from torch._dynamo.external_utils import call_backward, call_hook
  7. from torch._dynamo.source import GetItemSource, LocalSource
  8. from torch._dynamo.utils import counters, lazy_format_graph_code, set_locals_to_steal
  9. from torch._logging import getArtifactLogger, trace_structured
  10. from torch._prims_common import clone_preserve_strides
  11. from torch._subclasses import FakeTensorMode
  12. from torch.fx import GraphModule
  13. from torch.fx.experimental._backward_state import BackwardState
  14. from torch.fx.experimental.proxy_tensor import (
  15. decompose,
  16. disable_autocast_cache,
  17. disable_proxy_modes_tracing,
  18. fetch_object_proxy,
  19. ProxyTorchDispatchMode,
  20. PythonKeyTracer,
  21. track_tensor_tree,
  22. )
  23. from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
  24. from torch.fx.traceback import preserve_node_meta, set_stack_trace
  25. from torch.utils._traceback import CapturedTraceback
  26. if TYPE_CHECKING:
  27. from torch.fx.proxy import Proxy
  28. compiled_autograd_log = getArtifactLogger(__name__, "compiled_autograd")
  29. verbose_log = getArtifactLogger(__name__, "compiled_autograd_verbose")
  30. def snapshot_verbose_logging_enabled():
  31. return torch._logging._internal.log_state.is_artifact_enabled(
  32. "compiled_autograd_verbose"
  33. )
  34. def cpp_verbose_log_fn(msg: str) -> None:
  35. verbose_log.debug(msg)
  36. def snapshot_cudagraph_enabled():
  37. return torch._inductor.config.triton.cudagraphs
  38. def maybe_clone(x):
  39. if x is not None:
  40. return clone_preserve_strides(x)
  41. return x
  42. class AutogradCompilerInstance:
  43. def __init__(self, compiler_fn) -> None:
  44. self.compiler_fn = compiler_fn
  45. self.stack = contextlib.ExitStack()
  46. self.close = self.stack.close
  47. self.shape_env = ShapeEnv()
  48. self.fake_tensor_mode = FakeTensorMode(
  49. allow_fallback_kernels=True,
  50. allow_non_fake_inputs=True,
  51. shape_env=self.shape_env,
  52. )
  53. self.fx_tracer = PythonKeyTracer()
  54. self.proxy_mode = ProxyTorchDispatchMode(self.fx_tracer, "symbolic")
  55. self.hooks_proxy: Optional[Proxy] = None
  56. def wrap_fake(self, x, source):
  57. assert isinstance(x, torch.Tensor)
  58. return self.fake_tensor_mode.from_tensor(x, source=source)
  59. @staticmethod
  60. def source(name, idx) -> GetItemSource:
  61. return GetItemSource(LocalSource(name), idx)
  62. def begin_capture(self, inputs: List[torch.Tensor], sizes: List[int]):
  63. counters["compiled_autograd"]["captures"] += 1
  64. self.fx_tracer.root = torch.nn.Module()
  65. self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer)
  66. self.fx_tracer.tensor_attrs = {}
  67. args_proxy = self.fx_tracer.create_proxy("placeholder", "inputs", (), {})
  68. sizes_proxy = self.fx_tracer.create_proxy("placeholder", "sizes", (), {})
  69. self.hooks_proxy = self.fx_tracer.create_proxy("placeholder", "hooks", (), {})
  70. # tensor inputs to fake tensors
  71. inputs = [
  72. self.wrap_fake(x, self.source("inputs", idx))
  73. for idx, x in enumerate(inputs)
  74. ]
  75. proxies = [args_proxy[i] for i in range(len(inputs))]
  76. self.bind_tensors_to_proxies(inputs, proxies)
  77. # size inputs to symints
  78. sizes = [
  79. self.shape_env.create_unspecified_symint_and_symbol(
  80. val,
  81. self.source("sizes", idx),
  82. DimDynamic.DYNAMIC,
  83. )
  84. for idx, val in enumerate(sizes)
  85. ]
  86. self.bind_tensors_to_proxies(sizes, sizes_proxy)
  87. # TODO(jansel): are all these modes needed?
  88. self.stack.enter_context(decompose({}))
  89. self.stack.enter_context(self.fake_tensor_mode)
  90. self.stack.enter_context(self.proxy_mode.sym_mode)
  91. self.stack.enter_context(self.proxy_mode)
  92. self.stack.enter_context(disable_autocast_cache())
  93. self.stack.enter_context(preserve_node_meta())
  94. return inputs, sizes
  95. def proxy_call_backward(
  96. self,
  97. inputs,
  98. output_metadatas,
  99. saved_tensors,
  100. backward_idx: int,
  101. ):
  102. assert self.hooks_proxy is not None
  103. backward_c_function = self.hooks_proxy[backward_idx] # type: ignore[index]
  104. proxies = self.fx_tracer.create_proxy(
  105. kind="call_function",
  106. target=call_backward,
  107. args=(
  108. backward_c_function,
  109. self.to_proxy(saved_tensors),
  110. *self.to_proxy(inputs),
  111. ),
  112. kwargs={},
  113. )
  114. with disable_proxy_modes_tracing():
  115. # create fake Tensors
  116. grad_ins: List[Optional[torch.Tensor]] = []
  117. for output_metadata in output_metadatas:
  118. if output_metadata is None:
  119. grad_ins.append(None)
  120. continue
  121. layout, device, dtype, size = output_metadata
  122. grad_ins.append(
  123. torch.empty(size=size, dtype=dtype, layout=layout, device=device)
  124. )
  125. self.bind_tensors_to_proxies(grad_ins, proxies)
  126. return tuple(grad_ins)
  127. def proxy_call_hook(self, hook, *args):
  128. return self.fx_tracer.create_proxy(
  129. "call_function",
  130. call_hook,
  131. (
  132. hook,
  133. *[self.to_proxy(x) for x in args],
  134. ),
  135. {},
  136. )
  137. def tensor_pre_hook(self, inputs, hook_id, i: int):
  138. assert self.hooks_proxy is not None
  139. hook = self.hooks_proxy[hook_id] # type: ignore[index]
  140. proxy = self.proxy_call_hook(
  141. hook,
  142. inputs[i],
  143. )
  144. with disable_proxy_modes_tracing():
  145. inputs[i] = maybe_clone(inputs[i])
  146. self.bind_tensors_to_proxies([inputs[i]], [proxy])
  147. return inputs
  148. def pre_hook(self, inputs, hook_id):
  149. assert self.hooks_proxy is not None
  150. hook = self.hooks_proxy[hook_id] # type: ignore[index]
  151. proxies = self.proxy_call_hook(
  152. hook,
  153. inputs,
  154. )
  155. with disable_proxy_modes_tracing():
  156. inputs = [maybe_clone(x) for x in inputs]
  157. self.bind_tensors_to_proxies(inputs, proxies)
  158. return inputs
  159. def post_hook(self, outputs, inputs, hook_id):
  160. assert self.hooks_proxy is not None
  161. hook = self.hooks_proxy[hook_id] # type: ignore[index]
  162. proxies = self.proxy_call_hook(
  163. hook,
  164. outputs,
  165. inputs,
  166. )
  167. with disable_proxy_modes_tracing():
  168. outputs = [maybe_clone(x) for x in outputs]
  169. self.bind_tensors_to_proxies(outputs, proxies)
  170. return outputs
  171. def post_acc_grad_hook(self, input, hook_id):
  172. assert isinstance(input, torch.Tensor)
  173. assert self.hooks_proxy is not None
  174. hook = self.hooks_proxy[hook_id] # type: ignore[index]
  175. proxies = self.proxy_call_hook(
  176. hook,
  177. input,
  178. )
  179. with disable_proxy_modes_tracing():
  180. input = [maybe_clone(input)]
  181. self.bind_tensors_to_proxies(input, proxies)
  182. return input
  183. # Note: [Compiled autograd and cudagraphs]
  184. # Eager autograd backward implements scalars as 0-dim tensors, see DivBackward0::other_.
  185. # When compiled autograd traces those nodes, it lifts the scalar tensors, resulting in a graph
  186. # with some cpu 0-dim tensor inputs. To prevent the entire graph from skipping cudagraph, we move the
  187. # scalars tensors to cuda. This works because ATen/prims ops will accept cuda 0-dim tensors too.
  188. def move_graph_nodes_to_cuda(self, graph) -> List[int]:
  189. to_move: Dict[int, torch.fx.Node] = {}
  190. has_cuda_inputs = False
  191. nodes = list(graph.nodes)
  192. assert nodes[0].target == "inputs"
  193. inputs = nodes[0]
  194. inputs_users = list(inputs.users.keys())
  195. # the ordering of the nodes should always [inputs, sizes, hooks, getitem, getitem1, ...]
  196. # where getitemi accesses inputs[i]
  197. first_getitem_idx = 3
  198. assert nodes[first_getitem_idx] == inputs_users[0]
  199. last_getitem_idx = first_getitem_idx + len(inputs_users) - 1
  200. assert nodes[last_getitem_idx] == inputs_users[-1]
  201. for i, node in enumerate(inputs_users):
  202. if not has_cuda_inputs and node.meta["val"].device.type == "cuda":
  203. has_cuda_inputs = True
  204. continue
  205. is_cpu = node.meta["val"].device.type == "cpu"
  206. is_scalar = len(node.meta["val"].size()) == 0
  207. if is_cpu and is_scalar:
  208. node_users = list(node.users.keys())
  209. if all(
  210. isinstance(user.target, torch._ops.OpOverload)
  211. and user.target.namespace in ("prims", "aten")
  212. for user in node_users
  213. ):
  214. # all users are prims/aten, can move safely
  215. to_move[i] = node
  216. # only move cpu scalars to cuda if there were cuda activations in this graph,
  217. # this is to handle the case where cudagraphs is enabled on a cpu-only graph
  218. if has_cuda_inputs:
  219. for node in to_move.values():
  220. node.meta["val"] = node.meta["val"].cuda()
  221. # return runtime indices we need to move to cuda
  222. return list(to_move.keys())
  223. return []
  224. def end_capture(self, outputs):
  225. self.stack.close()
  226. self.fx_tracer.create_node(
  227. "output",
  228. "output",
  229. (self.fx_tracer.create_arg(self.to_proxy(outputs)),),
  230. {},
  231. )
  232. self.reorder_accumulate_grad_nodes()
  233. runtime_inputs_to_move: List[int] = []
  234. if snapshot_cudagraph_enabled():
  235. runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph)
  236. graph = GraphModule(
  237. self.fx_tracer.root, self.fx_tracer.graph, "CompiledAutograd"
  238. )
  239. set_locals_to_steal(graph, ["inputs"])
  240. compiled_autograd_log.info(
  241. "%s", lazy_format_graph_code("Compiled autograd graph", graph)
  242. )
  243. verbose_log.debug(
  244. "%s",
  245. lazy_format_graph_code(
  246. "Compiled autograd graph", graph, include_device=True
  247. ),
  248. )
  249. trace_structured(
  250. "compiled_autograd_graph",
  251. payload_fn=lambda: graph.print_readable(print_output=False),
  252. )
  253. def runtime_wrapper(compiled_fn, inputs, sizes, hooks):
  254. for i in runtime_inputs_to_move:
  255. inputs[i] = inputs[i].cuda()
  256. return compiled_fn(inputs, sizes, hooks)
  257. return runtime_wrapper, self.compiler_fn(graph)
  258. def reorder_accumulate_grad_nodes(self):
  259. """
  260. Usage of AOTAutograd causes all the accumulate_grad_ nodes to get pushed to the end of
  261. the graph. This differs from eager mode, which schedules them as soon as possible. This
  262. pass attempts to reorder the graph to mimic eager behavior.
  263. """
  264. for node in self.fx_tracer.graph.find_nodes(
  265. op="call_function", target=torch.ops.inductor.accumulate_grad_.default
  266. ):
  267. arg = max(node.args) # last arg
  268. if arg is not node.prev and arg.op != "placeholder":
  269. arg.append(node)
  270. def to_proxy(self, t):
  271. if t is None:
  272. return None
  273. if isinstance(t, list):
  274. return [self.to_proxy(x) for x in t]
  275. if isinstance(t, tuple):
  276. return tuple(self.to_proxy(x) for x in t)
  277. assert isinstance(t, (torch.Tensor, torch.SymInt))
  278. return fetch_object_proxy(self.fx_tracer)(t).proxy
  279. def bind_tensors_to_proxies(self, tensors, proxies):
  280. if isinstance(proxies, torch.fx.Proxy):
  281. proxies = [proxies[i] for i in range(len(tensors))]
  282. assert len(tensors) == len(proxies)
  283. track_tensor_tree(tensors, proxies, constant=None, tracer=self.fx_tracer)
  284. def bind_backward_state(self, index: int):
  285. assert self.hooks_proxy is not None
  286. proxy = self.hooks_proxy[index] # type: ignore[index]
  287. bw_state = BackwardState()
  288. track_tensor_tree(bw_state, proxy, constant=None, tracer=self.fx_tracer)
  289. return bw_state
  290. def set_node_origin(self, node_name, node_index):
  291. raw_stack_trace = CapturedTraceback.extract().format()[-1]
  292. new_code = f"{node_name} (NodeCall {node_index})"
  293. new_stack_trace = raw_stack_trace.replace(
  294. "raw_stack_trace = CapturedTraceback.extract().format()[-1]", new_code
  295. )
  296. set_stack_trace(new_stack_trace)
  297. compiled_autograd_enabled = False
  298. # We may have code like:
  299. # with enable(compiler_fn):
  300. # ...
  301. # with disable():
  302. # ...
  303. # ...
  304. # The disable() call just want to disable compiled autograd temporarily.
  305. # But overall the feature is enabled.
  306. #
  307. # The code covered by the disable context manager has no way to know if
  308. # compiled autograd is overall eanbled. Use another variable
  309. # compiled_autograd_enabled_count to indicate how many times compiled
  310. # autograd has been enabled in the call stack for this purpose.
  311. compiled_autograd_enabled_count = 0
  312. @contextlib.contextmanager
  313. def enable(compiler_fn):
  314. prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(
  315. functools.partial(AutogradCompilerInstance, compiler_fn)
  316. )
  317. if snapshot_verbose_logging_enabled():
  318. torch._C._dynamo.compiled_autograd.set_verbose_logger(cpp_verbose_log_fn)
  319. global compiled_autograd_enabled, compiled_autograd_enabled_count
  320. compiled_autograd_enabled = True
  321. compiled_autograd_enabled_count += 1
  322. try:
  323. with torch.autograd.set_multithreading_enabled(False):
  324. yield
  325. finally:
  326. compiled_autograd_enabled_count -= 1
  327. if not prior:
  328. compiled_autograd_enabled = False
  329. torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)
  330. @contextlib.contextmanager
  331. def disable():
  332. prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(None)
  333. global compiled_autograd_enabled
  334. compiled_autograd_enabled = False
  335. try:
  336. yield
  337. finally:
  338. if prior:
  339. compiled_autograd_enabled = True
  340. torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)
  341. # return to starting state of a new process
  342. def reset() -> None:
  343. compiled_autograd_enable = False
  344. assert compiled_autograd_enabled_count == 0
  345. torch._C._dynamo.compiled_autograd.set_autograd_compiler(None)
  346. torch._C._dynamo.compiled_autograd.set_verbose_logger(None)