hooks.py 644 B

1234567891011121314151617181920212223242526272829
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. from typing import Callable, List, TYPE_CHECKING
  4. if TYPE_CHECKING:
  5. import torch
  6. # Executed in the order they're registered
  7. INTERMEDIATE_HOOKS: List[Callable[[str, "torch.Tensor"], None]] = []
  8. @contextlib.contextmanager
  9. def intermediate_hook(fn):
  10. INTERMEDIATE_HOOKS.append(fn)
  11. try:
  12. yield
  13. finally:
  14. INTERMEDIATE_HOOKS.pop()
  15. def run_intermediate_hooks(name, val):
  16. global INTERMEDIATE_HOOKS
  17. hooks = INTERMEDIATE_HOOKS
  18. INTERMEDIATE_HOOKS = []
  19. try:
  20. for hook in hooks:
  21. hook(name, val)
  22. finally:
  23. INTERMEDIATE_HOOKS = hooks