mutation_guard.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. # mypy: allow-untyped-defs
  2. # mypy: disable-error-code="method-assign"
  3. import functools
  4. import weakref
  5. import torch.nn
  6. from torch.nn import Module
  7. from . import config
  8. from .utils import ExactWeakKeyDictionary, is_lazy_module, nn_module_has_global_hooks
  9. class MutationTracker:
  10. db = ExactWeakKeyDictionary()
  11. def __init__(self):
  12. self.mutation_count = 0
  13. self.watchers = []
  14. def on_mutation(self, name):
  15. self.mutation_count += 1
  16. tmp = self.watchers
  17. self.watchers = []
  18. for ref in tmp:
  19. guarded = ref()
  20. if guarded is not None:
  21. guarded.invalidate(ref)
  22. def track(self, guarded_code):
  23. self.watchers.append(weakref.ref(guarded_code))
  24. def watch(obj, guarded_code):
  25. """invalidate guarded_code when obj is mutated"""
  26. ensure_patched(type(obj))
  27. if obj not in MutationTracker.db:
  28. MutationTracker.db[obj] = MutationTracker()
  29. tracker = MutationTracker.db[obj]
  30. tracker.track(guarded_code)
  31. def ensure_patched(cls):
  32. if getattr(cls, "___needs_mutation_patch", True):
  33. cls.___needs_mutation_patch = False
  34. original_setattr = cls.__setattr__
  35. @functools.wraps(original_setattr)
  36. def custom_setattr(self, key, value):
  37. try:
  38. MutationTracker.db[self].on_mutation(key)
  39. except KeyError:
  40. pass
  41. return original_setattr(self, key, value)
  42. cls.__setattr__ = custom_setattr
  43. class GenerationTracker:
  44. generation = 0
  45. dynamic_classes = ExactWeakKeyDictionary()
  46. generation_values = ExactWeakKeyDictionary()
  47. @classmethod
  48. def tag(cls, obj):
  49. cls.generation_values[obj] = cls.generation
  50. @staticmethod
  51. def mark_class_dynamic(cls):
  52. assert issubclass(cls, torch.nn.Module)
  53. GenerationTracker.dynamic_classes[cls] = True
  54. @classmethod
  55. def get_generation_value(cls, obj):
  56. if obj not in cls.generation_values:
  57. return -1
  58. return cls.generation_values[obj]
  59. @classmethod
  60. def check(cls, obj):
  61. return (
  62. obj in cls.generation_values
  63. and cls.generation_values[obj] == cls.generation
  64. )
  65. @classmethod
  66. def clear(cls):
  67. cls.generation = 0
  68. cls.dynamic_classes = ExactWeakKeyDictionary()
  69. cls.generation_values = ExactWeakKeyDictionary()
  70. def is_dynamic_nn_module(obj, is_export):
  71. """Check for nn.Modules() created dynamically or mutated"""
  72. if isinstance(obj, torch.nn.Module) and "forward" in obj.__dict__:
  73. # A monkey patched `.forward` indicates something wacky is going on
  74. return True
  75. if hasattr(obj, "torchdynamo_force_dynamic"):
  76. return obj.torchdynamo_force_dynamic
  77. if is_lazy_module(obj):
  78. return False
  79. # For export, we will have to fix
  80. # 1) Input signature problem because params are lifted as inputs
  81. # 2) nn module stack info changes
  82. # 3) adjust failing tests
  83. if (
  84. isinstance(obj, torch.nn.Module)
  85. and config.inline_inbuilt_nn_modules
  86. and not is_export
  87. ):
  88. return True
  89. if isinstance(obj, torch.nn.Module) and nn_module_has_global_hooks():
  90. return True
  91. dyn = GenerationTracker.dynamic_classes.get(type(obj)) or GenerationTracker.check(
  92. obj
  93. )
  94. return dyn
  95. def install_generation_tagging_init():
  96. """
  97. Monkey patch torch.nn.Module.__init__ and torch.nn.Module.__setstate__
  98. so we can detect nn.Module instances created dynamically inside forward methods.
  99. """
  100. if getattr(Module, "___needs_generation_tag_patch", True):
  101. init = Module.__init__
  102. def patched_init(self, *args, **kwargs):
  103. init(self, *args, **kwargs)
  104. GenerationTracker.tag(self)
  105. Module.__init__ = patched_init
  106. setstate = Module.__setstate__
  107. def patched_setstate(self, state):
  108. setstate(self, state)
  109. GenerationTracker.tag(self)
  110. Module.__setstate__ = patched_setstate
  111. Module.___needs_generation_tag_patch = False # type: ignore[attr-defined]
  112. GenerationTracker.generation += 1