| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140 |
- # mypy: allow-untyped-defs
- import logging
- import warnings
- from typing import Any, Dict, Iterable, Optional, Tuple
- import torch
- import torch.export
- import torch.export._trace
- from torch._utils_internal import log_export_usage
- log = logging.getLogger(__name__)
- __all__ = ["report_exportability"]
- def _generate_inputs_for_submodules(
- model: torch.nn.Module,
- target_submodules: Iterable[str],
- args: Tuple[Any, ...],
- kwargs: Optional[Dict[str, Any]] = None,
- ) -> Dict[str, Tuple[Any, Any]]:
- """
- Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
- function doesn't work.
- Args:
- model: root model.
- inputs: inputs to the root model.
- target_submodules: submodules that we want to generate inputs for.
- Returns:
- A dict that maps from submodule name to its inputs.
- """
- kwargs = kwargs or {}
- handles = []
- results = {}
- submodule_to_names = {mod: name for name, mod in model.named_modules()}
- def pre_forward(module, module_args, module_kwargs):
- results[submodule_to_names[module]] = (module_args, module_kwargs)
- try:
- for name, mod in model.named_modules():
- if name in target_submodules:
- handles.append(
- mod.register_forward_pre_hook(pre_forward, with_kwargs=True)
- )
- model(*args, **kwargs)
- except Exception as e:
- warnings.warn(
- f"Failed to generate submodule inputs because of the following error:\n{e}"
- )
- finally:
- for h in handles:
- h.remove()
- return results
- def report_exportability(
- mod: torch.nn.Module,
- args: Tuple[Any, ...],
- kwargs: Optional[Dict[str, Any]] = None,
- *,
- strict: bool = True,
- pre_dispatch: bool = False,
- ) -> Dict[str, Optional[Exception]]:
- """
- Report exportability issues for a module in one-shot.
- Args:
- mod: root module.
- args: args to the root module.
- kwargs: kwargs to the root module.
- Returns:
- A dict that maps from submodule name to the exception that was raised when trying to export it.
- `None` means the module is exportable without issue.
- Sample output:
- {
- '': UnsupportedOperatorException(func=<OpOverload(op='testlib.op_missing_meta', overload='default')>),
- 'submod_1': UnsupportedOperatorException(func=<OpOverload(op='testlib.op_missing_meta', overload='default')>),
- 'submod_2': None
- }
- """
- log_export_usage(event="export.report_exportability")
- kwargs = kwargs or {}
- all_submod_names = [name for name, _ in mod.named_modules() if name != ""]
- submod_inputs = _generate_inputs_for_submodules(mod, all_submod_names, args, kwargs)
- report: Dict[str, Optional[Exception]] = {}
- def try_export(module, module_name, args, kwargs):
- nonlocal submod_inputs, report, strict, pre_dispatch
- if args is not None or kwargs is not None:
- try:
- torch.export._trace._export(
- module,
- args,
- kwargs,
- strict=strict,
- pre_dispatch=pre_dispatch,
- )
- report[module_name] = None
- log.info("Successfully exported `%s`", module_name)
- return
- except Exception as e:
- short_msg = repr(e).split("\n")[0]
- log.warning(
- "Failed exporting `%s` with exception: %s", module_name, short_msg
- )
- report[module_name] = e
- for name, submod in module.named_children():
- sub_module_name = name if module_name == "" else f"{module_name}.{name}"
- submod_args, submod_kwargs = submod_inputs.get(
- sub_module_name, (None, None)
- )
- try_export(submod, sub_module_name, submod_args, submod_kwargs)
- return
- try_export(mod, "", args, kwargs)
- unique_issues = set()
- for exception in report.values():
- if exception is not None:
- key = repr(exception).split("\\n")[0]
- unique_issues.add(key)
- log.warning("Found %d export issues:", len(unique_issues))
- for issue in unique_issues:
- log.warning(issue)
- return report
|