module_tracker.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # mypy: allow-untyped-defs
  2. import logging
  3. import weakref
  4. from typing import Set
  5. import torch
  6. from torch.autograd.graph import register_multi_grad_hook
  7. from torch.nn.modules.module import (
  8. register_module_forward_hook,
  9. register_module_forward_pre_hook,
  10. )
  11. from torch.utils._pytree import tree_flatten
  12. logger = logging.getLogger(__name__)
  13. __all__ = ["ModuleTracker"]
  14. class ModuleTracker:
  15. """
  16. ``ModuleTracker`` is a context manager that tracks the nn.Module hierarchy during execution
  17. so that other system can query which Module is currently being executed (or its backward is being
  18. executed).
  19. You can access the ``parents`` attribute on this context manager to get the set of all the
  20. Modules currently being executed via their fqn (fully qualified name, also used as the key within
  21. the state_dict).
  22. You can access the ``is_bw`` attribute to know if you are currently running in backward or not.
  23. Note that ``parents`` is never empty and always contains the "Global" key. The ``is_bw`` flag
  24. will remain ``True`` after the forward until another Module is executed. If you need it to be
  25. more accurate, please submit an issue requesting this. Adding a map from fqn to the module instance
  26. is possible but not done yet, please submit an issue requesting this if you need it.
  27. Example usage
  28. .. code-block:: python
  29. mod = torch.nn.Linear(2, 2)
  30. with ModuleTracker() as tracker:
  31. # Access anything during the forward pass
  32. def my_linear(m1, m2, bias):
  33. print(f"Current modules: {tracker.parents}")
  34. return torch.mm(m1, m2.t()) + bias
  35. torch.nn.functional.linear = my_linear
  36. mod(torch.rand(2, 2))
  37. """
  38. parents: Set[str]
  39. """
  40. A Set containing the fqn for each module currently running their forward
  41. """
  42. def __init__(self):
  43. self.parents = {"Global"}
  44. self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
  45. self._seen_modules: weakref.WeakSet = weakref.WeakSet()
  46. self._has_callback = False
  47. def _maybe_set_engine_callback(self):
  48. # This assumes no concurrent calls to backward
  49. if self._has_callback:
  50. return
  51. def callback():
  52. self.parents = {"Global"}
  53. self._has_callback = False
  54. torch.autograd.Variable._execution_engine.queue_callback(callback)
  55. self._has_callback = True
  56. @property
  57. def is_bw(self):
  58. """
  59. A boolean marking if this is currently running during the backward pass or not
  60. """
  61. return torch._C._current_graph_task_id() != -1
  62. def _get_mod_name(self, mod):
  63. if mod not in self._known_modules:
  64. self._known_modules[mod] = type(mod).__name__
  65. mod_name = self._known_modules[mod]
  66. if mod not in self._seen_modules:
  67. for name, submod in mod.named_children():
  68. self._known_modules[submod] = f"{mod_name}.{name}"
  69. self._get_mod_name(submod)
  70. self._seen_modules.add(mod)
  71. return mod_name
  72. def _get_append_fn(self, name, is_bw):
  73. def fn(*args):
  74. if is_bw:
  75. self._maybe_set_engine_callback()
  76. if name in self.parents:
  77. logger.info(
  78. "The module hierarchy tracking seems to be broken as this Module was already entered. %s during %s",
  79. name,
  80. "backward" if is_bw else "forward",
  81. )
  82. self.parents.add(name)
  83. return fn
  84. def _get_pop_fn(self, name, is_bw):
  85. def fn(*args):
  86. if name in self.parents:
  87. self.parents.remove(name)
  88. else:
  89. logger.info(
  90. "The Module hierarchy tracking is confused as we're exiting a Module that was never entered. %s during %s",
  91. name,
  92. "backward" if is_bw else "forward",
  93. )
  94. return fn
  95. def _fw_pre_hook(self, mod, input):
  96. name = self._get_mod_name(mod)
  97. self._get_append_fn(name, False)()
  98. args, _ = tree_flatten(input)
  99. tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
  100. if tensors:
  101. register_multi_grad_hook(tensors, self._get_pop_fn(name, True))
  102. def _fw_post_hook(self, mod, input, output):
  103. name = self._get_mod_name(mod)
  104. self._get_pop_fn(name, False)()
  105. args, _ = tree_flatten(output)
  106. tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
  107. if tensors:
  108. register_multi_grad_hook(tensors, self._get_append_fn(name, True))
  109. def __enter__(self):
  110. self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook)
  111. self._fw_post_handle = register_module_forward_hook(self._fw_post_hook)
  112. return self
  113. def __exit__(self, *args):
  114. self._fw_pre_handle.remove()
  115. self._fw_post_handle.remove()