| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148 |
- # mypy: allow-untyped-defs
- import logging
- import weakref
- from typing import Set
- import torch
- from torch.autograd.graph import register_multi_grad_hook
- from torch.nn.modules.module import (
- register_module_forward_hook,
- register_module_forward_pre_hook,
- )
- from torch.utils._pytree import tree_flatten
- logger = logging.getLogger(__name__)
- __all__ = ["ModuleTracker"]
- class ModuleTracker:
- """
- ``ModuleTracker`` is a context manager that tracks the nn.Module hierarchy during execution
- so that other system can query which Module is currently being executed (or its backward is being
- executed).
- You can access the ``parents`` attribute on this context manager to get the set of all the
- Modules currently being executed via their fqn (fully qualified name, also used as the key within
- the state_dict).
- You can access the ``is_bw`` attribute to know if you are currently running in backward or not.
- Note that ``parents`` is never empty and always contains the "Global" key. The ``is_bw`` flag
- will remain ``True`` after the forward until another Module is executed. If you need it to be
- more accurate, please submit an issue requesting this. Adding a map from fqn to the module instance
- is possible but not done yet, please submit an issue requesting this if you need it.
- Example usage
- .. code-block:: python
- mod = torch.nn.Linear(2, 2)
- with ModuleTracker() as tracker:
- # Access anything during the forward pass
- def my_linear(m1, m2, bias):
- print(f"Current modules: {tracker.parents}")
- return torch.mm(m1, m2.t()) + bias
- torch.nn.functional.linear = my_linear
- mod(torch.rand(2, 2))
- """
- parents: Set[str]
- """
- A Set containing the fqn for each module currently running their forward
- """
- def __init__(self):
- self.parents = {"Global"}
- self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
- self._seen_modules: weakref.WeakSet = weakref.WeakSet()
- self._has_callback = False
- def _maybe_set_engine_callback(self):
- # This assumes no concurrent calls to backward
- if self._has_callback:
- return
- def callback():
- self.parents = {"Global"}
- self._has_callback = False
- torch.autograd.Variable._execution_engine.queue_callback(callback)
- self._has_callback = True
- @property
- def is_bw(self):
- """
- A boolean marking if this is currently running during the backward pass or not
- """
- return torch._C._current_graph_task_id() != -1
- def _get_mod_name(self, mod):
- if mod not in self._known_modules:
- self._known_modules[mod] = type(mod).__name__
- mod_name = self._known_modules[mod]
- if mod not in self._seen_modules:
- for name, submod in mod.named_children():
- self._known_modules[submod] = f"{mod_name}.{name}"
- self._get_mod_name(submod)
- self._seen_modules.add(mod)
- return mod_name
- def _get_append_fn(self, name, is_bw):
- def fn(*args):
- if is_bw:
- self._maybe_set_engine_callback()
- if name in self.parents:
- logger.info(
- "The module hierarchy tracking seems to be broken as this Module was already entered. %s during %s",
- name,
- "backward" if is_bw else "forward",
- )
- self.parents.add(name)
- return fn
- def _get_pop_fn(self, name, is_bw):
- def fn(*args):
- if name in self.parents:
- self.parents.remove(name)
- else:
- logger.info(
- "The Module hierarchy tracking is confused as we're exiting a Module that was never entered. %s during %s",
- name,
- "backward" if is_bw else "forward",
- )
- return fn
- def _fw_pre_hook(self, mod, input):
- name = self._get_mod_name(mod)
- self._get_append_fn(name, False)()
- args, _ = tree_flatten(input)
- tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
- if tensors:
- register_multi_grad_hook(tensors, self._get_pop_fn(name, True))
- def _fw_post_hook(self, mod, input, output):
- name = self._get_mod_name(mod)
- self._get_pop_fn(name, False)()
- args, _ = tree_flatten(output)
- tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
- if tensors:
- register_multi_grad_hook(tensors, self._get_append_fn(name, True))
- def __enter__(self):
- self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook)
- self._fw_post_handle = register_module_forward_hook(self._fw_post_hook)
- return self
- def __exit__(self, *args):
- self._fw_pre_handle.remove()
- self._fw_post_handle.remove()
|