hooks.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. # mypy: allow-untyped-defs
  2. import torch
  3. from collections import OrderedDict
  4. import weakref
  5. import warnings
  6. from typing import Any, Tuple
  7. __all__ = ["RemovableHandle", "unserializable_hook", "warn_if_has_hooks", "BackwardHook"]
  8. class RemovableHandle:
  9. r"""
  10. A handle which provides the capability to remove a hook.
  11. Args:
  12. hooks_dict (dict): A dictionary of hooks, indexed by hook ``id``.
  13. extra_dict (Union[dict, List[dict]]): An additional dictionary or list of
  14. dictionaries whose keys will be deleted when the same keys are
  15. removed from ``hooks_dict``.
  16. """
  17. id: int
  18. next_id: int = 0
  19. def __init__(self, hooks_dict: Any, *, extra_dict: Any = None) -> None:
  20. self.hooks_dict_ref = weakref.ref(hooks_dict)
  21. self.id = RemovableHandle.next_id
  22. RemovableHandle.next_id += 1
  23. self.extra_dict_ref: Tuple = ()
  24. if isinstance(extra_dict, dict):
  25. self.extra_dict_ref = (weakref.ref(extra_dict),)
  26. elif isinstance(extra_dict, list):
  27. self.extra_dict_ref = tuple(weakref.ref(d) for d in extra_dict)
  28. def remove(self) -> None:
  29. hooks_dict = self.hooks_dict_ref()
  30. if hooks_dict is not None and self.id in hooks_dict:
  31. del hooks_dict[self.id]
  32. for ref in self.extra_dict_ref:
  33. extra_dict = ref()
  34. if extra_dict is not None and self.id in extra_dict:
  35. del extra_dict[self.id]
  36. def __getstate__(self):
  37. if self.extra_dict_ref is None:
  38. return (self.hooks_dict_ref(), self.id)
  39. else:
  40. return (self.hooks_dict_ref(), self.id, tuple(ref() for ref in self.extra_dict_ref))
  41. def __setstate__(self, state) -> None:
  42. if state[0] is None:
  43. # create a dead reference
  44. self.hooks_dict_ref = weakref.ref(OrderedDict())
  45. else:
  46. self.hooks_dict_ref = weakref.ref(state[0])
  47. self.id = state[1]
  48. RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1)
  49. if len(state) < 3 or state[2] is None:
  50. self.extra_dict_ref = ()
  51. else:
  52. self.extra_dict_ref = tuple(weakref.ref(d) for d in state[2])
  53. def __enter__(self) -> "RemovableHandle":
  54. return self
  55. def __exit__(self, type: Any, value: Any, tb: Any) -> None:
  56. self.remove()
  57. def unserializable_hook(f):
  58. """
  59. Mark a function as an unserializable hook with this decorator.
  60. This suppresses warnings that would otherwise arise if you attempt
  61. to serialize a tensor that has a hook.
  62. """
  63. f.__torch_unserializable__ = True
  64. return f
  65. def warn_if_has_hooks(tensor):
  66. if tensor._backward_hooks:
  67. for k in tensor._backward_hooks:
  68. hook = tensor._backward_hooks[k]
  69. if not hasattr(hook, "__torch_unserializable__"):
  70. warnings.warn(f"backward hook {repr(hook)} on tensor will not be "
  71. "serialized. If this is expected, you can "
  72. "decorate the function with @torch.utils.hooks.unserializable_hook "
  73. "to suppress this warning")
  74. class BackwardHook:
  75. """
  76. A wrapper class to implement nn.Module backward hooks.
  77. It handles:
  78. - Ignoring non-Tensor inputs and replacing them by None before calling the user hook
  79. - Generating the proper Node to capture a set of Tensor's gradients
  80. - Linking the gradients captures for the outputs with the gradients captured for the input
  81. - Calling the user hook once both output and input gradients are available
  82. """
  83. def __init__(self, module, user_hooks, user_pre_hooks):
  84. self.user_hooks = user_hooks
  85. self.user_pre_hooks = user_pre_hooks
  86. self.module = module
  87. self.grad_outputs = None
  88. self.n_outputs = -1
  89. self.output_tensors_index = None
  90. self.n_inputs = -1
  91. self.input_tensors_index = None
  92. def _pack_with_none(self, indices, values, size):
  93. res = [None] * size
  94. for idx, val in zip(indices, values):
  95. res[idx] = val
  96. return tuple(res)
  97. def _unpack_none(self, indices, values):
  98. res = []
  99. for idx in indices:
  100. res.append(values[idx])
  101. return tuple(res)
  102. def _set_user_hook(self, grad_fn):
  103. def hook(grad_input, _):
  104. if self.grad_outputs is None:
  105. # This happens because the gradient in your nn.Module flows to
  106. # the Module's input without " passing through the Module's
  107. # output, e.g. when you're doing double backward.
  108. return
  109. res = self._pack_with_none(self.input_tensors_index, grad_input, self.n_inputs)
  110. for hook in self.user_hooks:
  111. out = hook(self.module, res, self.grad_outputs)
  112. if out is None:
  113. continue
  114. if len(out) != len(res):
  115. raise RuntimeError("Backward hook returned an invalid number of grad_input, "
  116. f"got {len(out)}, but expected {len(res)}")
  117. res = out
  118. self.grad_outputs = None
  119. return self._unpack_none(self.input_tensors_index, res)
  120. grad_fn.register_hook(hook)
  121. def _apply_on_tensors(self, fn, args):
  122. # Can be used to apply the given function to the tensors contained in the
  123. # args. Will return updated args and the tensors indices
  124. tensors_idx = []
  125. tensors = []
  126. requires_grad = False
  127. for i, arg in enumerate(args):
  128. if isinstance(arg, torch.Tensor):
  129. tensors_idx.append(i)
  130. tensors.append(arg)
  131. requires_grad |= arg.requires_grad
  132. if not (requires_grad and torch.is_grad_enabled()):
  133. return args, None
  134. new_tensors = torch.nn.modules._functions.BackwardHookFunction.apply(*tensors)
  135. if len(new_tensors) == 0:
  136. raise RuntimeError("Cannot set Module backward hook for a Module with no input Tensors.")
  137. grad_fns = [t.grad_fn for t in new_tensors if t.grad_fn is not None and t.grad_fn.name() == "BackwardHookFunctionBackward"]
  138. if len(grad_fns) == 0:
  139. raise RuntimeError("Error while setting up backward hooks. Please open "
  140. "an issue with a code sample to reproduce this.")
  141. fn(grad_fns[0])
  142. arg_list = list(args)
  143. for idx, val in zip(tensors_idx, new_tensors):
  144. arg_list[idx] = val
  145. if type(args) is tuple:
  146. out = tuple(arg_list)
  147. else:
  148. out = type(args)(*arg_list)
  149. return out, tensors_idx
  150. def setup_input_hook(self, args):
  151. def fn(grad_fn):
  152. self._set_user_hook(grad_fn)
  153. res, input_idx = self._apply_on_tensors(fn, args)
  154. self.n_inputs = len(args)
  155. self.input_tensors_index = input_idx
  156. return res
  157. def setup_output_hook(self, args):
  158. def fn(grad_fn):
  159. def hook(_, grad_output):
  160. self.grad_outputs = self._pack_with_none(self.output_tensors_index,
  161. grad_output,
  162. self.n_outputs)
  163. if self.user_pre_hooks:
  164. expected_len = len(self.grad_outputs)
  165. for user_pre_hook in self.user_pre_hooks:
  166. hook_grad_outputs = user_pre_hook(self.module, self.grad_outputs)
  167. if hook_grad_outputs is None:
  168. continue
  169. actual_len = len(hook_grad_outputs)
  170. if actual_len != expected_len:
  171. raise RuntimeError("Backward pre hook returned an invalid number of grad_output, "
  172. f"got {actual_len}, but expected {expected_len}")
  173. self.grad_outputs = hook_grad_outputs
  174. # We need to be able to clear self.grad_outputs but also return it
  175. local_grad_outputs = self.grad_outputs
  176. # Special case if no input required gradients, this hook should call the user
  177. # hook directly
  178. if self.input_tensors_index is None:
  179. grad_inputs = self._pack_with_none([], [], self.n_inputs)
  180. for user_hook in self.user_hooks:
  181. res = user_hook(self.module, grad_inputs, self.grad_outputs)
  182. if res is not None and not (isinstance(res, tuple) and all(el is None for el in res)):
  183. raise RuntimeError("Backward hook for Modules where no input requires "
  184. "gradient should always return None or None for all gradients.")
  185. self.grad_outputs = None
  186. if local_grad_outputs is not None:
  187. assert self.output_tensors_index is not None # mypy
  188. return tuple(local_grad_outputs[i] for i in self.output_tensors_index)
  189. grad_fn.register_hook(hook)
  190. is_tuple = True
  191. if not isinstance(args, tuple):
  192. args = (args,)
  193. is_tuple = False
  194. res, output_idx = self._apply_on_tensors(fn, args)
  195. self.n_outputs = len(args)
  196. self.output_tensors_index = output_idx
  197. if not is_tuple:
  198. res = res[0]
  199. return res