schema.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. # NOTE: This is a placeholder for iterating on export serialization schema design.
  2. # Anything is subject to change and no guarantee is provided at this point.
  3. from dataclasses import dataclass, field
  4. from enum import IntEnum
  5. from typing import Dict, List, Optional, Tuple
  6. from torch._export.serde.union import _Union
  7. # NOTE: Please update this value if any modifications are made to the schema
  8. SCHEMA_VERSION = (5, 3)
  9. TREESPEC_VERSION = 1
  10. class ScalarType(IntEnum):
  11. UNKNOWN = 0
  12. BYTE = 1
  13. CHAR = 2
  14. SHORT = 3
  15. INT = 4
  16. LONG = 5
  17. HALF = 6
  18. FLOAT = 7
  19. DOUBLE = 8
  20. COMPLEXHALF = 9
  21. COMPLEXFLOAT = 10
  22. COMPLEXDOUBLE = 11
  23. BOOL = 12
  24. BFLOAT16 = 13
  25. class Layout(IntEnum):
  26. Unknown = 0
  27. SparseCoo = 1
  28. SparseCsr = 2
  29. SparseCsc = 3
  30. SparseBsr = 4
  31. SparseBsc = 5
  32. _mkldnn = 6
  33. Strided = 7
  34. class MemoryFormat(IntEnum):
  35. Unknown = 0
  36. ContiguousFormat = 1
  37. ChannelsLast = 2
  38. ChannelsLast3d = 3
  39. PreserveFormat = 4
  40. @dataclass
  41. class Device:
  42. type: str
  43. index: Optional[int] = None
  44. @dataclass(repr=False)
  45. class SymExprHint(_Union):
  46. as_int: int
  47. as_float: float
  48. as_bool: bool
  49. # This is for storing the symbolic expressions behind symints/symfloats/symbools
  50. # For example, we can get something like
  51. # SymExpr(expr_str="s0 + s1", hint=SymExprHint(as_int=4)
  52. # if we also have the hint that s0 and s1 are both 2.
  53. @dataclass
  54. class SymExpr:
  55. expr_str: str
  56. hint: Optional[SymExprHint] = None
  57. @dataclass(repr=False)
  58. class SymInt(_Union):
  59. as_expr: SymExpr
  60. as_int: int
  61. @dataclass(repr=False)
  62. class SymBool(_Union):
  63. as_expr: SymExpr
  64. as_bool: bool
  65. @dataclass
  66. class TensorMeta:
  67. dtype: ScalarType
  68. sizes: List[SymInt]
  69. requires_grad: bool
  70. device: Device
  71. strides: List[SymInt]
  72. storage_offset: SymInt
  73. layout: Layout
  74. # In most cases we will use the "as_name" field to store arguments which are
  75. # SymInts.
  76. # The "as_int" field is used in the case where we have a list containing a mix
  77. # of SymInt and ints (ex. [1, s0, ...]). We will serialize this type of list to
  78. # be List[SymIntArgument] and map the SymInts to the "as_name" field, and ints
  79. # to the "as_int" field.
  80. @dataclass(repr=False)
  81. class SymIntArgument(_Union):
  82. as_name: str
  83. as_int: int
  84. # In most cases we will use the "as_name" field to store arguments which are
  85. # SymBools.
  86. # The "as_bool" field is used in the case where we have a list containing a mix
  87. # of SymBool and bools (ex. [True, i0, ...]). We will serialize this type of list to
  88. # be List[SymboolArgument] and map the SymBools to the "as_name" field, and bools
  89. # to the "as_bool" field.
  90. @dataclass(repr=False)
  91. class SymBoolArgument(_Union):
  92. as_name: str
  93. as_bool: bool
  94. @dataclass
  95. class TensorArgument:
  96. name: str
  97. @dataclass
  98. class TokenArgument:
  99. name: str
  100. # This is use for storing the contents of a list which contain optional tensors
  101. # (Tensor?[], ex. [Tensor, None, ...]), where the list will be serialized to the
  102. # type List[OptionalTensorArgument], with tensor values seiralized to the
  103. # "as_tensor" field, and None values serialized to the "as_none" field.
  104. @dataclass(repr=False)
  105. class OptionalTensorArgument(_Union):
  106. as_tensor: TensorArgument
  107. as_none: Tuple[()]
  108. @dataclass
  109. class GraphArgument:
  110. name: str
  111. graph: 'Graph'
  112. @dataclass
  113. class CustomObjArgument:
  114. name: str
  115. class_fqn: str
  116. # This is actually a union type
  117. @dataclass(repr=False)
  118. class Argument(_Union):
  119. as_none: Tuple[()]
  120. as_tensor: TensorArgument
  121. as_tensors: List[TensorArgument]
  122. as_int: int
  123. as_ints: List[int]
  124. as_float: float
  125. as_floats: List[float]
  126. as_string: str
  127. as_strings: List[str]
  128. as_sym_int: SymIntArgument
  129. as_sym_ints: List[SymIntArgument]
  130. as_scalar_type: ScalarType
  131. as_memory_format: MemoryFormat
  132. as_layout: Layout
  133. as_device: Device
  134. as_bool: bool
  135. as_bools: List[bool]
  136. as_sym_bool: SymBoolArgument
  137. as_sym_bools: List[SymBoolArgument]
  138. as_graph: GraphArgument
  139. as_optional_tensors: List[OptionalTensorArgument]
  140. as_custom_obj: CustomObjArgument
  141. as_operator: str
  142. @dataclass
  143. class NamedArgument:
  144. # Argument name from the operator schema
  145. name: str
  146. arg: Argument
  147. @dataclass
  148. class Node:
  149. target: str
  150. inputs: List[NamedArgument]
  151. outputs: List[Argument]
  152. metadata: Dict[str, str]
  153. @dataclass
  154. class Graph:
  155. inputs: List[Argument]
  156. outputs: List[Argument]
  157. nodes: List[Node]
  158. tensor_values: Dict[str, TensorMeta]
  159. sym_int_values: Dict[str, SymInt]
  160. sym_bool_values: Dict[str, SymBool]
  161. # This is for deserializing the submodule graphs from higher order ops
  162. # (ex. cond, map) where single tensor returns will just return a single
  163. # tensor, rather than following export schema and returning a singleton
  164. # list.
  165. is_single_tensor_return: bool = False
  166. custom_obj_values: Dict[str, CustomObjArgument] = field(default_factory=dict)
  167. @dataclass
  168. class UserInputSpec:
  169. # Actually, only tensors and SymInts are allowed here
  170. arg: Argument
  171. @dataclass(repr=False)
  172. class ConstantValue(_Union):
  173. as_none: Tuple[()]
  174. as_int: int
  175. as_float: float
  176. as_string: str
  177. as_bool: bool
  178. @dataclass
  179. class ConstantInputSpec:
  180. name: str
  181. value: ConstantValue
  182. @dataclass
  183. class InputToParameterSpec:
  184. arg: TensorArgument
  185. parameter_name: str
  186. @dataclass
  187. class InputToBufferSpec:
  188. arg: TensorArgument
  189. buffer_name: str
  190. persistent: bool
  191. @dataclass
  192. class InputToTensorConstantSpec:
  193. arg: TensorArgument
  194. tensor_constant_name: str
  195. @dataclass
  196. class InputToCustomObjSpec:
  197. arg: CustomObjArgument
  198. custom_obj_name: str
  199. @dataclass
  200. class InputTokenSpec:
  201. arg: TokenArgument
  202. @dataclass(repr=False)
  203. class InputSpec(_Union):
  204. user_input: UserInputSpec
  205. parameter: InputToParameterSpec
  206. buffer: InputToBufferSpec
  207. tensor_constant: InputToTensorConstantSpec
  208. custom_obj: InputToCustomObjSpec
  209. token: InputTokenSpec
  210. constant_input: ConstantInputSpec
  211. @dataclass
  212. class UserOutputSpec:
  213. arg: Argument
  214. @dataclass
  215. class LossOutputSpec:
  216. arg: TensorArgument
  217. @dataclass
  218. class BufferMutationSpec:
  219. arg: TensorArgument
  220. buffer_name: str
  221. @dataclass
  222. class GradientToParameterSpec:
  223. arg: TensorArgument
  224. parameter_name: str
  225. @dataclass
  226. class GradientToUserInputSpec:
  227. arg: TensorArgument
  228. user_input_name: str
  229. @dataclass
  230. class UserInputMutationSpec:
  231. arg: TensorArgument
  232. user_input_name: str
  233. @dataclass
  234. class OutputTokenSpec:
  235. arg: TokenArgument
  236. @dataclass(repr=False)
  237. class OutputSpec(_Union):
  238. user_output: UserOutputSpec
  239. loss_output: LossOutputSpec
  240. buffer_mutation: BufferMutationSpec
  241. gradient_to_parameter: GradientToParameterSpec
  242. gradient_to_user_input: GradientToUserInputSpec
  243. user_input_mutation: UserInputMutationSpec
  244. token: OutputTokenSpec
  245. @dataclass
  246. class GraphSignature:
  247. input_specs: List[InputSpec]
  248. output_specs: List[OutputSpec]
  249. @dataclass
  250. class RangeConstraint:
  251. min_val: int
  252. max_val: int
  253. @dataclass
  254. class ModuleCallSignature:
  255. inputs: List[Argument]
  256. outputs: List[Argument]
  257. # These are serialized by calling pytree.treespec_loads
  258. # And deserialized by calling pytree.treespec_dumps
  259. in_spec: str
  260. out_spec: str
  261. @dataclass
  262. class ModuleCallEntry:
  263. fqn: str
  264. signature: Optional[ModuleCallSignature] = None
  265. @dataclass
  266. class GraphModule:
  267. graph: Graph
  268. signature: GraphSignature
  269. # This is used for unflattening, by tracking the calling structure of all of
  270. # the modules in order to unflatten the modules back to the eager calling
  271. # conventions.
  272. module_call_graph: List[ModuleCallEntry]
  273. # Invariant: Every time a change is made to the schema, one of the versions
  274. # should be upadted.
  275. @dataclass
  276. class SchemaVersion:
  277. major: int # Major version number is bumped every time a breaking change is made.
  278. minor: int # Minor version number is bumped when a compatible change is made.
  279. @dataclass
  280. class ExportedProgram:
  281. graph_module: GraphModule
  282. # Key is the opset namespace (ex. aten), and value is the version number
  283. opset_version: Dict[str, int]
  284. range_constraints: Dict[str, RangeConstraint]
  285. schema_version: SchemaVersion
  286. dialect: str