| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379 |
- # NOTE: This is a placeholder for iterating on export serialization schema design.
- # Anything is subject to change and no guarantee is provided at this point.
- from dataclasses import dataclass, field
- from enum import IntEnum
- from typing import Dict, List, Optional, Tuple
- from torch._export.serde.union import _Union
- # NOTE: Please update this value if any modifications are made to the schema
- SCHEMA_VERSION = (5, 3)
- TREESPEC_VERSION = 1
- class ScalarType(IntEnum):
- UNKNOWN = 0
- BYTE = 1
- CHAR = 2
- SHORT = 3
- INT = 4
- LONG = 5
- HALF = 6
- FLOAT = 7
- DOUBLE = 8
- COMPLEXHALF = 9
- COMPLEXFLOAT = 10
- COMPLEXDOUBLE = 11
- BOOL = 12
- BFLOAT16 = 13
- class Layout(IntEnum):
- Unknown = 0
- SparseCoo = 1
- SparseCsr = 2
- SparseCsc = 3
- SparseBsr = 4
- SparseBsc = 5
- _mkldnn = 6
- Strided = 7
- class MemoryFormat(IntEnum):
- Unknown = 0
- ContiguousFormat = 1
- ChannelsLast = 2
- ChannelsLast3d = 3
- PreserveFormat = 4
- @dataclass
- class Device:
- type: str
- index: Optional[int] = None
- @dataclass(repr=False)
- class SymExprHint(_Union):
- as_int: int
- as_float: float
- as_bool: bool
- # This is for storing the symbolic expressions behind symints/symfloats/symbools
- # For example, we can get something like
- # SymExpr(expr_str="s0 + s1", hint=SymExprHint(as_int=4)
- # if we also have the hint that s0 and s1 are both 2.
- @dataclass
- class SymExpr:
- expr_str: str
- hint: Optional[SymExprHint] = None
- @dataclass(repr=False)
- class SymInt(_Union):
- as_expr: SymExpr
- as_int: int
- @dataclass(repr=False)
- class SymBool(_Union):
- as_expr: SymExpr
- as_bool: bool
- @dataclass
- class TensorMeta:
- dtype: ScalarType
- sizes: List[SymInt]
- requires_grad: bool
- device: Device
- strides: List[SymInt]
- storage_offset: SymInt
- layout: Layout
- # In most cases we will use the "as_name" field to store arguments which are
- # SymInts.
- # The "as_int" field is used in the case where we have a list containing a mix
- # of SymInt and ints (ex. [1, s0, ...]). We will serialize this type of list to
- # be List[SymIntArgument] and map the SymInts to the "as_name" field, and ints
- # to the "as_int" field.
- @dataclass(repr=False)
- class SymIntArgument(_Union):
- as_name: str
- as_int: int
- # In most cases we will use the "as_name" field to store arguments which are
- # SymBools.
- # The "as_bool" field is used in the case where we have a list containing a mix
- # of SymBool and bools (ex. [True, i0, ...]). We will serialize this type of list to
- # be List[SymboolArgument] and map the SymBools to the "as_name" field, and bools
- # to the "as_bool" field.
- @dataclass(repr=False)
- class SymBoolArgument(_Union):
- as_name: str
- as_bool: bool
- @dataclass
- class TensorArgument:
- name: str
- @dataclass
- class TokenArgument:
- name: str
- # This is use for storing the contents of a list which contain optional tensors
- # (Tensor?[], ex. [Tensor, None, ...]), where the list will be serialized to the
- # type List[OptionalTensorArgument], with tensor values seiralized to the
- # "as_tensor" field, and None values serialized to the "as_none" field.
- @dataclass(repr=False)
- class OptionalTensorArgument(_Union):
- as_tensor: TensorArgument
- as_none: Tuple[()]
- @dataclass
- class GraphArgument:
- name: str
- graph: 'Graph'
- @dataclass
- class CustomObjArgument:
- name: str
- class_fqn: str
- # This is actually a union type
- @dataclass(repr=False)
- class Argument(_Union):
- as_none: Tuple[()]
- as_tensor: TensorArgument
- as_tensors: List[TensorArgument]
- as_int: int
- as_ints: List[int]
- as_float: float
- as_floats: List[float]
- as_string: str
- as_strings: List[str]
- as_sym_int: SymIntArgument
- as_sym_ints: List[SymIntArgument]
- as_scalar_type: ScalarType
- as_memory_format: MemoryFormat
- as_layout: Layout
- as_device: Device
- as_bool: bool
- as_bools: List[bool]
- as_sym_bool: SymBoolArgument
- as_sym_bools: List[SymBoolArgument]
- as_graph: GraphArgument
- as_optional_tensors: List[OptionalTensorArgument]
- as_custom_obj: CustomObjArgument
- as_operator: str
- @dataclass
- class NamedArgument:
- # Argument name from the operator schema
- name: str
- arg: Argument
- @dataclass
- class Node:
- target: str
- inputs: List[NamedArgument]
- outputs: List[Argument]
- metadata: Dict[str, str]
- @dataclass
- class Graph:
- inputs: List[Argument]
- outputs: List[Argument]
- nodes: List[Node]
- tensor_values: Dict[str, TensorMeta]
- sym_int_values: Dict[str, SymInt]
- sym_bool_values: Dict[str, SymBool]
- # This is for deserializing the submodule graphs from higher order ops
- # (ex. cond, map) where single tensor returns will just return a single
- # tensor, rather than following export schema and returning a singleton
- # list.
- is_single_tensor_return: bool = False
- custom_obj_values: Dict[str, CustomObjArgument] = field(default_factory=dict)
- @dataclass
- class UserInputSpec:
- # Actually, only tensors and SymInts are allowed here
- arg: Argument
- @dataclass(repr=False)
- class ConstantValue(_Union):
- as_none: Tuple[()]
- as_int: int
- as_float: float
- as_string: str
- as_bool: bool
- @dataclass
- class ConstantInputSpec:
- name: str
- value: ConstantValue
- @dataclass
- class InputToParameterSpec:
- arg: TensorArgument
- parameter_name: str
- @dataclass
- class InputToBufferSpec:
- arg: TensorArgument
- buffer_name: str
- persistent: bool
- @dataclass
- class InputToTensorConstantSpec:
- arg: TensorArgument
- tensor_constant_name: str
- @dataclass
- class InputToCustomObjSpec:
- arg: CustomObjArgument
- custom_obj_name: str
- @dataclass
- class InputTokenSpec:
- arg: TokenArgument
- @dataclass(repr=False)
- class InputSpec(_Union):
- user_input: UserInputSpec
- parameter: InputToParameterSpec
- buffer: InputToBufferSpec
- tensor_constant: InputToTensorConstantSpec
- custom_obj: InputToCustomObjSpec
- token: InputTokenSpec
- constant_input: ConstantInputSpec
- @dataclass
- class UserOutputSpec:
- arg: Argument
- @dataclass
- class LossOutputSpec:
- arg: TensorArgument
- @dataclass
- class BufferMutationSpec:
- arg: TensorArgument
- buffer_name: str
- @dataclass
- class GradientToParameterSpec:
- arg: TensorArgument
- parameter_name: str
- @dataclass
- class GradientToUserInputSpec:
- arg: TensorArgument
- user_input_name: str
- @dataclass
- class UserInputMutationSpec:
- arg: TensorArgument
- user_input_name: str
- @dataclass
- class OutputTokenSpec:
- arg: TokenArgument
- @dataclass(repr=False)
- class OutputSpec(_Union):
- user_output: UserOutputSpec
- loss_output: LossOutputSpec
- buffer_mutation: BufferMutationSpec
- gradient_to_parameter: GradientToParameterSpec
- gradient_to_user_input: GradientToUserInputSpec
- user_input_mutation: UserInputMutationSpec
- token: OutputTokenSpec
- @dataclass
- class GraphSignature:
- input_specs: List[InputSpec]
- output_specs: List[OutputSpec]
- @dataclass
- class RangeConstraint:
- min_val: int
- max_val: int
- @dataclass
- class ModuleCallSignature:
- inputs: List[Argument]
- outputs: List[Argument]
- # These are serialized by calling pytree.treespec_loads
- # And deserialized by calling pytree.treespec_dumps
- in_spec: str
- out_spec: str
- @dataclass
- class ModuleCallEntry:
- fqn: str
- signature: Optional[ModuleCallSignature] = None
- @dataclass
- class GraphModule:
- graph: Graph
- signature: GraphSignature
- # This is used for unflattening, by tracking the calling structure of all of
- # the modules in order to unflatten the modules back to the eager calling
- # conventions.
- module_call_graph: List[ModuleCallEntry]
- # Invariant: Every time a change is made to the schema, one of the versions
- # should be upadted.
- @dataclass
- class SchemaVersion:
- major: int # Major version number is bumped every time a breaking change is made.
- minor: int # Minor version number is bumped when a compatible change is made.
- @dataclass
- class ExportedProgram:
- graph_module: GraphModule
- # Key is the opset namespace (ex. aten), and value is the version number
- opset_version: Dict[str, int]
- range_constraints: Dict[str, RangeConstraint]
- schema_version: SchemaVersion
- dialect: str
|