autograd_function.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722
  1. # mypy: allow-untyped-defs
  2. from typing import Any, NamedTuple, Tuple
  3. import torch
  4. import torch.utils._pytree as pytree
  5. from torch._C._functorch import (
  6. _unwrap_for_grad,
  7. _wrap_for_grad,
  8. current_level,
  9. TransformType,
  10. )
  11. from torch._functorch.apis import vmap
  12. from torch._functorch.utils import enable_single_level_autograd_function
  13. from torch._functorch.vmap import (
  14. _add_batch_dim,
  15. _broadcast_to_and_flatten,
  16. restore_vmap,
  17. unwrap_batched,
  18. wrap_batched,
  19. )
  20. from torch._ops import HigherOrderOperator
  21. from torch.autograd.forward_ad import _set_fwd_grad_enabled
  22. # autograd.Function technically runs before the regular PyTorch dispatcher.
  23. # This is how features like autocast and torch_dispatch (e.g. PythonTLSSnapshot)
  24. # work with it. One day we might decide to change this, but until then,
  25. # we need to give the illusion that autograd.Function runs before those things.
  26. #
  27. # We do this by using creating a custom HigherOrderOperator that only functorch
  28. # dispatches specially.
  29. class CustomFunctionHigherOrderOperator(HigherOrderOperator):
  30. def __init__(self):
  31. super().__init__("custom_function_call")
  32. def __call__(self, autograd_function, *args, **kwargs):
  33. # When custom_function_call is done dispatching through functorch,
  34. # it should just invoke the autograd.Function. This is consistent
  35. # with the autograd.Function behavior of being invoked before the
  36. # PyTorch dispatcher.
  37. #
  38. # This will lead us into trouble later down the line, but this is
  39. # pre-existing. There is an invariant that a function traced by
  40. # make_fx should have the same behavior when provided the same
  41. # Tensor. However, make_fx sees autograd.Function as a composite
  42. # (because autograd.Function happens before the Python dispatch key)
  43. # and only traces the forward pass.
  44. if torch._C._are_functorch_transforms_active():
  45. return super().__call__(autograd_function, *args, **kwargs)
  46. return autograd_function.apply(*args, **kwargs)
  47. # "custom_function_call"
  48. # This is the mechanism for an autograd.Function that works with functorch transforms.
  49. # It wraps an autograd.Function; interactions with functorch transforms are defined
  50. # via PyDispatcher and HigherOrderOperator rather than through the traditional PyTorch
  51. # dispatcher.
  52. custom_function_call = CustomFunctionHigherOrderOperator()
  53. # The grad rule for custom_function_call is to construct a new _SingleLevelFunction
  54. # (autograd.Function that only works with a single layer (level) of functorch) that:
  55. # - unwraps the inputs
  56. # - redispatches to custom_function_call
  57. # - wraps the outputs
  58. # and whose backward pass calls the original autograd.Function's backward.
  59. #
  60. # Why do we need to redispatch to custom_function_call?
  61. # -----------------------------------------------------
  62. # This is consistent with how ATen operators work with functorch's grad transform:
  63. # they always redispatch to the original operator.
  64. # Consider torch.sin, and let's say we do grad0(grad1(torch.sin))(x)
  65. #
  66. # grad1 will:
  67. # - set up the autograd graph
  68. # - unwrap the inputs
  69. # - redispatch to at::sin (*)
  70. # - rewrap the outputs on the return
  71. #
  72. # On the redispatch in (*), grad0 will:
  73. # - set up the autograd graph
  74. # - unwrap the inputs
  75. # - redispatch to at::sin
  76. # - rewrap the outputs on the return
  77. #
  78. # To "set up the autograd graph", we generate a _SingleLevelFunction
  79. # and apply it.
  80. @custom_function_call.py_impl(TransformType.Grad)
  81. @custom_function_call.py_impl(TransformType.Jvp)
  82. def custom_function_call_grad(interpreter, autograd_function, *operands):
  83. Generated = generate_single_level_function(interpreter, autograd_function)
  84. with enable_single_level_autograd_function():
  85. flat_out = Generated.apply(*operands)
  86. return flat_out
  87. def generate_single_level_function(interpreter, autograd_function):
  88. level = interpreter.level()
  89. def forward(*operands):
  90. unwrapped_operands = pytree.tree_map_only(
  91. torch.Tensor, lambda x: _unwrap_for_grad(x, level), operands
  92. )
  93. # Both enable_grad() and _set_fwd_grad_enabled() are necessary no matter
  94. # the transform. _SingleLevelFunction will turn off both fwd and bwd
  95. # gradient computation and we need to turn it back on here.
  96. with torch.enable_grad(), _set_fwd_grad_enabled(True), interpreter.lower():
  97. unwrapped_output = custom_function_call(
  98. autograd_function, *unwrapped_operands
  99. )
  100. # See NOTE [mark_dirty object identity check]
  101. def wrap_fn(output):
  102. return _wrap_for_grad(output, level)
  103. return wrap_outputs_maintaining_identity(
  104. unwrapped_output, unwrapped_operands, operands, wrap_fn
  105. )
  106. def setup_context(ctx, inputs, output):
  107. return autograd_function.setup_context(ctx, inputs, output)
  108. # backward is only used if the transform is TransformType.Grad
  109. def backward(ctx, *grads):
  110. result = autograd_function.backward(ctx, *grads)
  111. return result
  112. # jvp is only used if the transform is TransformType.Jvp
  113. def jvp(ctx, *tangents):
  114. result = autograd_function.jvp(ctx, *tangents)
  115. return result
  116. # This is the sequence of magic words to dynamically generate a Subclass with
  117. # a given name. A Tensor's .grad_fn field has a class name that is the original
  118. # autograd.Function's name + Backward, so we do this to generate some
  119. # meaningful name.
  120. name = f"{autograd_function.__name__}Generated"
  121. Generated = type(
  122. name,
  123. (torch.autograd.function._SingleLevelFunction,),
  124. {
  125. "forward": staticmethod(forward),
  126. "backward": staticmethod(backward),
  127. "jvp": staticmethod(jvp),
  128. "setup_context": staticmethod(setup_context),
  129. },
  130. )
  131. return Generated
  132. # wrap_outputs_maintaining_identity handles outputs from the vmap,
  133. # backward (vjp), and jvp staticmethod. The way it distinguishes
  134. # between the vmap case and the {backward, jvp} case is if the out_dims
  135. # are specified or not.
  136. #
  137. # NB: we cannot use out_dims=None as the deciding factor. This because
  138. # out_dims=None can still happen in the vmap staticmethod! What the
  139. # user is saying in that case is that their output does not have a
  140. # dimension that is being vmapped over, which is valid.
  141. NO_OUT_DIMS = "not specified"
  142. # NOTE [mark_dirty object identity check]
  143. # autograd.Function's ctx.mark_dirty expect a returned input
  144. # to have the same object identity as the input.
  145. # Mode-only functorch will greatly simplify this logic.
  146. def wrap_outputs_maintaining_identity(
  147. outputs, unwrapped_inputs, orig_inputs, wrap_fn, out_dims=NO_OUT_DIMS
  148. ):
  149. flat_unwrapped_inputs = pytree.arg_tree_leaves(*unwrapped_inputs)
  150. flat_orig_inputs = pytree.arg_tree_leaves(*orig_inputs)
  151. unwrapped_input_to_orig_input = {
  152. id(unwrapped): orig
  153. for unwrapped, orig in zip(flat_unwrapped_inputs, flat_orig_inputs)
  154. }
  155. flat_outputs, spec = pytree.tree_flatten(outputs)
  156. result = []
  157. out_dims_specified = out_dims != NO_OUT_DIMS
  158. if out_dims_specified:
  159. flat_out_dims = _broadcast_to_and_flatten(out_dims, spec)
  160. # _broadcast_to_and_flatten returns None if it is unable to broadcast.
  161. # TODO: update following link from master to stable once that's out
  162. if flat_out_dims is None:
  163. raise RuntimeError(
  164. f"The autograd.Function's vmap staticmethod returned an "
  165. f"incompatible (output, out_dims) tuple. "
  166. f"Expected out_dims={out_dims} "
  167. f"to be compatible with the structure of `output`. "
  168. f"out_dims has structure {pytree.tree_flatten(out_dims)[1]} "
  169. f"but output has structure {spec}. "
  170. f"For more details, please see "
  171. f"https://pytorch.org/docs/main/notes/extending.func.html"
  172. )
  173. for i, output in enumerate(flat_outputs):
  174. if not isinstance(output, torch.Tensor):
  175. result.append(output)
  176. continue
  177. if id(output) in unwrapped_input_to_orig_input:
  178. result.append(unwrapped_input_to_orig_input[id(output)])
  179. continue
  180. if out_dims_specified:
  181. result.append(wrap_fn(output, flat_out_dims[i])) # type: ignore[possibly-undefined, index]
  182. else:
  183. result.append(wrap_fn(output))
  184. return pytree.tree_unflatten(result, spec)
  185. # NOTE: [functorch vjp and autograd interaction]
  186. # There's an edge case with the functorch vjp and autograd interaction
  187. # that will eventually be fixed by mode-only functorch.
  188. # The TL;DR is that there's no way to unwrap a dead GradTensorWrapper,
  189. # so we (the framework) need to do it manually. Regular PyTorch operators
  190. # automatically do so this is consistent.
  191. #
  192. # class MyExp(torch.autograd.Function):
  193. # @staticmethod
  194. # def forward(x):
  195. # return x.exp()
  196. #
  197. # @staticmethod
  198. # def setup_context(ctx, inputs, output):
  199. # y = output
  200. # ctx.save_for_backward(y)
  201. #
  202. # @staticmethod
  203. # def backward(gy):
  204. # y, = ctx.saved_tensors()
  205. # return MyMul.apply(gy, y)
  206. #
  207. # x = torch.randn([], requires_grad=True)
  208. # gy = torch.randn([], requires_grad=True)
  209. # _, vjp_fn = vjp(MySin.apply, x)
  210. # result = vjp_fn(gy)
  211. #
  212. # MyMul is an autograd.Function that is not shown here.
  213. # It saves a `y` for backward (since gy requires grad).
  214. #
  215. # in vjp_fn(gy), we get:
  216. # > MyMul.apply(gy, GradTensorWrapper(y, level=dead))
  217. # Because the y that is saved for backward by MyExp is a GradTensorWrapper
  218. # but is now dead since we are outside the vjp context.
  219. #
  220. # PyTorch dispatcher operations, upon seeing a dead GradTensorWrapper,
  221. # will automatically unwrap the GradTensorWrapper when applied.
  222. # But since autograd.Function technically sits above the regular PyTorch
  223. # dispatcher, it doesn't get this treatment. So we manually do
  224. # the unwrapping to be consistent with regular PyTorch dispatcher operations.
  225. class VmapInfo(NamedTuple):
  226. batch_size: int
  227. randomness: str
  228. def has_overriden_vmap_rule(autograd_function):
  229. return autograd_function.vmap is not torch.autograd.Function.vmap
  230. def validate_vmap_returns_tuple_of_two_elements(result):
  231. base_error_msg = (
  232. "Expected the vmap staticmethod to have two returns, an output "
  233. "and out_dims with pytree structure compatible with the output. "
  234. )
  235. if not isinstance(result, tuple):
  236. raise RuntimeError(base_error_msg + f"Got a {type(result)} instead")
  237. if not len(result) == 2:
  238. raise RuntimeError(base_error_msg + f"Got {len(result)} returns instead")
  239. @custom_function_call.py_impl(TransformType.Vmap)
  240. def custom_function_call_vmap(interpreter, autograd_function, *operands):
  241. if autograd_function.generate_vmap_rule:
  242. if has_overriden_vmap_rule(autograd_function):
  243. # TODO: Update link to stable once that's out
  244. # https://github.com/pytorch/pytorch/issues/92029
  245. raise RuntimeError(
  246. f"You tried to vmap over {autograd_function.__name__}, but "
  247. f"it has both generate_vmap_rule=True and an overriden vmap "
  248. f"staticmethod. Please set generate_vmap_rule=False or delete "
  249. f"the overriden vmap staticmethod to avoid ambiguity. "
  250. f"For more details, please see "
  251. f"https://pytorch.org/docs/main/notes/extending.func.html"
  252. )
  253. return custom_function_call_vmap_generate_rule(
  254. interpreter, autograd_function, *operands
  255. )
  256. if not has_overriden_vmap_rule(autograd_function):
  257. # TODO: Update link to stable once that's out
  258. # https://github.com/pytorch/pytorch/issues/92029
  259. raise RuntimeError(
  260. f"You tried to vmap over {autograd_function.__name__}, but "
  261. f"it does not have vmap support. Please override and implement the "
  262. f"vmap staticmethod or set generate_vmap_rule=True. "
  263. f"For more details, please see "
  264. f"https://pytorch.org/docs/main/notes/extending.func.html"
  265. )
  266. current_level = interpreter.level()
  267. info = VmapInfo(
  268. batch_size=interpreter.batch_size(),
  269. randomness=interpreter.randomness(),
  270. )
  271. unwrapped_operands, in_dims = unwrap_batched(operands, current_level)
  272. # If none of the tensors are batched at the current level, then we skip the
  273. # current level. This saves the user from needing to handle this case in
  274. # their vmap staticmethod (and is consistent with our C++ batching rule API)
  275. if pytree.tree_all(lambda dim: dim is None, in_dims):
  276. with interpreter.lower():
  277. return custom_function_call(autograd_function, *operands)
  278. with interpreter.lower():
  279. result = autograd_function.vmap(info, in_dims, *unwrapped_operands)
  280. validate_vmap_returns_tuple_of_two_elements(result)
  281. unwrapped_output, out_dims = result
  282. # See NOTE [mark_dirty object identity check]
  283. def wrap_fn(output, out_dim):
  284. return (
  285. output
  286. if out_dim is None
  287. else _add_batch_dim(output, out_dim, current_level)
  288. )
  289. return wrap_outputs_maintaining_identity(
  290. unwrapped_output, unwrapped_operands, operands, wrap_fn, out_dims=out_dims
  291. )
  292. def custom_function_call_vmap_generate_rule(interpreter, autograd_function, *operands):
  293. unwrapped_operands, in_dims = unwrap_batched(operands, interpreter.level())
  294. vmapped_function, get_out_dims = vmapify_autograd_function(
  295. autograd_function, in_dims, interpreter.batch_size(), interpreter.randomness()
  296. )
  297. with interpreter.lower():
  298. output = custom_function_call(vmapped_function, *unwrapped_operands)
  299. out_dims = get_out_dims()
  300. return wrap_batched(output, out_dims, interpreter.level())
  301. @custom_function_call.py_impl(TransformType.Functionalize)
  302. def custom_function_call_functionalize(
  303. interpreter, autograd_function, generate_vmap_rule, *operands
  304. ):
  305. raise RuntimeError("NYI: Functionalize rule for custom_function_call")
  306. def vmapify_autograd_function(autograd_function, in_dims, batch_size, randomness):
  307. # The following values are saved from the forward() and setup_context()
  308. # and used in backward().
  309. # Why do we save the values out here instead of on the ctx object?
  310. # - out_dims: There's no way to retrieve this from forward()
  311. # - input_shapes, saved_tensors_bdims: I'm a bit scared of nesting
  312. # vmap(vmap( but not completely sure if it is a problem. If we
  313. # assigned those fields to the ctx object, the worry is that they
  314. # get overwritten.
  315. init_val = "not populated"
  316. out_dims = init_val
  317. input_shapes: Any = init_val
  318. saved_tensors_bdims: Any = init_val
  319. def forward(*operands):
  320. nonlocal out_dims
  321. outputs, out_dims = restore_vmap(
  322. autograd_function.forward, in_dims, batch_size, randomness
  323. )(*operands)
  324. return outputs
  325. def setup_context(ctx, inputs, outputs):
  326. input_shapes_ = None
  327. saved_tensors_bdims_ = None
  328. def inner(inputs, outputs):
  329. # wrapped_ctx.save_for_backward will:
  330. # - unwrap batchedtensors into (tensor, bdim)
  331. # - save_for_backward(*unwrapped_tensors)
  332. # - assign the bdims to wrapped_ctx._pt_saved_tensors_bdims
  333. wrapped_ctx = CtxCustomSave(ctx, current_level())
  334. autograd_function.setup_context(wrapped_ctx, inputs, outputs)
  335. # input_shapes are used for reductify later to reduce expanded gradients
  336. # to the correct shape.
  337. # See NOTE: [Why can't we rely on autograd to reduce expanded gradients?]
  338. # for more details
  339. nonlocal input_shapes_
  340. input_shapes_ = tuple(
  341. inp.shape if isinstance(inp, torch.Tensor) else None for inp in inputs
  342. )
  343. nonlocal saved_tensors_bdims_
  344. saved_tensors_bdims_ = wrapped_ctx._pt_saved_tensors_bdims
  345. # See NOTE: [Why do we need to run setup_context under a vmap?]
  346. restore_vmap(
  347. inner,
  348. (in_dims, out_dims),
  349. batch_size,
  350. randomness,
  351. )(inputs, outputs)
  352. nonlocal input_shapes
  353. input_shapes = input_shapes_
  354. nonlocal saved_tensors_bdims
  355. saved_tensors_bdims = saved_tensors_bdims_
  356. def jvp(ctx, *tangents):
  357. assert out_dims != init_val
  358. assert saved_tensors_bdims != init_val
  359. def jvp_no_context(saved_tensors, tangents):
  360. wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors)
  361. return autograd_function.jvp(wrapped_ctx, *tangents)
  362. tangent_in_dims = get_tangents_in_dims(in_dims, tangents)
  363. out_tangents, out_tangents_dims = restore_vmap(
  364. jvp_no_context,
  365. (saved_tensors_bdims, tangent_in_dims),
  366. batch_size,
  367. randomness,
  368. )(ctx.saved_tensors, tangents)
  369. result = reductify(out_tangents, out_tangents_dims, out_dims, batch_size)
  370. return result
  371. def backward(ctx, *grad_outputs):
  372. assert out_dims != init_val
  373. assert input_shapes != init_val
  374. assert saved_tensors_bdims != init_val
  375. def backward_no_context(inputs):
  376. saved_tensors, grad_outputs = inputs
  377. wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors)
  378. return autograd_function.backward(wrapped_ctx, *grad_outputs)
  379. grad_ins, grad_ins_dims = restore_vmap(
  380. backward_no_context,
  381. ((saved_tensors_bdims, out_dims),),
  382. batch_size,
  383. randomness,
  384. )((ctx.saved_tensors, grad_outputs))
  385. result = reductify(grad_ins, grad_ins_dims, in_dims, batch_size, input_shapes)
  386. return result
  387. name = f"Vmapped{autograd_function.__name__}"
  388. Generated = type(
  389. name,
  390. (torch.autograd.Function,),
  391. {
  392. "forward": staticmethod(forward),
  393. "backward": staticmethod(backward),
  394. "jvp": staticmethod(jvp),
  395. "setup_context": staticmethod(setup_context),
  396. "generate_vmap_rule": True,
  397. },
  398. )
  399. def get_out_dims():
  400. assert out_dims != init_val
  401. return out_dims
  402. return Generated, get_out_dims
  403. # tangents might be None, so we need to replace
  404. # the corresponding in_dims with None.
  405. def get_tangents_in_dims(input_dims, tangents):
  406. flat_in_dims, spec = pytree.tree_flatten(input_dims)
  407. flat_tangents = pytree.arg_tree_leaves(*tangents)
  408. result = [
  409. None if tangent is None else in_dim
  410. for in_dim, tangent in zip(flat_in_dims, flat_tangents)
  411. ]
  412. return pytree.tree_unflatten(result, spec)
  413. # NOTE: [Why do we need to run setup_context under a vmap?]
  414. # Consider the following autograd.Function
  415. #
  416. # class Sum(torch.autograd.Function):
  417. # @staticmethod
  418. # def forward(x):
  419. # return x.sum()
  420. # @staticmethod
  421. # def setup_context(ctx, inputs, outputs):
  422. # ctx.x_shape = inputs[0]
  423. # @staticmethod
  424. # def backward(ctx, gy):
  425. # return gy.expand(ctx.x_shape)
  426. #
  427. # x = torch.randn(B, 4)
  428. # in_dims = 0
  429. # vmap(Sum.apply, in_dims)(x)
  430. #
  431. # Let's assume for a moment that we didn't vmap setup_context in VmappedSum:
  432. #
  433. # class VmappedSum(torch.autograd.Function):
  434. # @staticmethod
  435. # def forward(x):
  436. # return vmap(Sum.forward, in_dims)(x)
  437. #
  438. # @staticmethod
  439. # def setup_context(ctx, inputs, outputs):
  440. # Sum.setup_context(ctx, inputs, outputs)
  441. #
  442. # @staticmethod
  443. # def backward(ctx, gy):
  444. # def backward_no_context(gy):
  445. # return gy.expand(ctx.x_shape)
  446. #
  447. # dims = (0,)
  448. # gx = vmap(backward_no_context, dims)(gy)
  449. # return gx
  450. #
  451. # We end up saving [B, 4] as x_shape. In the backward, gy has shape [B],
  452. # and we're doing:
  453. #
  454. # def backward_no_context(gy):
  455. # return gy.expand([B, 4])
  456. #
  457. # gx = vmap(backward_no_context, dims)(gy: "Tensor[B]")
  458. #
  459. # This gives us the wrong result (gx has shape [B, B, 4], but it should
  460. # have shape [4]). Performing vmap over setup_context means the shape
  461. # saved has shape [4] and leads to a correct result shape for gx.
  462. # Wraps a ctx object. Forwards all attr accesses to the underlying object
  463. # except for the attrs in _pt_attrs
  464. class WrappedCtx:
  465. _pt_reserved_attrs: Tuple[str, ...] = ("_pt_reserved_attrs", "_pt_inner_ctx")
  466. def __init__(self, ctx):
  467. if not isinstance(ctx, WrappedCtx):
  468. reserved_attrs = type(self)._pt_reserved_attrs
  469. for name in reserved_attrs:
  470. if not hasattr(ctx, name):
  471. continue
  472. raise RuntimeError(
  473. f"PyTorch reserves the {reserved_attrs} field on ctx. "
  474. "Please name your fields on ctx something else to avoid name "
  475. "collision."
  476. )
  477. self._pt_inner_ctx = ctx
  478. def __getattr__(self, name):
  479. return getattr(self._pt_inner_ctx, name)
  480. def __setattr__(self, name, value):
  481. if name in type(self)._pt_reserved_attrs:
  482. self.__dict__[name] = value
  483. return
  484. return setattr(self._pt_inner_ctx, name, value)
  485. # Wraps ctx to create a new ctx object that overrides saved_tensors.
  486. class CtxWithSavedTensors(WrappedCtx):
  487. _pt_reserved_attrs = ("_pt_new_saved_tensors", *WrappedCtx._pt_reserved_attrs)
  488. def __init__(self, ctx, new_saved_tensors):
  489. super().__init__(ctx)
  490. self._pt_new_saved_tensors = new_saved_tensors
  491. @property
  492. def saved_tensors(self):
  493. return self._pt_new_saved_tensors
  494. class CtxCustomSave(WrappedCtx):
  495. _pt_reserved_attrs = (
  496. "_pt_saved_tensors_bdims",
  497. "_pt_current_level",
  498. *WrappedCtx._pt_reserved_attrs,
  499. )
  500. def __init__(self, ctx, current_level):
  501. super().__init__(ctx)
  502. self._pt_saved_tensors_bdims = ()
  503. self._pt_current_level = current_level
  504. def save_for_backward(self, *tensors):
  505. unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level)
  506. self._pt_inner_ctx.save_for_backward(*unwrapped_tensors)
  507. self._pt_saved_tensors_bdims = bdims
  508. def save_for_forward(self, *tensors):
  509. unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level)
  510. self._pt_inner_ctx.save_for_forward(*unwrapped_tensors)
  511. self._pt_saved_tensors_bdims = bdims
  512. def reductify(
  513. grad_input,
  514. grad_input_bdim,
  515. input_bdim,
  516. batch_size,
  517. target_shape_without_bdim_to_reduce_to=None,
  518. ):
  519. if not isinstance(grad_input, tuple):
  520. grad_input = (grad_input,)
  521. if not isinstance(grad_input_bdim, tuple):
  522. grad_input_bdim = (grad_input_bdim,)
  523. if not isinstance(input_bdim, tuple):
  524. input_bdim = (input_bdim,)
  525. if target_shape_without_bdim_to_reduce_to is None:
  526. target_shape_without_bdim_to_reduce_to = len(grad_input) * (None,)
  527. result = tuple(
  528. reductify_leaf(gi, gi_bdim, i_bdim, batch_size, maybe_ishape)
  529. for gi, gi_bdim, i_bdim, maybe_ishape in zip(
  530. grad_input,
  531. grad_input_bdim,
  532. input_bdim,
  533. target_shape_without_bdim_to_reduce_to,
  534. )
  535. )
  536. return result
  537. def reductify_leaf(
  538. grad_input,
  539. grad_input_bdim,
  540. input_bdim,
  541. batch_size,
  542. target_shape_without_bdim_to_reduce_to=None,
  543. ):
  544. if grad_input is None:
  545. return None
  546. if grad_input_bdim is None and input_bdim is None:
  547. return grad_input
  548. if grad_input_bdim is not None and input_bdim is None:
  549. return grad_input.sum(grad_input_bdim)
  550. # NOTE: [Why can't we rely on autograd to reduce expanded gradients?]
  551. # For reverse-mode AD,
  552. # given a grad_input and input, it is valid for the user to return a
  553. # grad_input that has a broadcasted shape when compared to the input.
  554. # In this situation, autograd automatically reduces the grad_input to
  555. # the shape of the input.
  556. #
  557. # However, when input_bdim is not None, we have problems.
  558. #
  559. # [example 1]
  560. # grad_input: Tensor[3, 4], input: Tensor[B, 4]
  561. # We can expand grad_input to Tensor[B, 3, 4], but that isn't broadcastable
  562. # from [B, 4].
  563. #
  564. # [example 2]
  565. # grad_input: Tensor[3, B, 4], input: Tensor[B, 4]
  566. # We can swizzle grad_input to Tensor[B, 3, 4], but that isn't broadcastable
  567. # from [B, 4].
  568. #
  569. # This means that we need to also reduce the grad_input to the shape of the
  570. # input. This behavior is controlled by the `target_shape_without_bdim_to_reduce_to` flag;
  571. # if not-None then we do the reducing manually, otherwise, we do not do a reduction.
  572. assert input_bdim is not None
  573. if grad_input_bdim is None:
  574. grad_input = grad_input.unsqueeze(input_bdim)
  575. new_shape = list(grad_input.shape)
  576. new_shape[input_bdim] = batch_size
  577. grad_input = grad_input.expand(new_shape)
  578. grad_input_bdim = input_bdim
  579. if target_shape_without_bdim_to_reduce_to is not None:
  580. return vmap(
  581. torch.Tensor.sum_to_size,
  582. in_dims=(grad_input_bdim, None),
  583. out_dims=input_bdim,
  584. )(grad_input, target_shape_without_bdim_to_reduce_to)
  585. if input_bdim != grad_input_bdim:
  586. grad_input = grad_input.movedim(grad_input_bdim, input_bdim)
  587. return grad_input
  588. def autograd_function_forward_rewritten(original_forward, original_setup_context):
  589. def new_forward(ctx, *args, **kwargs):
  590. output = original_forward(*args, **kwargs)
  591. original_setup_context(ctx, args, output)
  592. return output
  593. return new_forward
  594. class AutogradFunctionApply(HigherOrderOperator):
  595. def __init__(self):
  596. super().__init__("autograd_function_apply")
  597. def __call__(self, fwd, bwd, *fwd_args, **fwd_kwargs):
  598. saved_values = None
  599. args_tensor_mask = fwd_kwargs["args_tensor_mask"]
  600. length_of_tensor_args = sum(args_tensor_mask)
  601. # Filter out the original tensor args from fwd_args,
  602. # lifted freevars should not be args of ApplyTemplate.apply
  603. # since we don't need to calculate the gradients of them.
  604. new_fwd_args = fwd_args[:length_of_tensor_args]
  605. class ApplyTemplate(torch.autograd.Function):
  606. @staticmethod
  607. def forward(ctx, *args):
  608. nonlocal saved_values
  609. output, saved_values = fwd(None, *fwd_args)
  610. return output
  611. @staticmethod
  612. def backward(ctx, *grad):
  613. return bwd(None, *grad, *saved_values)
  614. return ApplyTemplate.apply(*new_fwd_args)
  615. autograd_function_apply = AutogradFunctionApply()