tools.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. # mypy: allow-untyped-defs
  2. import logging
  3. import warnings
  4. from typing import Any, Dict, Iterable, Optional, Tuple
  5. import torch
  6. import torch.export
  7. import torch.export._trace
  8. from torch._utils_internal import log_export_usage
  9. log = logging.getLogger(__name__)
  10. __all__ = ["report_exportability"]
  11. def _generate_inputs_for_submodules(
  12. model: torch.nn.Module,
  13. target_submodules: Iterable[str],
  14. args: Tuple[Any, ...],
  15. kwargs: Optional[Dict[str, Any]] = None,
  16. ) -> Dict[str, Tuple[Any, Any]]:
  17. """
  18. Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
  19. function doesn't work.
  20. Args:
  21. model: root model.
  22. inputs: inputs to the root model.
  23. target_submodules: submodules that we want to generate inputs for.
  24. Returns:
  25. A dict that maps from submodule name to its inputs.
  26. """
  27. kwargs = kwargs or {}
  28. handles = []
  29. results = {}
  30. submodule_to_names = {mod: name for name, mod in model.named_modules()}
  31. def pre_forward(module, module_args, module_kwargs):
  32. results[submodule_to_names[module]] = (module_args, module_kwargs)
  33. try:
  34. for name, mod in model.named_modules():
  35. if name in target_submodules:
  36. handles.append(
  37. mod.register_forward_pre_hook(pre_forward, with_kwargs=True)
  38. )
  39. model(*args, **kwargs)
  40. except Exception as e:
  41. warnings.warn(
  42. f"Failed to generate submodule inputs because of the following error:\n{e}"
  43. )
  44. finally:
  45. for h in handles:
  46. h.remove()
  47. return results
  48. def report_exportability(
  49. mod: torch.nn.Module,
  50. args: Tuple[Any, ...],
  51. kwargs: Optional[Dict[str, Any]] = None,
  52. *,
  53. strict: bool = True,
  54. pre_dispatch: bool = False,
  55. ) -> Dict[str, Optional[Exception]]:
  56. """
  57. Report exportability issues for a module in one-shot.
  58. Args:
  59. mod: root module.
  60. args: args to the root module.
  61. kwargs: kwargs to the root module.
  62. Returns:
  63. A dict that maps from submodule name to the exception that was raised when trying to export it.
  64. `None` means the module is exportable without issue.
  65. Sample output:
  66. {
  67. '': UnsupportedOperatorException(func=<OpOverload(op='testlib.op_missing_meta', overload='default')>),
  68. 'submod_1': UnsupportedOperatorException(func=<OpOverload(op='testlib.op_missing_meta', overload='default')>),
  69. 'submod_2': None
  70. }
  71. """
  72. log_export_usage(event="export.report_exportability")
  73. kwargs = kwargs or {}
  74. all_submod_names = [name for name, _ in mod.named_modules() if name != ""]
  75. submod_inputs = _generate_inputs_for_submodules(mod, all_submod_names, args, kwargs)
  76. report: Dict[str, Optional[Exception]] = {}
  77. def try_export(module, module_name, args, kwargs):
  78. nonlocal submod_inputs, report, strict, pre_dispatch
  79. if args is not None or kwargs is not None:
  80. try:
  81. torch.export._trace._export(
  82. module,
  83. args,
  84. kwargs,
  85. strict=strict,
  86. pre_dispatch=pre_dispatch,
  87. )
  88. report[module_name] = None
  89. log.info("Successfully exported `%s`", module_name)
  90. return
  91. except Exception as e:
  92. short_msg = repr(e).split("\n")[0]
  93. log.warning(
  94. "Failed exporting `%s` with exception: %s", module_name, short_msg
  95. )
  96. report[module_name] = e
  97. for name, submod in module.named_children():
  98. sub_module_name = name if module_name == "" else f"{module_name}.{name}"
  99. submod_args, submod_kwargs = submod_inputs.get(
  100. sub_module_name, (None, None)
  101. )
  102. try_export(submod, sub_module_name, submod_args, submod_kwargs)
  103. return
  104. try_export(mod, "", args, kwargs)
  105. unique_issues = set()
  106. for exception in report.values():
  107. if exception is not None:
  108. key = repr(exception).split("\\n")[0]
  109. unique_issues.add(key)
  110. log.warning("Found %d export issues:", len(unique_issues))
  111. for issue in unique_issues:
  112. log.warning(issue)
  113. return report