common.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. # mypy: ignore-errors
  2. import contextlib
  3. import functools
  4. import logging
  5. from unittest.mock import patch
  6. import torch
  7. from torch._dynamo import disable
  8. from torch._dynamo.utils import counters, defake, flatten_graph_inputs
  9. from torch._functorch.aot_autograd import aot_module_simplified
  10. from torch.utils._python_dispatch import _disable_current_modes
  11. log = logging.getLogger(__name__)
  12. class AotAutograd:
  13. def __init__(self, **kwargs):
  14. self.__name__ = "compiler_fn"
  15. self.kwargs = kwargs
  16. def __call__(self, gm: torch.fx.GraphModule, example_inputs):
  17. if any(isinstance(x, (list, tuple, dict)) for x in example_inputs):
  18. return flatten_graph_inputs(
  19. gm,
  20. example_inputs,
  21. self,
  22. )
  23. # Hack to get around circular import problems with aot_eager_decomp_partition
  24. if callable(self.kwargs.get("decompositions")):
  25. self.kwargs["decompositions"] = self.kwargs["decompositions"]()
  26. # NB: dont delete counter increment
  27. counters["aot_autograd"]["total"] += 1
  28. use_fallback = False
  29. if use_fallback:
  30. log.debug("Unable to use AOT Autograd because graph has mutation")
  31. counters["aot_autograd"]["not_ok"] += 1
  32. return gm
  33. # OK attempt to compile
  34. def _wrapped_bw_compiler(*args, **kwargs):
  35. # stop TorchDynamo from trying to compile our generated backwards pass
  36. return disable(disable(bw_compiler)(*args, **kwargs))
  37. bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"]
  38. self.kwargs["bw_compiler"] = _wrapped_bw_compiler
  39. self.kwargs["inference_compiler"] = (
  40. self.kwargs.get("inference_compiler") or self.kwargs["fw_compiler"]
  41. )
  42. from functorch.compile import nop
  43. from torch._inductor.debug import enable_aot_logging
  44. # debug asserts slow down compile time noticeably,
  45. # So only default them on when the aot_eager backend is used.
  46. if self.kwargs.get("fw_compiler", None) == nop:
  47. patch_config = patch("functorch.compile.config.debug_assert", True)
  48. else:
  49. patch_config = contextlib.nullcontext()
  50. try:
  51. # NB: NOT cloned!
  52. with enable_aot_logging(), patch_config:
  53. cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
  54. counters["aot_autograd"]["ok"] += 1
  55. return disable(cg)
  56. except Exception:
  57. counters["aot_autograd"]["not_ok"] += 1
  58. raise
  59. def aot_autograd(**kwargs):
  60. return AotAutograd(**kwargs)
  61. def mem_efficient_fusion_kwargs(use_decomps):
  62. from functorch.compile import (
  63. default_decompositions,
  64. min_cut_rematerialization_partition,
  65. ts_compile,
  66. )
  67. kwargs = {
  68. # these are taken from memory_efficient_fusion()
  69. "fw_compiler": ts_compile,
  70. "bw_compiler": ts_compile,
  71. "partition_fn": min_cut_rematerialization_partition,
  72. }
  73. if use_decomps:
  74. kwargs["decompositions"] = default_decompositions
  75. return kwargs
  76. def fake_tensor_unsupported(fn):
  77. """
  78. Decorator for backends that need real inputs. We swap out fake
  79. tensors for zero tensors.
  80. """
  81. @functools.wraps(fn)
  82. def wrapper(model, inputs, **kwargs):
  83. with _disable_current_modes():
  84. inputs = list(map(defake, inputs))
  85. return fn(model, inputs, **kwargs)
  86. return wrapper
  87. def device_from_inputs(example_inputs) -> torch.device:
  88. for x in example_inputs:
  89. if hasattr(x, "device"):
  90. return x.device
  91. def dtype_from_inputs(example_inputs) -> torch.dtype:
  92. for x in example_inputs:
  93. if hasattr(x, "dtype"):
  94. return x.dtype