| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189 |
- # mypy: allow-untyped-defs
- import inspect
- import re
- import string
- from dataclasses import dataclass, field
- from enum import Enum
- from typing import Any, Dict, List, Optional, Set, Tuple, Union
- from types import ModuleType
- import torch
- _TAGS: Dict[str, Dict[str, Any]] = {
- "torch": {
- "cond": {},
- "dynamic-shape": {},
- "escape-hatch": {},
- "map": {},
- "dynamic-value": {},
- "operator": {},
- "mutation": {},
- },
- "python": {
- "assert": {},
- "builtin": {},
- "closure": {},
- "context-manager": {},
- "control-flow": {},
- "data-structure": {},
- "standard-library": {},
- "object-model": {},
- },
- }
- class SupportLevel(Enum):
- """
- Indicates at what stage the feature
- used in the example is handled in export.
- """
- SUPPORTED = 1
- NOT_SUPPORTED_YET = 0
- class ExportArgs:
- __slots__ = ("args", "kwargs")
- def __init__(self, *args, **kwargs):
- self.args = args
- self.kwargs = kwargs
- InputsType = Union[Tuple[Any, ...], ExportArgs]
- def check_inputs_type(x):
- if not isinstance(x, (ExportArgs, tuple)):
- raise ValueError(
- f"Expecting inputs type to be either a tuple, or ExportArgs, got: {type(x)}"
- )
- def _validate_tag(tag: str):
- parts = tag.split(".")
- t = _TAGS
- for part in parts:
- assert set(part) <= set(
- string.ascii_lowercase + "-"
- ), f"Tag contains invalid characters: {part}"
- if part in t:
- t = t[part]
- else:
- raise ValueError(f"Tag {tag} is not found in registered tags.")
- @dataclass(frozen=True)
- class ExportCase:
- example_inputs: InputsType
- description: str # A description of the use case.
- model: torch.nn.Module
- name: str
- extra_inputs: Optional[InputsType] = None # For testing graph generalization.
- # Tags associated with the use case. (e.g dynamic-shape, escape-hatch)
- tags: Set[str] = field(default_factory=set)
- support_level: SupportLevel = SupportLevel.SUPPORTED
- dynamic_shapes: Optional[Dict[str, Any]] = None
- def __post_init__(self):
- check_inputs_type(self.example_inputs)
- if self.extra_inputs is not None:
- check_inputs_type(self.extra_inputs)
- for tag in self.tags:
- _validate_tag(tag)
- if not isinstance(self.description, str) or len(self.description) == 0:
- raise ValueError(f'Invalid description: "{self.description}"')
- _EXAMPLE_CASES: Dict[str, ExportCase] = {}
- _MODULES: Set[ModuleType] = set()
- _EXAMPLE_CONFLICT_CASES: Dict[str, List[ExportCase]] = {}
- _EXAMPLE_REWRITE_CASES: Dict[str, List[ExportCase]] = {}
- def register_db_case(case: ExportCase) -> None:
- """
- Registers a user provided ExportCase into example bank.
- """
- if case.name in _EXAMPLE_CASES:
- if case.name not in _EXAMPLE_CONFLICT_CASES:
- _EXAMPLE_CONFLICT_CASES[case.name] = [_EXAMPLE_CASES[case.name]]
- _EXAMPLE_CONFLICT_CASES[case.name].append(case)
- return
- _EXAMPLE_CASES[case.name] = case
- def to_snake_case(name):
- name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
- return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
- def _make_export_case(m, name, configs):
- if not issubclass(m, torch.nn.Module):
- raise TypeError("Export case class should be a torch.nn.Module.")
- m = m()
- if "description" not in configs:
- # Fallback to docstring if description is missing.
- assert (
- m.__doc__ is not None
- ), f"Could not find description or docstring for export case: {m}"
- configs = {**configs, "description": m.__doc__}
- return ExportCase(**{**configs, "model": m, "name": name})
- def export_case(**kwargs):
- """
- Decorator for registering a user provided case into example bank.
- """
- def wrapper(m):
- configs = kwargs
- module = inspect.getmodule(m)
- if module in _MODULES:
- raise RuntimeError("export_case should only be used once per example file.")
- assert module is not None
- _MODULES.add(module)
- normalized_name = to_snake_case(m.__name__)
- module_name = module.__name__.split(".")[-1]
- if module_name != normalized_name:
- raise RuntimeError(
- f'Module name "{module.__name__}" is inconsistent with exported program '
- + f'name "{m.__name__}". Please rename the module to "{normalized_name}".'
- )
- case = _make_export_case(m, module_name, configs)
- register_db_case(case)
- return case
- return wrapper
- def export_rewrite_case(**kwargs):
- def wrapper(m):
- configs = kwargs
- parent = configs.pop("parent")
- assert isinstance(parent, ExportCase)
- key = parent.name
- if key not in _EXAMPLE_REWRITE_CASES:
- _EXAMPLE_REWRITE_CASES[key] = []
- configs["example_inputs"] = parent.example_inputs
- case = _make_export_case(m, to_snake_case(m.__name__), configs)
- _EXAMPLE_REWRITE_CASES[key].append(case)
- return case
- return wrapper
- def normalize_inputs(x: InputsType) -> ExportArgs:
- if isinstance(x, tuple):
- return ExportArgs(*x)
- assert isinstance(x, ExportArgs)
- return x
|