_lazy_graph_module.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. # mypy: allow-untyped-defs
  2. from contextlib import contextmanager
  3. from torch.fx import GraphModule
  4. from torch.fx.graph_module import (
  5. _format_import_block,
  6. reduce_graph_module,
  7. reduce_package_graph_module,
  8. )
  9. from torch.package import PackageExporter, sys_importer
  10. from ._compatibility import compatibility
  11. _use_lazy_graph_module_flag = False
  12. _force_skip_lazy_graph_module_flag = False
  13. @compatibility(is_backward_compatible=False)
  14. @contextmanager
  15. def _force_skip_lazy_graph_module():
  16. """
  17. Skip using lazy graph module disregarding the setting of _use_lazy_graph_module.
  18. Use to skip _LazyGraphModule when testing inductor torchscript related backend.
  19. torch.jit.script a _LazyGraphModule results in following error:
  20. https://gist.github.com/shunting314/5143654c8084aed84ecd19b818258a69
  21. """
  22. try:
  23. global _force_skip_lazy_graph_module_flag
  24. prior = _force_skip_lazy_graph_module_flag
  25. _force_skip_lazy_graph_module_flag = True
  26. yield
  27. finally:
  28. _force_skip_lazy_graph_module_flag = prior
  29. @compatibility(is_backward_compatible=False)
  30. @contextmanager
  31. def _use_lazy_graph_module(should_use: bool):
  32. try:
  33. global _use_lazy_graph_module_flag
  34. prior = _use_lazy_graph_module_flag
  35. _use_lazy_graph_module_flag = (
  36. should_use and not _force_skip_lazy_graph_module_flag
  37. )
  38. yield
  39. finally:
  40. _use_lazy_graph_module_flag = prior
  41. @compatibility(is_backward_compatible=False)
  42. def _get_graph_module_cls():
  43. return _LazyGraphModule if _use_lazy_graph_module_flag else GraphModule
  44. def _make_graph_module(*args, graph_module_cls=None, **kwargs):
  45. if graph_module_cls is None:
  46. graph_module_cls = _get_graph_module_cls()
  47. return graph_module_cls(*args, **kwargs)
  48. @compatibility(is_backward_compatible=False)
  49. class _LazyGraphModule(GraphModule):
  50. """
  51. The main difference between _LazyGraphModule and GraphModule is how recompile happens.
  52. GraphModule will do a 'recompile' call to generate python code and the forward method when it's
  53. constructed. Later on if the graph get updated, recompile method can be called again to refresh
  54. the saved python code and forward method.
  55. However in some cases especially in inductor, the recompilation can be a waste since we never
  56. check the python code for the graph module or call its forward method. A few more concreate
  57. examples regarding pattern matching fx passes in inductor:
  58. 1. some passes will update the graph to be compiled and then call recompile on the GraphModule.
  59. 2. some passes will trace small pattern function to search it in the graph being compiled and
  60. replace the match with the traced graph of a replacement function. The pattern graph and
  61. replacement graph are quite small but there are large amount of them. Doing GraphModule.recompile
  62. for them in GraphModule.__init__ is also a waste of time.
  63. However simply skip calling GraphModule.recompile in these scenarios is also dangeruous.
  64. People may want to check the python code or call the GraphModule's forward method for debugging purposes.
  65. The way _LazyGraphModule solves it is, we override the recompile method to just mark the
  66. need for recompilation but does not do the actual recompilation. Later on if people really
  67. access the compiled python code or call the GraphModule's forward method, we do the real
  68. recompilation.
  69. """
  70. @classmethod
  71. def from_graphmodule(cls, gm: GraphModule):
  72. if isinstance(gm, _LazyGraphModule):
  73. return gm
  74. else:
  75. return _LazyGraphModule(gm, gm.graph)
  76. @staticmethod
  77. def force_recompile(gm):
  78. """
  79. Sometimes we need force a recompile as a workaround
  80. - we want to do the real recompilation before symbolic_trace to avoid error:
  81. https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259
  82. """
  83. if isinstance(gm, _LazyGraphModule):
  84. gm.real_recompile()
  85. def real_recompile(self):
  86. if self._needs_recompile():
  87. self._real_recompile()
  88. @classmethod
  89. def _needs_recompile(cls):
  90. return cls.forward is cls._lazy_forward
  91. def _lazy_forward(self, *args, **kwargs):
  92. # Call self.real_recompile() rather than self._real_recompile() here.
  93. # The _lazy_forward method may be saved and call repeatedly.
  94. # Calling self.real_recompile can make sure we skip recompilation if
  95. # we have already done so.
  96. self.real_recompile()
  97. assert not self._needs_recompile()
  98. # call `__call__` rather than 'forward' since recompilation may
  99. # install a wrapper for `__call__` to provide a customized error
  100. # message.
  101. return self(*args, **kwargs)
  102. forward = _lazy_forward
  103. # TODO: we shold handle __reduce_deploy__ the same way as __reduce_package__,
  104. # or __reduce__ by calling _real_recompile. But I don't find a good way
  105. # to test __reduce_deploy__ out. Also it's very unlikely that LazyGraphModule
  106. # will be used in torch::deploy. So it's skipped for now.
  107. def __reduce_package__(self, exporter: PackageExporter):
  108. """
  109. Follow GraphModule.__reduce__ but call 'self._real_recompile' rather
  110. than 'self.recompile' since for a _LazyGraphModule, self.recompile just
  111. mark the need of recompilation and does not return the PythonCode object.
  112. """
  113. python_code = self._real_recompile()
  114. dict_without_graph = self.__dict__.copy()
  115. dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__
  116. del dict_without_graph["_graph"]
  117. generated_module_name = f"fx-generated._{exporter.get_unique_id()}"
  118. import_block = _format_import_block(python_code.globals, exporter.importer)
  119. module_code = import_block + self.code
  120. exporter.save_source_string(generated_module_name, module_code)
  121. return (
  122. reduce_package_graph_module,
  123. (dict_without_graph, generated_module_name),
  124. )
  125. def __reduce__(self):
  126. """
  127. Follow GraphModule.__reduce__ but call 'self._real_recompile' rather
  128. than 'self.recompile' since for a _LazyGraphModule, self.recompile just
  129. mark the need of recompilation and does not return the PythonCode object.
  130. """
  131. python_code = self._real_recompile()
  132. dict_without_graph = self.__dict__.copy()
  133. import_block = _format_import_block(python_code.globals, sys_importer)
  134. del dict_without_graph["_graph"]
  135. return (reduce_graph_module, (dict_without_graph, import_block))
  136. def _real_recompile(self):
  137. return super().recompile()
  138. @classmethod
  139. def recompile(cls):
  140. cls.forward = cls._lazy_forward
  141. @property
  142. def code(self) -> str:
  143. self.real_recompile()
  144. return super().code
  145. def __str__(self) -> str:
  146. """
  147. str(GraphModule) will access the _code attribute. Make sure recompile
  148. happens so _code attribute is available.
  149. """
  150. self.real_recompile()
  151. return super().__str__()