serialize.py 111 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873
  1. # mypy: allow-untyped-defs
  2. import base64
  3. import copy
  4. import copyreg
  5. import dataclasses
  6. import heapq
  7. import inspect
  8. import io
  9. import json
  10. import logging
  11. import math
  12. import operator
  13. import re
  14. import typing
  15. from contextlib import contextmanager
  16. from dataclasses import dataclass, field
  17. from enum import Enum
  18. from typing import (
  19. Any,
  20. Callable,
  21. cast,
  22. Dict,
  23. final,
  24. Iterator,
  25. List,
  26. Optional,
  27. Set,
  28. Tuple,
  29. Union,
  30. Type,
  31. )
  32. import sympy
  33. import torch
  34. import torch.export.exported_program as ep
  35. from torch._export.serde.schema import SchemaVersion
  36. from torch._export.verifier import load_verifier
  37. from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
  38. from torch.fx.experimental import symbolic_shapes
  39. from torch.utils import _pytree as pytree
  40. from torch.utils._pytree import treespec_dumps, treespec_loads
  41. from torch.utils._sympy.value_ranges import ValueRanges
  42. from .schema import ( # type: ignore[attr-defined]
  43. Argument,
  44. BufferMutationSpec,
  45. ConstantInputSpec,
  46. ConstantValue,
  47. CustomObjArgument,
  48. Device,
  49. ExportedProgram,
  50. GradientToParameterSpec,
  51. GradientToUserInputSpec,
  52. Graph,
  53. GraphArgument,
  54. GraphModule,
  55. GraphSignature,
  56. InputSpec,
  57. InputToBufferSpec,
  58. InputToCustomObjSpec,
  59. InputTokenSpec,
  60. InputToParameterSpec,
  61. InputToTensorConstantSpec,
  62. Layout,
  63. LossOutputSpec,
  64. MemoryFormat,
  65. ModuleCallEntry,
  66. ModuleCallSignature,
  67. NamedArgument,
  68. Node,
  69. OptionalTensorArgument,
  70. OutputSpec,
  71. OutputTokenSpec,
  72. RangeConstraint,
  73. ScalarType,
  74. SCHEMA_VERSION,
  75. SymBool,
  76. SymBoolArgument,
  77. SymExpr,
  78. SymExprHint,
  79. SymInt,
  80. SymIntArgument,
  81. TensorArgument,
  82. TensorMeta,
  83. TokenArgument,
  84. TREESPEC_VERSION,
  85. UserInputMutationSpec,
  86. UserInputSpec,
  87. UserOutputSpec,
  88. )
  89. from .union import _Union
  90. __all__ = [
  91. "serialize",
  92. "GraphModuleSerializer",
  93. "ExportedProgramSerializer",
  94. "GraphModuleDeserializer",
  95. "ExportedProgramDeserializer",
  96. ]
  97. log = logging.getLogger(__name__)
  98. class SerializeError(RuntimeError):
  99. pass
  100. def _reverse_map(d: Dict[Any, Enum]):
  101. return {v.value: k for k, v in d.items()}
  102. MetaType = Union[
  103. FakeTensor, int, torch.SymInt, bool, torch.SymBool, ep.CustomObjArgument
  104. ]
  105. ST_DELIMITER = ";"
  106. _TORCH_TO_SERIALIZE_DTYPE = {
  107. torch.uint8: ScalarType.BYTE,
  108. torch.int8: ScalarType.CHAR,
  109. torch.int16: ScalarType.SHORT,
  110. torch.int32: ScalarType.INT,
  111. torch.int64: ScalarType.LONG,
  112. torch.float16: ScalarType.HALF,
  113. torch.float32: ScalarType.FLOAT,
  114. torch.float64: ScalarType.DOUBLE,
  115. torch.complex32: ScalarType.COMPLEXHALF,
  116. torch.complex64: ScalarType.COMPLEXFLOAT,
  117. torch.complex128: ScalarType.COMPLEXDOUBLE,
  118. torch.bool: ScalarType.BOOL,
  119. torch.bfloat16: ScalarType.BFLOAT16,
  120. }
  121. _SERIALIZE_TO_TORCH_DTYPE = _reverse_map(_TORCH_TO_SERIALIZE_DTYPE) # type: ignore[arg-type]
  122. _TORCH_TO_SERIALIZE_LAYOUT = {
  123. torch.sparse_coo: Layout.SparseCoo,
  124. torch.sparse_csr: Layout.SparseCsr,
  125. torch.sparse_csc: Layout.SparseCsc,
  126. torch.sparse_bsr: Layout.SparseBsr,
  127. torch.sparse_bsc: Layout.SparseBsc,
  128. torch._mkldnn: Layout._mkldnn, # type: ignore[attr-defined]
  129. torch.strided: Layout.Strided,
  130. }
  131. _SERIALIZE_TO_TORCH_LAYOUT = _reverse_map(_TORCH_TO_SERIALIZE_LAYOUT) # type: ignore[arg-type]
  132. _TORCH_TO_SERIALIZE_MEMORY_FORMAT = {
  133. torch.contiguous_format: MemoryFormat.ContiguousFormat,
  134. torch.channels_last: MemoryFormat.ChannelsLast,
  135. torch.channels_last_3d: MemoryFormat.ChannelsLast3d,
  136. torch.preserve_format: MemoryFormat.PreserveFormat,
  137. }
  138. _SERIALIZE_TO_TORCH_MEMORY_FORMAT = _reverse_map(_TORCH_TO_SERIALIZE_MEMORY_FORMAT) # type: ignore[arg-type]
  139. _SYM_INT_OPS = {
  140. operator.mul,
  141. operator.add,
  142. operator.sub,
  143. operator.floordiv,
  144. operator.mod,
  145. torch.sym_int,
  146. torch.sym_float,
  147. torch.sym_ite,
  148. torch.sym_max,
  149. torch.sym_min,
  150. torch.sym_sqrt,
  151. }
  152. _SYM_BOOL_OPS = {
  153. operator.eq,
  154. operator.ne,
  155. operator.le,
  156. operator.ge,
  157. operator.lt,
  158. operator.gt,
  159. torch.sym_not,
  160. }
  161. @dataclass
  162. class SerializedArtifact:
  163. exported_program: bytes
  164. state_dict: bytes
  165. constants: bytes
  166. example_inputs: bytes
  167. @dataclass
  168. class _SerializedProgram:
  169. exported_program: ExportedProgram
  170. state_dict: bytes
  171. constants: bytes
  172. example_inputs: bytes
  173. def deserialize_device(d: Device) -> torch.device:
  174. if d.index is None:
  175. return torch.device(type=d.type) # type: ignore[call-overload]
  176. return torch.device(type=d.type, index=d.index)
  177. def serialize_sym_int(s: Union[int, torch.SymInt]) -> SymInt:
  178. if isinstance(s, (torch.SymInt, int)):
  179. if symbolic_shapes.is_concrete_int(s):
  180. return SymInt.create(as_int=int(s))
  181. else:
  182. assert isinstance(s, torch.SymInt)
  183. if s.node.hint is None:
  184. return SymInt.create(as_expr=SymExpr(str(s)))
  185. else:
  186. return SymInt.create(
  187. as_expr=SymExpr(str(s), hint=SymExprHint.create(as_int=s.node.hint))
  188. )
  189. else:
  190. raise SerializeError(
  191. f"SymInt should be either symbol or int, got `{s}` of type `{type(s)}`"
  192. )
  193. def serialize_sym_bool(s: Union[bool, torch.SymBool]) -> SymBool:
  194. if isinstance(s, (torch.SymBool, bool)):
  195. if symbolic_shapes.is_concrete_bool(s):
  196. return SymBool.create(as_bool=bool(s))
  197. else:
  198. return SymBool.create(as_expr=SymExpr(expr_str=str(s)))
  199. else:
  200. raise SerializeError(
  201. f"SymBool should be either symbol or bool, got `{s}` of type `{type(s)}`"
  202. )
  203. def serialize_tensor_meta(t: torch.Tensor) -> TensorMeta:
  204. """
  205. Extract a TensorMeta describing `t`.
  206. """
  207. return TensorMeta(
  208. dtype=_TORCH_TO_SERIALIZE_DTYPE[t.dtype],
  209. sizes=[serialize_sym_int(s) for s in t.shape],
  210. requires_grad=t.requires_grad,
  211. device=Device(type=t.device.type, index=t.device.index),
  212. strides=[serialize_sym_int(s) for s in t.stride()],
  213. storage_offset=serialize_sym_int(0), # TODO needs to be fixed.
  214. layout=_TORCH_TO_SERIALIZE_LAYOUT[t.layout],
  215. )
  216. _CURRENT_DESERIALIZER: Optional["GraphModuleDeserializer"] = None
  217. def _reduce_fake_tensor(fake_tensor: FakeTensor):
  218. is_parameter = isinstance(fake_tensor, torch.nn.Parameter)
  219. tensor_meta = serialize_tensor_meta(fake_tensor)
  220. tensor_meta_bytes = json.dumps(
  221. _dataclass_to_dict(tensor_meta), cls=EnumEncoder
  222. ).encode("utf-8")
  223. return _reconstruct_fake_tensor, (tensor_meta_bytes, is_parameter)
  224. def _reconstruct_fake_tensor(
  225. serialized_tensor_meta: bytes, is_parameter: bool
  226. ) -> FakeTensor:
  227. # Deserialize the bytes into a TensorMeta
  228. json_tensor_meta = json.loads(serialized_tensor_meta.decode("utf-8"))
  229. tensor_meta = _dict_to_dataclass(TensorMeta, json_tensor_meta)
  230. # Find the current fake mode
  231. assert (
  232. _CURRENT_DESERIALIZER is not None
  233. ), "Need access to current deserializer state"
  234. fake_tensor = _CURRENT_DESERIALIZER.deserialize_tensor_meta(tensor_meta)
  235. if is_parameter:
  236. fake_tensor = torch.nn.Parameter(fake_tensor) # type: ignore[assignment]
  237. return fake_tensor
  238. def serialize_torch_artifact(artifact: Optional[Any]) -> bytes:
  239. if artifact is None:
  240. return b""
  241. assert (
  242. FakeTensor not in copyreg.dispatch_table
  243. ), "Refusing to stomp on existing FakeTensor reducer"
  244. try:
  245. copyreg.pickle(FakeTensor, _reduce_fake_tensor)
  246. buffer = io.BytesIO()
  247. # This is a workaround for backend's tensor deserialization problem:
  248. # unpickleTensor() always create a tensor on the device where it was originally saved
  249. # This behavior is bad for multi-gpu training, as we wish to directly load the tensor
  250. # on the designated device.
  251. # For now, we simply move the tensor to cpu before saving.
  252. # TODO: this should be fixed by deserialization instead.
  253. torch.save(artifact, buffer)
  254. return buffer.getvalue()
  255. finally:
  256. del copyreg.dispatch_table[FakeTensor]
  257. def deserialize_torch_artifact(serialized: Union[Dict[str, Any], Tuple[Any, ...], bytes]):
  258. if isinstance(serialized, (dict, tuple)):
  259. return serialized
  260. if len(serialized) == 0:
  261. return {}
  262. buffer = io.BytesIO(serialized)
  263. buffer.seek(0)
  264. artifact = torch.load(buffer)
  265. assert isinstance(artifact, (tuple, dict))
  266. return artifact
  267. def _sympy_int_to_int(val: sympy.Expr, adjust: str):
  268. # Convert simple sympy Integers into concrete int
  269. if val == sympy.oo:
  270. return math.inf
  271. if val == -sympy.oo:
  272. return -math.inf
  273. if isinstance(val, sympy.Integer):
  274. return int(val)
  275. # TODO: Remove this adjustment when Ed gets rid of fractional ranges
  276. log.warning(
  277. "Export constraints cannot be non-integer expressions. Found "
  278. "type %s, and value %s. We will attempt to %s "
  279. "this value.", type(val), val, adjust
  280. )
  281. if adjust == "floor":
  282. return math.floor(val)
  283. elif adjust == "ceil":
  284. return math.ceil(val)
  285. else:
  286. raise RuntimeError(f"Got invalid adjustment {adjust}")
  287. def _int_to_sympy_int(val) -> sympy.Expr:
  288. # Convert concrete int into simple sympy Integers
  289. if val == math.inf:
  290. return sympy.oo
  291. if val == -math.inf:
  292. return -sympy.oo
  293. return sympy.Integer(val)
  294. def serialize_range_constraints(
  295. range_constraints: Dict[sympy.Symbol, ValueRanges]
  296. ) -> Dict[str, RangeConstraint]:
  297. return {
  298. str(k): RangeConstraint(
  299. _sympy_int_to_int(v.lower, "ceil"), # type: ignore[arg-type]
  300. _sympy_int_to_int(v.upper, "floor"), # type: ignore[arg-type]
  301. )
  302. for k, v in range_constraints.items()
  303. }
  304. def _get_schema_from_target(target):
  305. if isinstance(target, torch._ops.OpOverload):
  306. return target._schema
  307. elif type(target) in _serialization_registry:
  308. return _serialization_registry[type(target)].op_schema(type(target))
  309. raise RuntimeError(f"Cannot find schema for {type(target)}")
  310. def _is_single_tensor_return(target: torch._ops.OpOverload) -> bool:
  311. schema = _get_schema_from_target(target)
  312. returns = schema.returns
  313. return len(returns) == 1 and isinstance(returns[0].real_type, torch.TensorType)
  314. def _is_single_tensor_list_return(target: Any) -> bool:
  315. schema = _get_schema_from_target(target)
  316. returns = schema.returns
  317. if len(returns) != 1:
  318. return False
  319. return_type = returns[0].real_type
  320. return isinstance(return_type, torch.ListType) and isinstance(
  321. return_type.getElementType(), torch.TensorType
  322. )
  323. @dataclass
  324. class GraphState:
  325. inputs: List[Argument] = field(default_factory=list)
  326. outputs: List[Argument] = field(default_factory=list)
  327. nodes: List[Node] = field(default_factory=list)
  328. tensor_values: Dict[str, TensorMeta] = field(default_factory=dict)
  329. sym_int_values: Dict[str, SymInt] = field(default_factory=dict)
  330. sym_bool_values: Dict[str, SymBool] = field(default_factory=dict)
  331. is_single_tensor_return: bool = False
  332. custom_obj_values: Dict[str, CustomObjArgument] = field(default_factory=dict)
  333. class Final(type):
  334. def __new__(metacls, name, bases, classdict):
  335. for b in bases:
  336. if isinstance(b, Final):
  337. raise TypeError(f"type '{b.__name__}' is not an acceptable base type")
  338. return type.__new__(metacls, name, bases, dict(classdict))
  339. @final
  340. class GraphModuleSerializer(metaclass=Final):
  341. def __init__(
  342. self,
  343. graph_signature: ep.ExportGraphSignature,
  344. module_call_graph: List[ep.ModuleCallEntry],
  345. ):
  346. self.graph_state = GraphState()
  347. self.graph_signature = graph_signature
  348. self.module_call_graph = module_call_graph
  349. self.custom_objs: Dict[str, torch._C.ScriptObject] = {}
  350. self.duplicate_getitem_nodes: Dict[str, str] = {}
  351. @contextmanager
  352. def save_graph_state(self):
  353. saved = self.graph_state
  354. self.graph_state = GraphState()
  355. try:
  356. yield
  357. finally:
  358. self.graph_state = saved
  359. def handle_placeholder(self, node: torch.fx.Node):
  360. assert node.op == "placeholder"
  361. if isinstance(node.meta["val"], torch.Tensor):
  362. graph_input = Argument.create(as_tensor=TensorArgument(name=node.name))
  363. self.graph_state.tensor_values[node.name] = serialize_tensor_meta(
  364. node.meta["val"]
  365. )
  366. elif isinstance(node.meta["val"], torch.SymInt):
  367. raise AssertionError("SymInt graph input is not implemented yet.")
  368. elif isinstance(node.meta["val"], (int, bool, str, float, type(None))):
  369. graph_input = self.serialize_input(node.meta["val"])
  370. elif isinstance(node.meta["val"], ep.CustomObjArgument):
  371. class_fqn = node.meta["val"].class_fqn
  372. graph_input = Argument.create(
  373. as_custom_obj=CustomObjArgument(name=node.name, class_fqn=class_fqn)
  374. )
  375. self.graph_state.custom_obj_values[node.name] = (
  376. self.serialize_script_obj_meta(node.meta["val"])
  377. )
  378. else:
  379. raise AssertionError(f"Unimplemented graph input type: {node.meta['val']}")
  380. self.graph_state.inputs.append(graph_input)
  381. def handle_output(self, node: torch.fx.Node):
  382. assert node.op == "output"
  383. assert len(node.args) == 1, "FX.Node's args should have one arg"
  384. node_args = node.args[0]
  385. if isinstance(node_args, torch.fx.Node):
  386. # For singleton tensor returns
  387. self.graph_state.is_single_tensor_return = True
  388. self.graph_state.outputs = [self.serialize_input(node_args)]
  389. else:
  390. assert isinstance(node_args, (tuple, list))
  391. self.graph_state.outputs = [self.serialize_input(arg) for arg in node_args]
  392. def serialize_operator(self, target) -> str:
  393. if isinstance(target, str):
  394. return target
  395. elif target.__module__.startswith("torch._ops"):
  396. # TODO(zhxchen17) Maybe provide a function name helper in FX.
  397. # From torch.fx.node._get_qualified_name
  398. module = target.__module__.replace("torch._ops", "torch.ops")
  399. return f"{module}.{target.__name__}"
  400. else: # TODO(zhxchen17) Don't catch all here.
  401. return f"{target.__module__}.{target.__name__}"
  402. def handle_call_function(self, node: torch.fx.Node):
  403. assert node.op == "call_function"
  404. # getitem has been handled in the producer node, skip it here
  405. if node.target is operator.getitem:
  406. return
  407. if node.target in _SYM_INT_OPS:
  408. assert len(node.kwargs) == 0
  409. meta_val = node.meta["val"]
  410. ex_node = Node(
  411. target=self.serialize_operator(node.target),
  412. inputs=self.serialize_sym_op_inputs(node.target, node.args),
  413. outputs=[
  414. Argument.create(
  415. as_sym_int=self.serialize_sym_int_output(node.name, meta_val)
  416. )
  417. ],
  418. metadata=self.serialize_metadata(node),
  419. )
  420. elif node.target in _SYM_BOOL_OPS:
  421. assert len(node.kwargs) == 0
  422. meta_val = node.meta["val"]
  423. ex_node = Node(
  424. target=self.serialize_operator(node.target),
  425. inputs=self.serialize_sym_op_inputs(node.target, node.args),
  426. outputs=[
  427. Argument.create(
  428. as_sym_bool=self.serialize_sym_bool_output(node.name, meta_val)
  429. )
  430. ],
  431. metadata=self.serialize_metadata(node),
  432. )
  433. elif isinstance(node.target, torch._ops.OpOverload):
  434. ex_node = Node(
  435. target=self.serialize_operator(node.target),
  436. inputs=self.serialize_inputs(node.target, node.args, node.kwargs),
  437. outputs=self.serialize_outputs(node),
  438. # TODO: create a new tensor_values here, meta might have faketensor info
  439. metadata=self.serialize_metadata(node),
  440. )
  441. elif isinstance(node.target, torch._ops.HigherOrderOperator):
  442. ex_node = Node(
  443. target=self.serialize_operator(node.target),
  444. inputs=self.serialize_hoo_inputs(node.args, node.kwargs),
  445. outputs=self.serialize_hoo_outputs(node),
  446. metadata=self.serialize_metadata(node),
  447. )
  448. elif type(node.target) in _serialization_registry:
  449. custom_op_handler = node.target
  450. # Sanity check for unhandled serialization.
  451. assert type(node.target) in _serialization_registry, f"Miss {type(node.target)} CustomOpHandler"
  452. handler = _serialization_registry[type(node.target)]
  453. ex_node = Node(
  454. target=f"${handler.namespace()}:{handler.op_name(node.target)}",
  455. inputs=self.serialize_inputs(node.target, node.args, node.kwargs),
  456. outputs=self.serialize_outputs(node),
  457. metadata=self.serialize_metadata(node),
  458. )
  459. else:
  460. raise SerializeError(f"Serializing {node.target} is not supported")
  461. self.graph_state.nodes.append(ex_node)
  462. def handle_get_attr(self, node):
  463. pass
  464. def _output_node_at_index(self, node, index):
  465. user_node = None
  466. for user in node.users:
  467. assert user.target is operator.getitem, f"{user} is not a getitem node"
  468. if index == user.args[1]:
  469. if user_node is None:
  470. user_node = user
  471. else:
  472. # We want to deduplicate getitem nodes that are trying to
  473. # index to the same index
  474. self.duplicate_getitem_nodes[user.name] = user_node.name
  475. return user_node
  476. def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]:
  477. ret = {}
  478. if stack_trace := node.meta.get("stack_trace"):
  479. ret["stack_trace"] = stack_trace
  480. if nn_module_stack := node.meta.get("nn_module_stack"):
  481. def export_nn_module_stack(val):
  482. assert isinstance(val, tuple) and len(val) == 2
  483. path, ty = val
  484. assert isinstance(path, str)
  485. assert isinstance(ty, str)
  486. return path + "," + ty
  487. # Serialize to "key,orig_path,type_str"
  488. nn_module_list = [
  489. f"{k},{export_nn_module_stack(v)}" for k, v in nn_module_stack.items()
  490. ]
  491. ret["nn_module_stack"] = ST_DELIMITER.join(nn_module_list)
  492. if source_fn_st := node.meta.get("source_fn_stack"):
  493. source_fn_list = [
  494. f"{source_fn[0]},{self.serialize_operator(source_fn[1])}"
  495. for source_fn in source_fn_st
  496. ]
  497. ret["source_fn_stack"] = ST_DELIMITER.join(source_fn_list)
  498. if torch_fn := node.meta.get("torch_fn"):
  499. ret["torch_fn"] = ST_DELIMITER.join(list(torch_fn))
  500. return ret
  501. def serialize_script_obj_meta(
  502. self, script_obj_meta: ep.CustomObjArgument
  503. ) -> CustomObjArgument:
  504. return CustomObjArgument(
  505. name=script_obj_meta.name,
  506. class_fqn=script_obj_meta.class_fqn,
  507. )
  508. def serialize_sym_op_inputs(self, op, args) -> List[NamedArgument]:
  509. serialized_args = []
  510. args_names = inspect.signature(op).parameters.keys()
  511. for args_name, arg in zip(args_names, args):
  512. serialized_args.append(
  513. NamedArgument(name=args_name, arg=self.serialize_input(arg))
  514. )
  515. return serialized_args
  516. def serialize_inputs(
  517. self,
  518. target: Any, # torch._ops.OpOverload and other custom operator types.
  519. args,
  520. kwargs=None
  521. ) -> List[NamedArgument]:
  522. assert isinstance(target, (torch._ops.OpOverload, *allowed_registered_op_types()))
  523. kwargs = kwargs or {}
  524. serialized_args = []
  525. schema = _get_schema_from_target(target)
  526. for i, schema_arg in enumerate(schema.arguments):
  527. if schema_arg.name in kwargs:
  528. serialized_args.append(
  529. NamedArgument(
  530. name=schema_arg.name,
  531. arg=self.serialize_input(kwargs[schema_arg.name], schema_arg.type),
  532. )
  533. )
  534. elif not schema_arg.kwarg_only and i < len(args):
  535. serialized_args.append(
  536. NamedArgument(
  537. name=schema_arg.name,
  538. arg=self.serialize_input(args[i], schema_arg.type),
  539. )
  540. )
  541. else:
  542. # We intentionally don't serialize the missing arguments
  543. # with default values
  544. pass
  545. return serialized_args
  546. def serialize_hoo_inputs(self, args, kwargs) -> List[NamedArgument]:
  547. """
  548. For serializing HOO inputs since HOOs do not have a schema.
  549. """
  550. inputs = [
  551. NamedArgument(
  552. name="",
  553. arg=self.serialize_input(a),
  554. )
  555. for a in args
  556. ]
  557. inputs.extend(
  558. [
  559. NamedArgument(name=name, arg=self.serialize_input(a))
  560. for name, a in kwargs.items()
  561. ]
  562. )
  563. return inputs
  564. def is_sym_int_arg(self, arg) -> bool:
  565. return isinstance(arg, int) or (
  566. isinstance(arg, torch.fx.Node)
  567. and arg.name in self.graph_state.sym_int_values
  568. )
  569. def is_sym_bool_arg(self, arg) -> bool:
  570. return isinstance(arg, bool) or (
  571. isinstance(arg, torch.fx.Node)
  572. and arg.name in self.graph_state.sym_bool_values
  573. )
  574. def serialize_input(
  575. self, arg, arg_type: Optional[torch._C.Argument] = None
  576. ) -> Argument:
  577. import torch._inductor.ir as inductor_ir
  578. inductor_tensor_buffers = (
  579. inductor_ir.Buffer,
  580. inductor_ir.ReinterpretView,
  581. )
  582. if isinstance(arg, torch.fx.Node):
  583. if arg.op == "get_attr":
  584. assert isinstance(arg.target, str)
  585. attr = getattr(arg.graph.owning_module, arg.target)
  586. if isinstance(attr, torch.Tensor):
  587. raise SerializeError(
  588. "getattr nodes containing tensors should not appear in the graph"
  589. )
  590. elif isinstance(attr, torch.fx.GraphModule):
  591. with self.save_graph_state():
  592. graph = self.serialize_graph(attr)
  593. return Argument.create(
  594. as_graph=GraphArgument(name=arg.target, graph=graph)
  595. )
  596. else:
  597. raise SerializeError(
  598. f"Unsupported getattr attribute {arg.target} with type: {type(attr)}"
  599. )
  600. elif self.is_sym_int_arg(arg):
  601. return Argument.create(
  602. as_sym_int=SymIntArgument.create(as_name=arg.name)
  603. )
  604. elif self.is_sym_bool_arg(arg):
  605. return Argument.create(
  606. as_sym_bool=SymBoolArgument.create(as_name=arg.name)
  607. )
  608. elif isinstance(arg.meta["val"], ep.CustomObjArgument):
  609. return Argument.create(
  610. as_custom_obj=CustomObjArgument(
  611. name=arg.name, class_fqn=arg.meta["val"].class_fqn
  612. )
  613. )
  614. elif arg.name in self.duplicate_getitem_nodes:
  615. dedup_name = self.duplicate_getitem_nodes[arg.name]
  616. return Argument.create(as_tensor=TensorArgument(name=dedup_name))
  617. else:
  618. return Argument.create(as_tensor=TensorArgument(name=arg.name))
  619. elif isinstance(arg, inductor_tensor_buffers):
  620. # Other branches are for arguments in fx node.
  621. # This is a special branch for handling buffers (representing tensor arguments)
  622. # for inductor's ExternalFallbackNode
  623. # export_extern_kernel_node() is using this function to serialize arguments
  624. arg_name = arg.get_name()
  625. assert arg_name is not None, "Buffer must have valid name"
  626. return Argument.create(as_tensor=TensorArgument(name=arg_name))
  627. elif isinstance(arg, torch.SymInt):
  628. # This is a special branch for handling SymInt args in inductor's
  629. # ExternalFallbackNode.
  630. # For regular FX graph, SymInt arg should be a fx.Node with
  631. # self.is_sym_int_arg(arg) being true
  632. return Argument.create(as_sym_int=SymIntArgument.create(as_name=str(arg)))
  633. elif isinstance(arg, bool):
  634. return Argument.create(as_bool=arg)
  635. elif isinstance(arg, str):
  636. return Argument.create(as_string=arg)
  637. elif isinstance(arg, int):
  638. return Argument.create(as_int=arg)
  639. elif isinstance(arg, float):
  640. return Argument.create(as_float=arg)
  641. elif arg is None:
  642. return Argument.create(as_none=())
  643. elif isinstance(arg, (list, tuple)):
  644. if len(arg) == 0:
  645. if arg_type is not None:
  646. if isinstance(arg_type, torch.OptionalType):
  647. arg_type = arg_type.getElementType() # type: ignore[assignment]
  648. assert isinstance(arg_type, torch.ListType)
  649. elem_type = arg_type.getElementType()
  650. if isinstance(elem_type, torch.OptionalType):
  651. elem_type = elem_type.getElementType()
  652. if isinstance(elem_type, torch.BoolType):
  653. return Argument.create(as_bools=[])
  654. elif isinstance(elem_type, torch.IntType):
  655. return Argument.create(as_ints=[])
  656. elif isinstance(elem_type, torch.FloatType):
  657. return Argument.create(as_floats=[])
  658. elif isinstance(elem_type, torch.StringType):
  659. return Argument.create(as_strings=[])
  660. elif isinstance(elem_type, torch.TensorType):
  661. return Argument.create(as_tensors=[])
  662. else:
  663. # I believe empty symint lists default to ints, but
  664. # please file an issue if this is not the case
  665. raise SerializeError(f"Empty list with type {elem_type} nyi.")
  666. else:
  667. # We could serialize this by default to a tensor list. This
  668. # is needed in the HOO case
  669. log.warning(
  670. "Unsure how to serialize the given empty list, "
  671. "as we don't know what is the type of this argument. "
  672. "Serializing it as a tensor list by default."
  673. )
  674. return Argument.create(as_tensors=[])
  675. # Must check bool first, as bool is also treated as int
  676. if all(isinstance(a, bool) for a in arg):
  677. return Argument.create(as_bools=list(arg))
  678. elif all(isinstance(a, int) for a in arg):
  679. return Argument.create(as_ints=list(arg))
  680. elif all(isinstance(a, float) for a in arg):
  681. return Argument.create(as_floats=list(arg))
  682. elif all(isinstance(a, str) for a in arg):
  683. return Argument.create(as_strings=list(arg))
  684. elif all(isinstance(a, torch.SymInt) for a in arg):
  685. # This is a special branch for handling SymInt args in inductor's
  686. # ExternalFallbackNode.
  687. # For regular FX graph, SymInt arg should be a fx.Node with
  688. # self.is_sym_int_arg(arg) being true
  689. return Argument.create(
  690. as_sym_ints=[SymIntArgument.create(as_name=str(a)) for a in arg]
  691. )
  692. elif all(self.is_sym_int_arg(a) for a in arg):
  693. # list of sym_ints
  694. values = []
  695. for a in arg:
  696. if isinstance(a, torch.fx.Node):
  697. values.append(SymIntArgument.create(as_name=a.name))
  698. elif isinstance(a, int):
  699. values.append(SymIntArgument.create(as_int=a))
  700. return Argument.create(as_sym_ints=values)
  701. elif all(self.is_sym_bool_arg(a) for a in arg):
  702. # list of sym_bools
  703. values = []
  704. for a in arg:
  705. if isinstance(a, torch.fx.Node):
  706. values.append(SymBoolArgument.create(as_name=a.name))
  707. elif isinstance(a, bool):
  708. values.append(SymBoolArgument.create(as_bool=a))
  709. return Argument.create(as_sym_bools=values)
  710. elif all(isinstance(a, torch.fx.Node) for a in arg):
  711. # list of tensors
  712. arguments = []
  713. for a in arg:
  714. if a.op == "get_attr":
  715. raise SerializeError(
  716. "getattr nodes containing tensors should not appear in the graph"
  717. )
  718. arguments.append(TensorArgument(name=a.name))
  719. return Argument.create(as_tensors=arguments)
  720. elif all(isinstance(a, (torch.fx.Node, type(None))) for a in arg):
  721. # list of optional tensors
  722. def serialize_optional_tensor_args(a):
  723. if a is None:
  724. return OptionalTensorArgument.create(as_none=())
  725. elif isinstance(a, torch.fx.Node):
  726. return OptionalTensorArgument.create(
  727. as_tensor=TensorArgument(name=a.name)
  728. )
  729. else:
  730. raise SerializeError(f"Unsupported list/tuple argument: {a}")
  731. return Argument.create(
  732. as_optional_tensors=list(map(serialize_optional_tensor_args, arg))
  733. )
  734. elif all(isinstance(a, inductor_tensor_buffers) for a in arg):
  735. # list of inductor buffers
  736. return Argument.create(
  737. as_tensors=[TensorArgument(name=a.get_name()) for a in arg],
  738. )
  739. elif all(
  740. isinstance(a, (*inductor_tensor_buffers, type(None))) for a in arg
  741. ):
  742. # list of inductor buffers as optional tensors
  743. def serialize_optional_tensor_args(a):
  744. if a is None:
  745. return OptionalTensorArgument.create(as_none=())
  746. elif isinstance(a, inductor_tensor_buffers):
  747. return OptionalTensorArgument.create(
  748. as_tensor=TensorArgument(name=a.get_name())
  749. )
  750. else:
  751. raise SerializeError(f"Unsupported list/tuple argument: {a}")
  752. return Argument.create(
  753. as_optional_tensors=list(map(serialize_optional_tensor_args, arg))
  754. )
  755. else:
  756. raise SerializeError(
  757. f"Unsupported list/tuple argument type: {[type(a) for a in arg]}"
  758. )
  759. elif isinstance(arg, torch.dtype):
  760. return Argument.create(as_scalar_type=_TORCH_TO_SERIALIZE_DTYPE[arg])
  761. elif isinstance(arg, torch.device):
  762. return Argument.create(as_device=Device(type=arg.type, index=arg.index))
  763. elif isinstance(arg, torch.memory_format):
  764. return Argument.create(
  765. as_memory_format=_TORCH_TO_SERIALIZE_MEMORY_FORMAT[arg]
  766. )
  767. elif isinstance(arg, torch.layout):
  768. return Argument.create(as_layout=_TORCH_TO_SERIALIZE_LAYOUT[arg])
  769. elif isinstance(arg, torch._C.ScriptObject):
  770. if not (
  771. arg._has_method("__getstate__") # type: ignore[attr-defined]
  772. and arg._has_method("__setstate__") # type: ignore[attr-defined]
  773. ):
  774. raise SerializeError(
  775. f"Unable to serialize custom class {arg}. Please define "
  776. "serialization methods via def_pickle()."
  777. )
  778. # Custom objects through torchind are serializable with pickle,
  779. # through implementing the .def_pickle function. This should result
  780. # in the object containing a __getstate__ and __setstate__
  781. # serialize/deserialize function.
  782. custom_obj_name = f"_custom_obj_{len(self.custom_objs)}"
  783. self.custom_objs[custom_obj_name] = arg
  784. class_fqn = arg._type().qualified_name() # type: ignore[attr-defined]
  785. return Argument.create(
  786. as_custom_obj=CustomObjArgument(custom_obj_name, class_fqn)
  787. )
  788. elif isinstance(arg, torch._ops.OpOverload):
  789. return Argument.create(as_operator=self.serialize_operator(arg))
  790. else:
  791. raise SerializeError(f"Unsupported argument type: {type(arg)}")
  792. def serialize_tensor_output(self, name, meta_val) -> TensorArgument:
  793. assert name not in self.graph_state.tensor_values
  794. self.graph_state.tensor_values[name] = serialize_tensor_meta(meta_val)
  795. return TensorArgument(name=name)
  796. def serialize_sym_int_output(self, name, meta_val) -> SymIntArgument:
  797. assert name not in self.graph_state.sym_int_values
  798. self.graph_state.sym_int_values[name] = serialize_sym_int(meta_val)
  799. return SymIntArgument.create(as_name=name)
  800. def serialize_sym_bool_output(self, name, meta_val) -> SymIntArgument:
  801. assert name not in self.graph_state.sym_bool_values
  802. self.graph_state.sym_bool_values[name] = serialize_sym_bool(meta_val)
  803. return SymBoolArgument.create(as_name=name)
  804. def serialize_input_spec(self, spec: ep.InputSpec) -> InputSpec:
  805. if spec.kind == ep.InputKind.USER_INPUT:
  806. if isinstance(spec.arg, ep.ConstantArgument):
  807. if isinstance(spec.arg.value, int):
  808. constant_spec = ConstantValue.create(as_int=spec.arg.value)
  809. elif isinstance(spec.arg.value, bool):
  810. constant_spec = ConstantValue.create(as_bool=spec.arg.value)
  811. elif isinstance(spec.arg.value, str):
  812. constant_spec = ConstantValue.create(as_string=spec.arg.value)
  813. elif isinstance(spec.arg.value, float):
  814. constant_spec = ConstantValue.create(as_float=spec.arg.value)
  815. elif spec.arg.value is None:
  816. constant_spec = ConstantValue.create(as_none=())
  817. else:
  818. raise SerializeError(f"Unhandled constant input {spec.arg.value} to serialize")
  819. return InputSpec.create(
  820. constant_input=ConstantInputSpec(
  821. name=spec.arg.name, value=constant_spec
  822. )
  823. )
  824. else:
  825. return InputSpec.create(
  826. user_input=UserInputSpec(
  827. arg=self.serialize_argument_spec(spec.arg)
  828. )
  829. )
  830. elif spec.kind == ep.InputKind.PARAMETER:
  831. assert spec.target is not None
  832. assert isinstance(spec.arg, ep.TensorArgument)
  833. return InputSpec.create(
  834. parameter=InputToParameterSpec(
  835. arg=TensorArgument(name=spec.arg.name),
  836. parameter_name=spec.target,
  837. )
  838. )
  839. elif spec.kind == ep.InputKind.BUFFER:
  840. assert spec.target is not None
  841. assert isinstance(spec.arg, ep.TensorArgument)
  842. assert spec.persistent is not None
  843. return InputSpec.create(
  844. buffer=InputToBufferSpec(
  845. arg=TensorArgument(name=spec.arg.name),
  846. buffer_name=spec.target,
  847. persistent=spec.persistent,
  848. )
  849. )
  850. elif spec.kind == ep.InputKind.CONSTANT_TENSOR:
  851. assert spec.target is not None
  852. assert isinstance(spec.arg, ep.TensorArgument)
  853. return InputSpec.create(
  854. tensor_constant=InputToTensorConstantSpec(
  855. arg=TensorArgument(name=spec.arg.name),
  856. tensor_constant_name=spec.target,
  857. )
  858. )
  859. elif spec.kind == ep.InputKind.CUSTOM_OBJ:
  860. assert spec.target is not None
  861. assert isinstance(spec.arg, ep.CustomObjArgument)
  862. return InputSpec.create(
  863. custom_obj=InputToCustomObjSpec(
  864. arg=CustomObjArgument(
  865. name=spec.arg.name, class_fqn=spec.arg.class_fqn
  866. ),
  867. custom_obj_name=spec.target,
  868. )
  869. )
  870. elif spec.kind == ep.InputKind.TOKEN:
  871. assert isinstance(spec.arg, ep.TokenArgument)
  872. return InputSpec.create(
  873. token=InputTokenSpec(
  874. arg=TokenArgument(name=spec.arg.name),
  875. )
  876. )
  877. else:
  878. raise AssertionError(f"Unknown argument kind: {spec}")
  879. def serialize_output_spec(self, spec: ep.OutputSpec) -> OutputSpec:
  880. if spec.kind == ep.OutputKind.USER_OUTPUT:
  881. return OutputSpec.create(
  882. user_output=UserOutputSpec(arg=self.serialize_argument_spec(spec.arg))
  883. )
  884. elif spec.kind == ep.OutputKind.LOSS_OUTPUT:
  885. assert isinstance(spec.arg, ep.TensorArgument)
  886. return OutputSpec.create(
  887. loss_output=LossOutputSpec(arg=TensorArgument(name=spec.arg.name))
  888. )
  889. elif spec.kind == ep.OutputKind.BUFFER_MUTATION:
  890. assert spec.target is not None
  891. assert isinstance(spec.arg, ep.TensorArgument)
  892. return OutputSpec.create(
  893. buffer_mutation=BufferMutationSpec(
  894. arg=TensorArgument(name=spec.arg.name),
  895. buffer_name=spec.target,
  896. )
  897. )
  898. elif spec.kind == ep.OutputKind.GRADIENT_TO_PARAMETER:
  899. assert spec.target is not None
  900. assert isinstance(spec.arg, ep.TensorArgument)
  901. return OutputSpec.create(
  902. gradient_to_parameter=GradientToParameterSpec(
  903. arg=TensorArgument(name=spec.arg.name),
  904. parameter_name=spec.target,
  905. )
  906. )
  907. elif spec.kind == ep.OutputKind.GRADIENT_TO_USER_INPUT:
  908. assert spec.target is not None
  909. assert isinstance(spec.arg, ep.TensorArgument)
  910. return OutputSpec.create(
  911. gradient_to_user_input=GradientToUserInputSpec(
  912. arg=TensorArgument(name=spec.arg.name),
  913. user_input_name=spec.target,
  914. )
  915. )
  916. elif spec.kind == ep.OutputKind.USER_INPUT_MUTATION:
  917. assert spec.target is not None
  918. assert isinstance(spec.arg, ep.TensorArgument)
  919. return OutputSpec.create(
  920. user_input_mutation=UserInputMutationSpec(
  921. arg=TensorArgument(name=spec.arg.name),
  922. user_input_name=spec.target,
  923. )
  924. )
  925. elif spec.kind == ep.OutputKind.TOKEN:
  926. assert isinstance(spec.arg, ep.TokenArgument)
  927. return OutputSpec.create(
  928. token=OutputTokenSpec(
  929. arg=TokenArgument(name=spec.arg.name),
  930. )
  931. )
  932. else:
  933. raise AssertionError(f"Unknown argument kind: {spec}")
  934. def serialize_signature(self, sig: ep.ExportGraphSignature) -> GraphSignature:
  935. return GraphSignature(
  936. input_specs=[self.serialize_input_spec(s) for s in sig.input_specs],
  937. output_specs=[self.serialize_output_spec(s) for s in sig.output_specs],
  938. )
  939. def serialize_argument_spec(self, x: ep.ArgumentSpec) -> Argument:
  940. if isinstance(x, ep.TensorArgument):
  941. return Argument.create(as_tensor=TensorArgument(name=x.name))
  942. elif isinstance(x, ep.SymIntArgument):
  943. return Argument.create(as_sym_int=SymIntArgument.create(as_name=x.name))
  944. elif isinstance(x, ep.ConstantArgument):
  945. return self.serialize_input(x.value)
  946. elif isinstance(x, ep.CustomObjArgument):
  947. return Argument.create(
  948. as_custom_obj=CustomObjArgument(name=x.name, class_fqn=x.class_fqn)
  949. )
  950. else:
  951. raise AssertionError("TODO")
  952. def serialize_module_call_signature(
  953. self, module_call_signature: ep.ModuleCallSignature
  954. ) -> ModuleCallSignature:
  955. return ModuleCallSignature(
  956. inputs=[
  957. self.serialize_argument_spec(x) for x in module_call_signature.inputs
  958. ],
  959. outputs=[
  960. self.serialize_argument_spec(x) for x in module_call_signature.outputs
  961. ],
  962. in_spec=treespec_dumps(module_call_signature.in_spec, TREESPEC_VERSION),
  963. out_spec=treespec_dumps(module_call_signature.out_spec, TREESPEC_VERSION),
  964. )
  965. def serialize_module_call_graph(
  966. self, module_call_graph: List[ep.ModuleCallEntry]
  967. ) -> List[ModuleCallEntry]:
  968. return [
  969. ModuleCallEntry(
  970. fqn=entry.fqn,
  971. signature=(
  972. self.serialize_module_call_signature(entry.signature)
  973. if entry.signature
  974. else None
  975. ),
  976. )
  977. for entry in module_call_graph
  978. ]
  979. def serialize_outputs(self, node: torch.fx.Node) -> List[Argument]:
  980. """For a given node, return the dataclass representing its output values.
  981. [NOTE: Multiple outputs] We handle aggregates differently than FX. For
  982. FX, it looks like:
  983. x = call_function("multiple_return", ...)
  984. element0 = call_function(getitem, x, 0)
  985. foo = call_function("use_output", element0)
  986. We do not want the intermediate `getitem` call, so our serialized thing looks like:
  987. element0, element1, element2 = call_function("multiple_return", ...)
  988. foo = call_function("use_output", element0)
  989. We want names to be consistent across these two schemes, so that we can
  990. mostly reuse the names coming from FX. This function computes a mapping from
  991. the FX representation to our representation, preserving the names.
  992. """
  993. assert node.op == "call_function" and isinstance(node.target, (torch._ops.OpOverload, *allowed_registered_op_types()))
  994. schema = _get_schema_from_target(node.target)
  995. returns = schema.returns
  996. if len(returns) == 0:
  997. return []
  998. meta_val = node.meta["val"]
  999. # Check single value return
  1000. if _is_single_tensor_list_return(node.target):
  1001. # e.g "-> Tensor[]"
  1002. tensor_args = []
  1003. for idx, meta in enumerate(meta_val):
  1004. user_node = self._output_node_at_index(node, idx)
  1005. name = (
  1006. user_node.name
  1007. if user_node is not None
  1008. else f"{node.name}_unused_{idx}"
  1009. )
  1010. tensor_args.append(self.serialize_tensor_output(name, meta))
  1011. return [Argument.create(as_tensors=tensor_args)]
  1012. elif len(returns) == 1:
  1013. return [self.serialize_output(node.name, meta_val)]
  1014. # There are a two possibilities at this point:
  1015. # - This operator returns a tuple of Tensors, e.g. "-> (Tensor, Tensor)"
  1016. # - This operator returns a tuple of mixed of Tensor and Tensors, e.g. "-> (Tensor, Tensor[])"
  1017. #
  1018. # Either way, start by gathering a list of TensorArguments with the correct names.
  1019. # For consistent naming with FX, consult the downstream `getitem` node and
  1020. # make sure our outputs have the same name.
  1021. output_arguments = []
  1022. for idx, (meta, return_schema) in enumerate(zip(meta_val, returns)):
  1023. if meta is None:
  1024. assert isinstance(
  1025. return_schema.real_type, (torch.OptionalType, torch.TensorType)
  1026. )
  1027. # When the return type is annoated as Tensor type, the op can also return an
  1028. # undefined Tensor which will be implicitly converted to None in Python.
  1029. output_arguments.append(Argument.create(as_none=()))
  1030. elif isinstance(meta, FakeTensor):
  1031. assert isinstance(return_schema.real_type, (torch.OptionalType, torch.TensorType))
  1032. user_node = self._output_node_at_index(node, idx)
  1033. name = (
  1034. user_node.name
  1035. if user_node is not None
  1036. else f"{node.name}_unused_{idx}"
  1037. )
  1038. output_arguments.append(self.serialize_output(name, meta))
  1039. elif isinstance(meta, list):
  1040. # for List[Tensor] return type
  1041. assert isinstance(
  1042. return_schema.real_type, torch.ListType
  1043. ) and isinstance(
  1044. return_schema.real_type.getElementType(), torch.TensorType
  1045. )
  1046. user_node = self._output_node_at_index(node, idx)
  1047. assert user_node is not None
  1048. args = []
  1049. for i, m in enumerate(meta):
  1050. if m is None:
  1051. continue
  1052. sub_user_node = self._output_node_at_index(user_node, i)
  1053. assert sub_user_node is not None, f"No user found at index {i}"
  1054. args.append(self.serialize_tensor_output(sub_user_node.name, m))
  1055. output_arguments.append(Argument.create(as_tensors=args))
  1056. elif isinstance(meta, (int, SymInt)):
  1057. user_node = self._output_node_at_index(node, idx)
  1058. name = (
  1059. user_node.name
  1060. if user_node is not None
  1061. else f"{node.name}_unused_{idx}"
  1062. )
  1063. output_arguments.append(self.serialize_output(name, meta))
  1064. else:
  1065. raise ValueError(
  1066. f"Unhandled output type {type(meta)} from node {node.format_node()}"
  1067. )
  1068. return output_arguments
  1069. def serialize_hoo_outputs(self, node: torch.fx.Node) -> List[Argument]:
  1070. """
  1071. For serializing HOO outputs since HOOs do not have a schema.
  1072. """
  1073. meta_val = node.meta["val"]
  1074. if isinstance(meta_val, tuple):
  1075. # Note: Since we don't have a schema, we just serialize all tuple
  1076. # outputs to be a list of values. Even if the output is supposed to
  1077. # be a tensor list (Tensor[]), we will serialize it to be a list of
  1078. # tensors (Tensor, Tensor, Tensor). An exception is that if there's
  1079. # a singleton tensor, we will serialize this to be a singleton
  1080. # tensor list so that the deserializer knows to insert getitem nodes.
  1081. if len(meta_val) == 1:
  1082. assert isinstance(meta_val[0], torch.Tensor)
  1083. user_node = self._output_node_at_index(node, 0)
  1084. name = (
  1085. user_node.name
  1086. if user_node is not None
  1087. else f"{node.name}_unused_0"
  1088. )
  1089. return [Argument.create(as_tensors=[self.serialize_tensor_output(name, meta_val[0])])]
  1090. outputs = []
  1091. for i, element_meta_val in enumerate(meta_val):
  1092. user_node = self._output_node_at_index(node, i)
  1093. if isinstance(element_meta_val, list):
  1094. # e.g "-> Tensor[]"
  1095. assert user_node is not None
  1096. tensors = []
  1097. for j, m in enumerate(element_meta_val):
  1098. if not isinstance(m, torch.Tensor):
  1099. raise SerializeError(f"Serialize list output with type {type(m)} nyi")
  1100. sub_user_node = self._output_node_at_index(user_node, j)
  1101. name = (
  1102. sub_user_node.name
  1103. if sub_user_node is not None
  1104. else f"{user_node.name}_unused_{j}"
  1105. )
  1106. tensors.append(self.serialize_tensor_output(name, m))
  1107. outputs.append(Argument.create(as_tensors=tensors))
  1108. else:
  1109. name = (
  1110. user_node.name
  1111. if user_node is not None
  1112. else f"{node.name}_unused_{i}"
  1113. )
  1114. outputs.append(self.serialize_output(name, element_meta_val))
  1115. return outputs
  1116. else:
  1117. return [self.serialize_output(node.name, meta_val)]
  1118. def serialize_output(self, name: str, meta_val: Any) -> Argument:
  1119. # Check single value return
  1120. if meta_val is None:
  1121. return Argument.create(as_none=())
  1122. if isinstance(meta_val, torch.Tensor):
  1123. # e.g "-> Tensor"
  1124. return Argument.create(
  1125. as_tensor=self.serialize_tensor_output(name, meta_val)
  1126. )
  1127. elif isinstance(meta_val, (int, torch.SymInt)):
  1128. # e.g "-> SymInt"
  1129. return Argument.create(
  1130. as_sym_int=self.serialize_sym_int_output(name, meta_val)
  1131. )
  1132. elif isinstance(meta_val, torch.SymBool):
  1133. # e.g "-> SymBool"
  1134. return Argument.create(
  1135. as_sym_bool=self.serialize_sym_bool_output(name, meta_val)
  1136. )
  1137. # list outputs should've been handled earlier
  1138. raise SerializeError(f"Unable to serialize output {meta_val}")
  1139. def _handle_getitem_users(self, node: torch.fx.Node) -> List[TensorArgument]:
  1140. meta_val = node.meta["val"]
  1141. idx_to_name = {}
  1142. for user in node.users:
  1143. assert (
  1144. user.target is operator.getitem
  1145. ), f"User node {user} of {node} is incorrect"
  1146. idx_to_name[user.args[1]] = user.name
  1147. for idx, _ in enumerate(meta_val):
  1148. # FX does not emit a getitem node for any outputs that are unused.
  1149. # However, we need a name for them so that the number of outputs will
  1150. # correctly match the schema. Just assign a dummy name.
  1151. if idx not in idx_to_name:
  1152. idx_to_name[idx] = f"{node.name}_unused_{idx}"
  1153. arg_list = []
  1154. for i, element_meta_val in enumerate(meta_val):
  1155. arg_list.append(
  1156. self.serialize_tensor_output(idx_to_name[i], element_meta_val)
  1157. )
  1158. return arg_list
  1159. def serialize_graph(self, graph_module: torch.fx.GraphModule) -> Graph:
  1160. assert isinstance(graph_module, torch.fx.GraphModule)
  1161. for node in graph_module.graph.nodes:
  1162. try:
  1163. getattr(self, f"handle_{node.op}")(node)
  1164. except Exception as e:
  1165. raise SerializeError(
  1166. f"Failed serializing node {node} in graph: {node.format_node()}"
  1167. ) from e
  1168. return Graph(
  1169. inputs=self.graph_state.inputs,
  1170. nodes=self.graph_state.nodes,
  1171. tensor_values=self.graph_state.tensor_values,
  1172. sym_int_values=self.graph_state.sym_int_values,
  1173. sym_bool_values=self.graph_state.sym_bool_values,
  1174. custom_obj_values=self.graph_state.custom_obj_values,
  1175. outputs=self.graph_state.outputs,
  1176. is_single_tensor_return=self.graph_state.is_single_tensor_return,
  1177. )
  1178. def serialize(self, graph_module: torch.fx.GraphModule) -> GraphModule:
  1179. graph = self.serialize_graph(graph_module)
  1180. return GraphModule(
  1181. graph=graph,
  1182. signature=self.serialize_signature(self.graph_signature),
  1183. module_call_graph=self.serialize_module_call_graph(self.module_call_graph),
  1184. )
  1185. @final
  1186. class ExportedProgramSerializer(metaclass=Final):
  1187. def __init__(self, opset_version: Optional[Dict[str, int]] = None):
  1188. self.opset_version: Dict[str, int] = {}
  1189. if opset_version:
  1190. self.opset_version.update(opset_version)
  1191. if "aten" not in self.opset_version:
  1192. self.opset_version["aten"] = torch._C._get_max_operator_version()
  1193. def serialize(self, exported_program: ep.ExportedProgram) -> _SerializedProgram:
  1194. """
  1195. Args:
  1196. exported_program: Exported Program to serialize
  1197. """
  1198. exported_program._validate()
  1199. gm_serializer = GraphModuleSerializer(
  1200. exported_program.graph_signature, exported_program.module_call_graph
  1201. )
  1202. serialized_graph_module = gm_serializer.serialize(exported_program.graph_module)
  1203. serialized_range_constraints = serialize_range_constraints(
  1204. exported_program.range_constraints
  1205. )
  1206. # TODO: Directly serialize exported_program.constants once
  1207. # CustomClassHolders get stored in the ExportedProgram rather than in
  1208. # the graph
  1209. constants = {}
  1210. for n, c in gm_serializer.custom_objs.items():
  1211. constants[n] = c
  1212. for n, t in exported_program.constants.items():
  1213. assert n not in constants
  1214. constants[n] = t
  1215. serialized_ep = ExportedProgram(
  1216. graph_module=serialized_graph_module,
  1217. opset_version=self.opset_version,
  1218. range_constraints=serialized_range_constraints,
  1219. schema_version=SchemaVersion(
  1220. major=SCHEMA_VERSION[0],
  1221. minor=SCHEMA_VERSION[1],
  1222. ),
  1223. dialect=exported_program.dialect
  1224. )
  1225. # Test canonical form is well defined.
  1226. canonicalize(serialized_ep)
  1227. return _SerializedProgram(
  1228. serialized_ep,
  1229. serialize_torch_artifact(exported_program.state_dict),
  1230. serialize_torch_artifact(constants),
  1231. serialize_torch_artifact(exported_program.example_inputs),
  1232. )
  1233. @final
  1234. class GraphModuleDeserializer(metaclass=Final):
  1235. @dataclasses.dataclass
  1236. class Result:
  1237. graph_module: torch.fx.GraphModule
  1238. signature: ep.ExportGraphSignature
  1239. module_call_graph: List[ep.ModuleCallEntry]
  1240. names_to_symbols: Dict[str, sympy.Symbol]
  1241. state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]]
  1242. constants: Dict[str, Union[torch.Tensor, torch.ScriptObject]]
  1243. example_inputs: Optional[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]]
  1244. def __init__(self):
  1245. self.serialized_name_to_node: Dict[str, torch.fx.Node] = {}
  1246. self.serialized_name_to_meta: Dict[str, MetaType] = {}
  1247. self.graph = torch.fx.Graph()
  1248. self.module = torch.nn.Module()
  1249. @contextmanager
  1250. def save_graph_module(self) -> Iterator[None]:
  1251. saved = (
  1252. self.graph,
  1253. self.module,
  1254. self.serialized_name_to_node,
  1255. self.serialized_name_to_meta,
  1256. )
  1257. self.graph = torch.fx.Graph()
  1258. self.module = torch.nn.Module()
  1259. self.serialized_name_to_node = {}
  1260. self.serialized_name_to_meta = {}
  1261. try:
  1262. yield
  1263. finally:
  1264. (
  1265. self.graph,
  1266. self.module,
  1267. self.serialized_name_to_node,
  1268. self.serialized_name_to_meta,
  1269. ) = saved
  1270. def deserialize_operator(self, serialized_target: str):
  1271. if serialized_target.startswith(
  1272. "_operator"
  1273. ): # TODO(zhxchen17) Follow up on this.
  1274. module = operator
  1275. serialized_target_names = serialized_target.split(".")[1:]
  1276. elif serialized_target.startswith("torch"):
  1277. module = torch # type: ignore[misc]
  1278. serialized_target_names = serialized_target.split(".")[1:]
  1279. else: # TODO(zhxchen17) Don't catch all here.
  1280. return serialized_target
  1281. target = module
  1282. for name in serialized_target_names:
  1283. if not hasattr(target, name):
  1284. return serialized_target
  1285. else:
  1286. target = getattr(target, name)
  1287. return target
  1288. def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]:
  1289. val = s.value
  1290. if s.type == "as_expr":
  1291. if val.hint is None:
  1292. hint = None
  1293. else:
  1294. assert val.hint.type == "as_int"
  1295. hint = val.hint.value
  1296. if val.expr_str in self.symbol_name_to_symbol:
  1297. sym = self.symbol_name_to_symbol[val.expr_str]
  1298. else:
  1299. sym = sympy.sympify(val.expr_str, locals=self.symbol_name_to_symbol)
  1300. # NOTE(avik): Assumptions on symbols are not explicitly serialized.
  1301. # This seems dangerous: it might cause unknown differences in shape env behavior
  1302. # on deserialization? Probably deserves a follow-up.
  1303. # Here we force symbols corresponding to SymInts to be at least integers.
  1304. # Otherwise some expressions that the shape env would otherwise evaluate to False,
  1305. # e.g., 2*s = 9, can have rational solutions, e.g., 9/2.
  1306. # TODO: This is HIGHLY SUSPICIOUS ezyang(May 2024)
  1307. sym = sym.subs(
  1308. {s: sympy.Symbol(s.name, integer=True) for s in sym.free_symbols}
  1309. )
  1310. # We need to check if the symbol has already been allocated,
  1311. # self.symbol_name_to_symbol is not enough because the
  1312. # integer-ification of symbols can induce simplification;
  1313. # e.g., (2**s0 + 1) // 2 --> s0 when we know s0 is integral
  1314. if isinstance(sym, sympy.Symbol) and sym not in self.shape_env.var_to_val:
  1315. self.symbol_name_to_symbol[val.expr_str] = sym
  1316. if hint is not None:
  1317. self.shape_env.add_var_to_val(sym, hint)
  1318. if vr := self.symbol_name_to_range.get(val.expr_str):
  1319. self.shape_env.constrain_symbol_range(
  1320. sym,
  1321. compiler_min=vr.lower, # type: ignore[arg-type]
  1322. compiler_max=vr.upper, # type: ignore[arg-type]
  1323. )
  1324. else:
  1325. # Placeholders, in particular, can have shapes as symbolic expressions.
  1326. # We need to populate the shape env with the range constraints of their
  1327. # free symbols, otherwise evaluating such expressions will error.
  1328. self.symbol_name_to_symbol[val.expr_str] = sym
  1329. free_symbols = sym.free_symbols
  1330. for s in free_symbols:
  1331. if s.name not in self.symbol_name_to_symbol:
  1332. self.symbol_name_to_symbol[s.name] = s # type: ignore[assignment]
  1333. if vr := self.symbol_name_to_range.get(s.name):
  1334. self.shape_env.constrain_symbol_range(
  1335. s,
  1336. compiler_min=vr.lower, # type: ignore[arg-type]
  1337. compiler_max=vr.upper, # type: ignore[arg-type]
  1338. )
  1339. return self.shape_env.create_symintnode(sym, hint=hint)
  1340. elif s.type == "as_int":
  1341. assert isinstance(val, int)
  1342. return val
  1343. else:
  1344. raise SerializeError(
  1345. f"SymInt has invalid field type {s.type} with value {s.value}"
  1346. )
  1347. def deserialize_sym_bool(self, s: SymBool) -> Union[bool, torch.SymBool]:
  1348. val = s.value
  1349. if s.type == "as_expr":
  1350. expr = sympy.sympify(val.expr_str, locals=self.symbol_name_to_symbol)
  1351. return self.shape_env.create_symboolnode(expr)
  1352. elif s.type == "as_bool":
  1353. assert isinstance(val, bool)
  1354. return val
  1355. else:
  1356. raise SerializeError(
  1357. f"SymBool has invalid field type {s.type} with value {s.value}"
  1358. )
  1359. def deserialize_tensor_meta(
  1360. self,
  1361. tensor_meta: TensorMeta,
  1362. ) -> FakeTensor:
  1363. with self.fake_tensor_mode:
  1364. return cast(
  1365. FakeTensor,
  1366. torch.empty_strided(
  1367. tuple(self.deserialize_sym_int(val) for val in tensor_meta.sizes), # type: ignore[misc]
  1368. tuple(self.deserialize_sym_int(val) for val in tensor_meta.strides), # type: ignore[misc]
  1369. device=deserialize_device(tensor_meta.device),
  1370. dtype=_SERIALIZE_TO_TORCH_DTYPE[tensor_meta.dtype],
  1371. ),
  1372. )
  1373. def deserialize_script_obj_meta(
  1374. self, script_obj_meta: CustomObjArgument
  1375. ) -> ep.CustomObjArgument:
  1376. return ep.CustomObjArgument(
  1377. name=script_obj_meta.name,
  1378. class_fqn=script_obj_meta.class_fqn,
  1379. )
  1380. def deserialize_graph_output(self, output) -> Optional[Union[torch.fx.Node, int]]:
  1381. if output.type == "as_tensor":
  1382. return self.serialized_name_to_node[output.as_tensor.name]
  1383. elif output.type == "as_sym_int":
  1384. return self.serialized_name_to_node[output.as_sym_int.as_name]
  1385. elif output.type == "as_sym_bool":
  1386. return self.serialized_name_to_node[output.as_sym_bool.as_name]
  1387. elif output.type == "as_int":
  1388. return output.as_int
  1389. elif output.type == "as_none":
  1390. return None
  1391. else:
  1392. raise SerializeError(f"Unable to deserialize output node {output}")
  1393. def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph:
  1394. # Handle the tensor metas.
  1395. for name, tensor_value in serialized_graph.tensor_values.items():
  1396. meta_val = self.deserialize_tensor_meta(tensor_value)
  1397. self.serialized_name_to_meta[name] = meta_val
  1398. for name, sym_int_value in serialized_graph.sym_int_values.items():
  1399. self.serialized_name_to_meta[name] = self.deserialize_sym_int(sym_int_value)
  1400. for name, sym_bool_value in serialized_graph.sym_bool_values.items():
  1401. self.serialized_name_to_meta[name] = self.deserialize_sym_bool(
  1402. sym_bool_value
  1403. )
  1404. for name, script_obj_meta in serialized_graph.custom_obj_values.items():
  1405. self.serialized_name_to_meta[name] = self.deserialize_script_obj_meta(
  1406. script_obj_meta
  1407. )
  1408. # Inputs: convert to placeholder nodes in FX.
  1409. for i, input_ in enumerate(serialized_graph.inputs):
  1410. if input_.type in ("as_tensor", "as_sym_int", "as_custom_obj"):
  1411. node_name = input_.value.name
  1412. placeholder_node = self.graph.placeholder(node_name)
  1413. # FX might declare a name illegal (e.g. some nn.Modules use "input" as forward() arguments)
  1414. # we will overwrite it
  1415. placeholder_node.name = node_name
  1416. self.sync_fx_node(node_name, placeholder_node)
  1417. elif input_.type in (
  1418. "as_int",
  1419. "as_float",
  1420. "as_bool",
  1421. "as_none",
  1422. "as_string",
  1423. ):
  1424. node_name = self.signature.input_specs[i].arg.name
  1425. placeholder_node = self.graph.placeholder(node_name)
  1426. placeholder_node.meta["val"] = self.deserialize_input(input_)
  1427. else:
  1428. raise SerializeError(f"Invalid input type {input_}")
  1429. # Nodes: convert to call_function nodes.
  1430. for serialized_node in serialized_graph.nodes:
  1431. try:
  1432. target = self.deserialize_operator(serialized_node.target)
  1433. self.deserialize_node(serialized_node, target)
  1434. except Exception as e:
  1435. raise SerializeError(
  1436. f"Failed deserializing node {serialized_node}"
  1437. ) from e
  1438. # Outputs: convert to a single `output` node.
  1439. outputs = []
  1440. for output in serialized_graph.outputs:
  1441. outputs.append(self.deserialize_graph_output(output))
  1442. if serialized_graph.is_single_tensor_return:
  1443. assert len(outputs) == 1
  1444. outputs = outputs[0] # type: ignore[assignment]
  1445. else:
  1446. outputs = tuple(outputs) # type: ignore[assignment]
  1447. output_node = self.graph.output(outputs)
  1448. if serialized_graph.is_single_tensor_return:
  1449. output_node.meta["val"] = output_node.args[0].meta["val"]
  1450. else:
  1451. output_node.meta["val"] = tuple(
  1452. arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg
  1453. for arg in output_node.args[0]
  1454. )
  1455. return self.graph
  1456. def deserialize_node(self, serialized_node: Node, target: Callable) -> None:
  1457. if target in _SYM_BOOL_OPS or target in _SYM_INT_OPS:
  1458. name = serialized_node.outputs[0].value.as_name
  1459. args = self.deserialize_sym_op_inputs(serialized_node.inputs)
  1460. fx_node = self.graph.create_node("call_function", target, args, {}, name)
  1461. self.deserialize_sym_op_outputs(serialized_node, fx_node)
  1462. elif isinstance(target, torch._ops.HigherOrderOperator):
  1463. args, kwargs = self.deserialize_hoo_inputs(serialized_node.inputs)
  1464. # If HOP returns a single tensor, name the
  1465. # newly-created node after it. This ensures that these tensor values
  1466. # have names that are consistent with serialized.
  1467. #
  1468. # HOPs don't have schema yet, just check the output lengths and as_tensor attribute
  1469. name = (
  1470. serialized_node.outputs[0].as_tensor.name
  1471. if len(serialized_node.outputs) == 1
  1472. and hasattr(serialized_node.outputs[0], "as_tensor")
  1473. else None
  1474. )
  1475. fx_node = self.graph.create_node(
  1476. "call_function", target, args, kwargs, name
  1477. )
  1478. self.deserialize_outputs(serialized_node, fx_node)
  1479. fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
  1480. elif isinstance(target, torch._ops.OpOverload):
  1481. # For convenience: if this node returns a single tensor, name the
  1482. # newly-created node after it. This ensures that these tensor values
  1483. # have names that are consistent with serialized.
  1484. name = (
  1485. serialized_node.outputs[0].as_tensor.name
  1486. if _is_single_tensor_return(target)
  1487. else None # FX will generate a name for us.
  1488. )
  1489. args, kwargs = self.deserialize_inputs(target, serialized_node)
  1490. fx_node = self.graph.create_node(
  1491. "call_function", target, args, kwargs, name
  1492. )
  1493. self.deserialize_outputs(serialized_node, fx_node)
  1494. else:
  1495. raise SerializeError(
  1496. f"Unsupported target type for node {serialized_node}: {type(target)}"
  1497. )
  1498. fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
  1499. if fx_node.op not in ["placeholder", "output"] and "nn_module_stack" not in fx_node.meta:
  1500. fx_node.meta["nn_module_stack"] = {} # serialization throws away empty dicts
  1501. def deserialize_input_spec(self, i: InputSpec) -> ep.InputSpec:
  1502. if i.type == "user_input":
  1503. return ep.InputSpec(
  1504. kind=ep.InputKind.USER_INPUT,
  1505. arg=self.deserialize_argument_spec(i.user_input.arg),
  1506. target=None,
  1507. )
  1508. elif i.type == "parameter":
  1509. return ep.InputSpec(
  1510. kind=ep.InputKind.PARAMETER,
  1511. arg=ep.TensorArgument(name=i.parameter.arg.name),
  1512. target=i.parameter.parameter_name,
  1513. )
  1514. elif i.type == "buffer":
  1515. return ep.InputSpec(
  1516. kind=ep.InputKind.BUFFER,
  1517. arg=ep.TensorArgument(name=i.buffer.arg.name),
  1518. target=i.buffer.buffer_name,
  1519. persistent=i.buffer.persistent,
  1520. )
  1521. elif i.type == "tensor_constant":
  1522. return ep.InputSpec(
  1523. kind=ep.InputKind.CONSTANT_TENSOR,
  1524. arg=ep.TensorArgument(name=i.tensor_constant.arg.name),
  1525. target=i.tensor_constant.tensor_constant_name,
  1526. )
  1527. elif i.type == "custom_obj":
  1528. return ep.InputSpec(
  1529. kind=ep.InputKind.CUSTOM_OBJ,
  1530. arg=ep.CustomObjArgument(
  1531. name=i.custom_obj.arg.name, class_fqn=i.custom_obj.arg.class_fqn
  1532. ),
  1533. target=i.custom_obj.custom_obj_name,
  1534. )
  1535. elif i.type == "token":
  1536. return ep.InputSpec(
  1537. kind=ep.InputKind.TOKEN,
  1538. arg=ep.TokenArgument(name=i.token.arg.name),
  1539. target=None
  1540. )
  1541. elif i.type == "constant_input":
  1542. return ep.InputSpec(
  1543. kind=ep.InputKind.USER_INPUT,
  1544. arg=ep.ConstantArgument(
  1545. name=i.constant_input.name,
  1546. value=self.deserialize_constant_input(i.constant_input.value)
  1547. ),
  1548. target=None,
  1549. )
  1550. else:
  1551. raise AssertionError(f"Unknown input spec {i}")
  1552. def deserialize_output_spec(self, o: OutputSpec) -> ep.OutputSpec:
  1553. if o.type == "user_output":
  1554. return ep.OutputSpec(
  1555. kind=ep.OutputKind.USER_OUTPUT,
  1556. arg=self.deserialize_argument_spec(o.user_output.arg),
  1557. target=None,
  1558. )
  1559. elif o.type == "loss_output":
  1560. return ep.OutputSpec(
  1561. kind=ep.OutputKind.LOSS_OUTPUT,
  1562. arg=ep.TensorArgument(name=o.loss_output.arg.name),
  1563. target=None,
  1564. )
  1565. elif o.type == "buffer_mutation":
  1566. return ep.OutputSpec(
  1567. kind=ep.OutputKind.BUFFER_MUTATION,
  1568. arg=ep.TensorArgument(name=o.buffer_mutation.arg.name),
  1569. target=o.buffer_mutation.buffer_name,
  1570. )
  1571. elif o.type == "gradient_to_parameter":
  1572. return ep.OutputSpec(
  1573. kind=ep.OutputKind.GRADIENT_TO_PARAMETER,
  1574. arg=ep.TensorArgument(name=o.gradient_to_parameter.arg.name),
  1575. target=o.gradient_to_parameter.parameter_name,
  1576. )
  1577. elif o.type == "gradient_to_user_input":
  1578. return ep.OutputSpec(
  1579. kind=ep.OutputKind.GRADIENT_TO_USER_INPUT,
  1580. arg=ep.TensorArgument(name=o.gradient_to_user_input.arg.name),
  1581. target=o.gradient_to_user_input.user_input_name,
  1582. )
  1583. elif o.type == "user_input_mutation":
  1584. return ep.OutputSpec(
  1585. kind=ep.OutputKind.USER_INPUT_MUTATION,
  1586. arg=ep.TensorArgument(name=o.user_input_mutation.arg.name),
  1587. target=o.user_input_mutation.user_input_name,
  1588. )
  1589. elif o.type == "token":
  1590. return ep.OutputSpec(
  1591. kind=ep.OutputKind.TOKEN,
  1592. arg=ep.TokenArgument(name=o.token.arg.name),
  1593. target=None
  1594. )
  1595. else:
  1596. raise AssertionError(f"Unknown output spec {o}")
  1597. def deserialize_signature(self, sig: GraphSignature) -> ep.ExportGraphSignature:
  1598. return ep.ExportGraphSignature(
  1599. input_specs=[self.deserialize_input_spec(i) for i in sig.input_specs],
  1600. output_specs=[self.deserialize_output_spec(o) for o in sig.output_specs],
  1601. )
  1602. def deserialize(
  1603. self,
  1604. serialized_graph_module: GraphModule,
  1605. serialized_state_dict: Union[Dict[str, torch.Tensor], bytes],
  1606. constants: Union[Dict[str, Any], bytes],
  1607. example_inputs: Optional[Union[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]], bytes]] = None,
  1608. symbol_name_to_range: Optional[Dict[str, symbolic_shapes.ValueRanges]] = None,
  1609. ) -> Result:
  1610. global _CURRENT_DESERIALIZER
  1611. assert _CURRENT_DESERIALIZER is None
  1612. _CURRENT_DESERIALIZER = self
  1613. try:
  1614. self.shape_env = symbolic_shapes.ShapeEnv(assume_static_by_default=True)
  1615. self.fake_tensor_mode = FakeTensorMode(
  1616. allow_fallback_kernels=False,
  1617. allow_non_fake_inputs=True,
  1618. shape_env=self.shape_env,
  1619. )
  1620. self.symbol_name_to_symbol: Dict[str, sympy.Symbol] = {}
  1621. self.constants = deserialize_torch_artifact(constants)
  1622. self.signature = self.deserialize_signature(serialized_graph_module.signature)
  1623. # deserialization does analysis with checks on 0/1, so we create fake range constraints and
  1624. # restore the original range constraints afterwards
  1625. self.symbol_name_to_range = {}
  1626. if symbol_name_to_range:
  1627. for k, vr in symbol_name_to_range.items():
  1628. lower = int(vr.lower)
  1629. if vr.upper >= 2: # max is >= 2, not sym bool range
  1630. lower = max(2, lower)
  1631. self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges(_int_to_sympy_int(lower), vr.upper)
  1632. if example_inputs is not None and len(example_inputs) > 0:
  1633. self.example_inputs = deserialize_torch_artifact(example_inputs)
  1634. else:
  1635. self.example_inputs = None
  1636. self.deserialize_graph(serialized_graph_module.graph)
  1637. module_call_graph = self.deserialize_module_call_graph(
  1638. serialized_graph_module.module_call_graph
  1639. )
  1640. return GraphModuleDeserializer.Result(
  1641. graph_module=ep._create_graph_module_for_export(
  1642. self.module, self.graph
  1643. ),
  1644. signature=self.signature,
  1645. module_call_graph=module_call_graph,
  1646. names_to_symbols=self.symbol_name_to_symbol,
  1647. state_dict=deserialize_torch_artifact(serialized_state_dict),
  1648. constants=self.constants,
  1649. example_inputs=self.example_inputs,
  1650. )
  1651. finally:
  1652. _CURRENT_DESERIALIZER = None
  1653. def sync_fx_node(self, name: str, fx_node: torch.fx.Node):
  1654. if name in self.serialized_name_to_node:
  1655. raise SerializeError(f"Node {name} has already been deserialized before.")
  1656. self.serialized_name_to_node[name] = fx_node
  1657. assert "val" not in fx_node.meta
  1658. fx_node.meta["val"] = self.serialized_name_to_meta[name]
  1659. def deserialize_sym_op_inputs(self, inputs):
  1660. return tuple(self.deserialize_input(input.arg) for input in inputs)
  1661. def deserialize_inputs(self, target: torch._ops.OpOverload, serialized_node: Node):
  1662. schema_args = target._schema.arguments
  1663. actual_args = {
  1664. input.name: self.deserialize_input(input.arg)
  1665. for input in serialized_node.inputs
  1666. }
  1667. args = []
  1668. kwargs = {}
  1669. for schema_arg in schema_args:
  1670. is_positional = (
  1671. not schema_arg.has_default_value() and not schema_arg.kwarg_only
  1672. )
  1673. if is_positional:
  1674. args.append(actual_args[schema_arg.name])
  1675. else:
  1676. if schema_arg.name in actual_args:
  1677. kwargs[schema_arg.name] = actual_args[schema_arg.name]
  1678. return tuple(args), kwargs
  1679. def deserialize_hoo_inputs(self, inputs: List[NamedArgument]):
  1680. """
  1681. For deserializing HOO inputs since HOOs do not have a schema.
  1682. """
  1683. args = []
  1684. kwargs = {}
  1685. for input_ in inputs:
  1686. if input_.name != "":
  1687. kwargs[input_.name] = self.deserialize_input(input_.arg)
  1688. else:
  1689. args.append(self.deserialize_input(input_.arg))
  1690. return (tuple(args), kwargs)
  1691. def deserialize_input(self, inp: Argument) -> Any:
  1692. value = inp.value
  1693. typ_ = inp.type
  1694. if typ_ == "as_none":
  1695. # None should converted as None, but is encoded as bool in serialized
  1696. # Convert serialized object to torch equivalent
  1697. return None
  1698. elif typ_ == "as_tensor":
  1699. return self.serialized_name_to_node[inp.as_tensor.name]
  1700. elif typ_ == "as_scalar_type":
  1701. return _SERIALIZE_TO_TORCH_DTYPE[inp.as_scalar_type]
  1702. elif typ_ == "as_memory_format":
  1703. return _SERIALIZE_TO_TORCH_MEMORY_FORMAT[inp.as_memory_format]
  1704. elif typ_ == "as_layout":
  1705. return _SERIALIZE_TO_TORCH_LAYOUT[inp.as_layout]
  1706. elif typ_ == "as_graph":
  1707. assert isinstance(value, GraphArgument)
  1708. with self.save_graph_module():
  1709. self.deserialize_graph(value.graph)
  1710. submodule = ep._create_graph_module_for_export(self.module, self.graph)
  1711. self.module.register_module(value.name, submodule)
  1712. return self.graph.create_node(
  1713. "get_attr",
  1714. value.name,
  1715. name=value.name,
  1716. )
  1717. elif typ_ == "as_device":
  1718. return deserialize_device(inp.as_device)
  1719. elif typ_ == "as_int":
  1720. return inp.as_int
  1721. elif typ_ == "as_float":
  1722. return inp.as_float
  1723. elif typ_ == "as_bool":
  1724. return inp.as_bool
  1725. elif typ_ == "as_string":
  1726. return inp.as_string
  1727. elif typ_ == "as_sym_int":
  1728. return self.deserialize_sym_argument(inp.as_sym_int)
  1729. elif typ_ == "as_sym_bool":
  1730. return self.deserialize_sym_argument(inp.as_sym_bool)
  1731. elif isinstance(value, list):
  1732. if len(value) == 0:
  1733. return []
  1734. elif typ_ == "as_tensors":
  1735. result = []
  1736. for arg in value:
  1737. result.append(self.serialized_name_to_node[arg.name])
  1738. return result
  1739. elif typ_ in ("as_ints", "as_floats", "as_bools", "as_strings"):
  1740. # convert from serialized.python.types.List to python list
  1741. return list(value)
  1742. elif typ_ in ("as_sym_ints", "as_sym_bools"):
  1743. return [self.deserialize_sym_argument(arg) for arg in value]
  1744. elif typ_ == "as_optional_tensors":
  1745. def deserialize_optional_tensor_args(a):
  1746. if a.type == "as_none":
  1747. return None
  1748. elif a.type == "as_tensor":
  1749. return self.serialized_name_to_node[a.value.name]
  1750. else:
  1751. raise SerializeError(f"Unhandled argument {inp}")
  1752. return list(map(deserialize_optional_tensor_args, value))
  1753. else:
  1754. raise SerializeError(f"Unhandled argument {inp}")
  1755. elif typ_ == "as_custom_obj":
  1756. if inp.as_custom_obj.name in self.serialized_name_to_node:
  1757. # Custom object has been lifted as an input
  1758. return self.serialized_name_to_node[inp.as_custom_obj.name]
  1759. return self.constants[inp.as_custom_obj.name]
  1760. elif typ_ == "as_operator":
  1761. return self.deserialize_operator(inp.as_operator)
  1762. else:
  1763. raise SerializeError(f"Unhandled argument {inp}")
  1764. def deserialize_constant_input(self, inp: ConstantValue) -> Any:
  1765. if inp.type == "as_int":
  1766. return int(inp.as_int)
  1767. elif inp.type == "as_float":
  1768. return float(inp.as_float)
  1769. elif inp.type == "as_string":
  1770. return str(inp.as_string)
  1771. elif inp.type == "as_bool":
  1772. return bool(inp.as_bool)
  1773. elif inp.type == "as_none":
  1774. return None
  1775. else:
  1776. raise SerializeError(f"Unhandled constant argument {inp} to deserialize")
  1777. def deserialize_sym_argument(self, sym_arg):
  1778. if isinstance(sym_arg, SymIntArgument):
  1779. if sym_arg.type == "as_int":
  1780. return sym_arg.as_int
  1781. elif sym_arg.type == "as_name":
  1782. return self.serialized_name_to_node[sym_arg.as_name]
  1783. elif isinstance(sym_arg, SymBoolArgument):
  1784. if sym_arg.type == "as_bool":
  1785. return sym_arg.as_bool
  1786. elif sym_arg.type == "as_name":
  1787. return self.serialized_name_to_node[sym_arg.as_name]
  1788. raise SerializeError(f"Unknown symbolic argument type: {sym_arg}")
  1789. def deserialize_sym_op_outputs(self, serialized_node: Node, fx_node: torch.fx.Node):
  1790. self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node)
  1791. def deserialize_outputs(self, serialized_node: Node, fx_node: torch.fx.Node):
  1792. # Check single value return
  1793. if len(serialized_node.outputs) == 0:
  1794. return
  1795. if (
  1796. len(serialized_node.outputs) == 1
  1797. and serialized_node.outputs[0].type == "as_tensor"
  1798. ):
  1799. self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node)
  1800. return
  1801. elif len(serialized_node.outputs) == 1 and isinstance(
  1802. serialized_node.outputs[0].value, (SymIntArgument, SymBoolArgument)
  1803. ):
  1804. self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node)
  1805. return
  1806. self.deserialize_multiple_outputs(serialized_node, fx_node)
  1807. def deserialize_multiple_outputs(
  1808. self, serialized_node: Node, fx_node: torch.fx.Node
  1809. ) -> None:
  1810. deserialized_metadata = self.deserialize_metadata(serialized_node.metadata)
  1811. def generate_getitem(
  1812. meta_val,
  1813. fx_node: torch.fx.Node,
  1814. arg: Union[TensorArgument, SymIntArgument],
  1815. idx: int,
  1816. ):
  1817. if isinstance(arg, TensorArgument):
  1818. name = arg.name
  1819. elif isinstance(arg, SymIntArgument):
  1820. name = arg.as_name
  1821. else:
  1822. raise AssertionError(
  1823. f"generate_getitem got unknown argument type {type(arg)}"
  1824. )
  1825. individual_output = self.graph.create_node(
  1826. "call_function",
  1827. operator.getitem,
  1828. (fx_node, idx),
  1829. name=name,
  1830. )
  1831. self.sync_fx_node(name, individual_output)
  1832. meta_val.append(self.serialized_name_to_meta[name])
  1833. # The derived `getitem` nodes should have the same stacktrace as the
  1834. # original `fx_node`
  1835. individual_output.meta.update(deserialized_metadata)
  1836. def generate_getitems(meta_val, fx_node: torch.fx.Node, args):
  1837. for idx, arg in enumerate(args):
  1838. if isinstance(arg, Argument):
  1839. arg = arg.value
  1840. if isinstance(arg, (TensorArgument, SymIntArgument)):
  1841. generate_getitem(meta_val, fx_node, arg, idx)
  1842. elif isinstance(arg, (list, tuple)):
  1843. list_output = self.graph.create_node(
  1844. "call_function",
  1845. operator.getitem,
  1846. (fx_node, idx),
  1847. )
  1848. meta_val.append([])
  1849. generate_getitems(meta_val[-1], list_output, arg)
  1850. list_output.meta.update(deserialized_metadata)
  1851. list_output.meta["val"] = meta_val[-1]
  1852. else:
  1853. raise NotImplementedError(f"Unimplemented node output type: {arg}")
  1854. # Convert multiple return types to FX format.
  1855. # In FX, each node only returns one value. So in order to represent
  1856. # multiple return values, we have to emit a `getitem` node for each
  1857. # return value.
  1858. # This performs the inverse mapping of the `serialize_outputs` call in
  1859. # serialization, see [NOTE: Multiple outputs]
  1860. meta_val: List[Any] = []
  1861. if len(serialized_node.outputs) == 1:
  1862. assert isinstance(serialized_node.outputs[0].value, list)
  1863. assert isinstance(serialized_node.outputs[0].value[0], TensorArgument)
  1864. generate_getitems(meta_val, fx_node, serialized_node.outputs[0].as_tensors)
  1865. else:
  1866. generate_getitems(meta_val, fx_node, serialized_node.outputs)
  1867. # also update the metaval for `fx_node` to be a list(meta)
  1868. fx_node.meta["val"] = tuple(meta_val)
  1869. self.serialized_name_to_node[fx_node.name] = fx_node
  1870. def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]:
  1871. ret: Dict[str, Any] = {}
  1872. if stack_trace := metadata.get("stack_trace"):
  1873. ret["stack_trace"] = stack_trace
  1874. def deserialize_meta_func(serialized_target: str):
  1875. module = None
  1876. if serialized_target.startswith("torch.nn"):
  1877. module = torch.nn
  1878. serialized_target_names = serialized_target.split(".")[2:]
  1879. elif serialized_target.startswith("torch"):
  1880. module = torch
  1881. serialized_target_names = serialized_target.split(".")[1:]
  1882. else:
  1883. return self.deserialize_operator(serialized_target)
  1884. target = module
  1885. for name in serialized_target_names:
  1886. if not hasattr(target, name):
  1887. return serialized_target
  1888. else:
  1889. target = getattr(target, name)
  1890. return target
  1891. if nn_module_stack_str := metadata.get("nn_module_stack"):
  1892. # Originally serialized to "key,orig_path,type_str"
  1893. def import_nn_module_stack(key, path, ty):
  1894. return key, (path, ty)
  1895. # Helper function that splits strings by commas except for those
  1896. # encapsulated by parens, which are valid traces.
  1897. # TODO: Currently this is needed due to indexing Sequential
  1898. # layers introducing names in the form "layer.slice(1, None, None)".
  1899. # If that naming is improved, this fancier splitting can probably be
  1900. # reverted to a simple split by comma.
  1901. def metadata_split(metadata):
  1902. # Remove the parentheses and commas inside them
  1903. metadata = re.sub(r'\(.*?\)', '', metadata)
  1904. # Split the string by comma, except for those inside parentheses
  1905. return re.split(r'(?<!\()\s*,\s*(?!\()', metadata)
  1906. nn_module_stack = dict(
  1907. import_nn_module_stack(*metadata_split(item))
  1908. for item in nn_module_stack_str.split(ST_DELIMITER)
  1909. )
  1910. ret["nn_module_stack"] = nn_module_stack
  1911. if source_fn_st_str := metadata.get("source_fn_stack"):
  1912. # Originally serializes to "fx_node_name,op_str"
  1913. source_fn_st = []
  1914. for source_fn_str in source_fn_st_str.split(ST_DELIMITER):
  1915. name, target_str = source_fn_str.split(",")
  1916. source_fn_st.append((name, deserialize_meta_func(target_str)))
  1917. ret["source_fn_stack"] = source_fn_st
  1918. if torch_fn_str := metadata.get("torch_fn"):
  1919. ret["torch_fn"] = tuple(torch_fn_str.split(ST_DELIMITER))
  1920. return ret
  1921. def deserialize_argument_spec(self, x: Argument) -> ep.ArgumentSpec:
  1922. if x.type == "as_tensor":
  1923. return ep.TensorArgument(name=x.as_tensor.name)
  1924. elif x.type == "as_sym_int":
  1925. return ep.SymIntArgument(name=x.as_sym_int.as_name)
  1926. elif x.type == "as_custom_obj":
  1927. return ep.ConstantArgument(name=x.as_custom_obj.name, value=self.deserialize_input(x))
  1928. else:
  1929. return ep.ConstantArgument(name="", value=self.deserialize_input(x))
  1930. def deserialize_module_call_signature(
  1931. self, module_call_signature: ModuleCallSignature
  1932. ) -> ep.ModuleCallSignature:
  1933. return ep.ModuleCallSignature(
  1934. inputs=[
  1935. self.deserialize_argument_spec(x) for x in module_call_signature.inputs
  1936. ],
  1937. outputs=[
  1938. self.deserialize_argument_spec(x) for x in module_call_signature.outputs
  1939. ],
  1940. in_spec=treespec_loads(module_call_signature.in_spec),
  1941. out_spec=treespec_loads(module_call_signature.out_spec),
  1942. )
  1943. def deserialize_module_call_graph(
  1944. self, module_call_graph: List[ModuleCallEntry]
  1945. ) -> List[ep.ModuleCallEntry]:
  1946. return [
  1947. ep.ModuleCallEntry(
  1948. fqn=entry.fqn,
  1949. signature=(
  1950. self.deserialize_module_call_signature(entry.signature)
  1951. if entry.signature
  1952. else None
  1953. ),
  1954. )
  1955. for entry in module_call_graph
  1956. ]
  1957. @final
  1958. class ExportedProgramDeserializer(metaclass=Final):
  1959. def __init__(self, expected_opset_version: Optional[Dict[str, int]] = None):
  1960. self.expected_opset_version: Dict[str, int] = {}
  1961. if expected_opset_version:
  1962. self.expected_opset_version.update(expected_opset_version)
  1963. if "aten" not in self.expected_opset_version:
  1964. self.expected_opset_version["aten"] = torch._C._get_max_operator_version()
  1965. def deserialize_range_constraints(
  1966. self,
  1967. symbol_name_to_range: Dict[str, symbolic_shapes.ValueRanges],
  1968. symbol_name_to_symbol: Dict[str, sympy.Symbol],
  1969. ) -> Dict[sympy.Symbol, ValueRanges]:
  1970. range_constraints = {}
  1971. for k, v in symbol_name_to_range.items():
  1972. if symbol := symbol_name_to_symbol.get(k):
  1973. range_constraints[symbol] = v # type: ignore[arg-type]
  1974. else:
  1975. log.warning(f"Symbol {k} did not appear in the graph that was deserialized") # noqa: G004
  1976. return range_constraints
  1977. def deserialize(
  1978. self,
  1979. exported_program: ExportedProgram,
  1980. state_dict: Union[Dict[str, torch.Tensor], bytes],
  1981. constants: Union[Dict[str, torch.Tensor], bytes],
  1982. example_inputs: Optional[Union[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]], bytes]] = None,
  1983. ) -> ep.ExportedProgram:
  1984. assert isinstance(exported_program, ExportedProgram)
  1985. version = exported_program.schema_version
  1986. # TODO(zhxchen17) blocked on thrift schema refactor
  1987. if version.major != SCHEMA_VERSION[0] and not (version.major == 0 and version.minor == 0):
  1988. raise SerializeError(
  1989. f"Serialized schema version {exported_program.schema_version} "
  1990. f"does not match our current schema version {SCHEMA_VERSION}."
  1991. )
  1992. symbol_name_to_range = {
  1993. k: symbolic_shapes.ValueRanges(
  1994. _int_to_sympy_int(v.min_val), _int_to_sympy_int(v.max_val)
  1995. )
  1996. for k, v in exported_program.range_constraints.items()
  1997. }
  1998. res = (
  1999. GraphModuleDeserializer()
  2000. .deserialize(
  2001. exported_program.graph_module,
  2002. state_dict,
  2003. constants,
  2004. example_inputs,
  2005. symbol_name_to_range,
  2006. )
  2007. )
  2008. range_constraints = self.deserialize_range_constraints(
  2009. symbol_name_to_range,
  2010. res.names_to_symbols,
  2011. )
  2012. model_opset_version: Optional[Dict[str, int]] = exported_program.opset_version
  2013. return ep.ExportedProgram(
  2014. root=res.graph_module,
  2015. graph=res.graph_module.graph,
  2016. graph_signature=res.signature,
  2017. state_dict=res.state_dict, # type: ignore[arg-type]
  2018. range_constraints=range_constraints,
  2019. module_call_graph=res.module_call_graph,
  2020. example_inputs=res.example_inputs,
  2021. verifier=load_verifier(exported_program.dialect),
  2022. constants=res.constants,
  2023. )
  2024. class EnumEncoder(json.JSONEncoder):
  2025. def default(self, obj):
  2026. if isinstance(obj, Enum):
  2027. return obj.value
  2028. if isinstance(obj, bytes):
  2029. return base64.b64encode(obj).decode("utf-8")
  2030. return super().default(obj)
  2031. def _dataclass_to_dict(obj):
  2032. if isinstance(obj, _Union):
  2033. return {obj.type: _dataclass_to_dict(obj.value)}
  2034. elif dataclasses.is_dataclass(obj):
  2035. return {
  2036. f.name: _dataclass_to_dict(getattr(obj, f.name))
  2037. for f in dataclasses.fields(obj)
  2038. if not (f.default is None and getattr(obj, f.name) is None)
  2039. }
  2040. elif isinstance(obj, list):
  2041. return [_dataclass_to_dict(x) for x in obj]
  2042. elif isinstance(obj, tuple):
  2043. return tuple(_dataclass_to_dict(x) for x in obj)
  2044. elif isinstance(obj, dict):
  2045. return {k: _dataclass_to_dict(v) for k, v in obj.items()}
  2046. else:
  2047. return obj
  2048. def serialize(
  2049. exported_program: ep.ExportedProgram,
  2050. opset_version: Optional[Dict[str, int]] = None,
  2051. ) -> SerializedArtifact:
  2052. serialized_program = ExportedProgramSerializer(opset_version).serialize(
  2053. exported_program
  2054. )
  2055. assert isinstance(serialized_program.exported_program, ExportedProgram)
  2056. json_program = json.dumps(
  2057. _dataclass_to_dict(serialized_program.exported_program), cls=EnumEncoder
  2058. )
  2059. json_bytes = json_program.encode("utf-8")
  2060. artifact = SerializedArtifact(
  2061. json_bytes,
  2062. serialized_program.state_dict,
  2063. serialized_program.constants,
  2064. serialized_program.example_inputs
  2065. )
  2066. return artifact
  2067. def _dict_to_dataclass(cls, data):
  2068. assert not isinstance(cls, str), f"Unresolved class type: '{cls}'."
  2069. if typing.get_origin(cls) == typing.Union and type(None) in typing.get_args(cls):
  2070. if data is None:
  2071. return None
  2072. ty_args = typing.get_args(cls)
  2073. assert len(ty_args) == 2
  2074. return _dict_to_dataclass(ty_args[0], data)
  2075. elif isinstance(cls, type) and issubclass(cls, _Union):
  2076. assert isinstance(data, dict)
  2077. assert len(data) == 1
  2078. _type = next(iter(data.keys()))
  2079. _value = next(iter(data.values()))
  2080. assert isinstance(_type, str)
  2081. field_type = cls.__annotations__[_type]
  2082. return cls.create(**{_type: _dict_to_dataclass(field_type, _value)})
  2083. elif dataclasses.is_dataclass(cls):
  2084. obj = cls(**data) # type: ignore[assignment]
  2085. type_hints = typing.get_type_hints(cls)
  2086. for f in dataclasses.fields(cls):
  2087. name = f.name
  2088. new_field_obj = _dict_to_dataclass(type_hints[name], getattr(obj, name))
  2089. setattr(obj, name, new_field_obj)
  2090. return obj
  2091. elif isinstance(data, list):
  2092. if len(data) == 0:
  2093. return data
  2094. d_type = typing.get_args(cls)[0]
  2095. return [_dict_to_dataclass(d_type, d) for d in data]
  2096. elif isinstance(data, dict):
  2097. v_type = typing.get_args(cls)[1]
  2098. return {k: _dict_to_dataclass(v_type, v) for k, v in data.items()}
  2099. return data
  2100. def deserialize(
  2101. artifact: SerializedArtifact,
  2102. expected_opset_version: Optional[Dict[str, int]] = None,
  2103. ) -> ep.ExportedProgram:
  2104. assert isinstance(artifact.exported_program, bytes)
  2105. exported_program_str = artifact.exported_program.decode("utf-8")
  2106. exported_program_dict = json.loads(exported_program_str)
  2107. serialized_exported_program = _dict_to_dataclass(ExportedProgram, exported_program_dict)
  2108. return (
  2109. ExportedProgramDeserializer(expected_opset_version)
  2110. .deserialize(
  2111. serialized_exported_program,
  2112. artifact.state_dict,
  2113. artifact.constants,
  2114. artifact.example_inputs,
  2115. )
  2116. )
  2117. def _canonicalize_graph(
  2118. sorted_inputs, sorted_outputs, graph
  2119. ) -> Tuple[Graph, Dict[str, str]]:
  2120. def _get_argument(a: Argument):
  2121. if a.type == "as_none":
  2122. return None
  2123. elif a.type == "as_tensor":
  2124. return a.as_tensor
  2125. elif a.type == "as_tensors":
  2126. return a.as_tensors
  2127. elif a.type == "as_int":
  2128. return None
  2129. elif a.type == "as_ints":
  2130. return None
  2131. elif a.type == "as_float":
  2132. return None
  2133. elif a.type == "as_floats":
  2134. return None
  2135. elif a.type == "as_string":
  2136. return None
  2137. elif a.type == "as_strings":
  2138. return None
  2139. elif a.type == "as_sym_int":
  2140. return a.as_sym_int
  2141. elif a.type == "as_sym_ints":
  2142. return a.as_sym_ints
  2143. elif a.type == "as_scalar_type":
  2144. return None
  2145. elif a.type == "as_memory_format":
  2146. return None
  2147. elif a.type == "as_layout":
  2148. return None
  2149. elif a.type == "as_device":
  2150. return None
  2151. elif a.type == "as_bool":
  2152. return None
  2153. elif a.type == "as_bools":
  2154. return None
  2155. elif a.type == "as_sym_bool":
  2156. return a.as_sym_bool
  2157. elif a.type == "as_sym_bools":
  2158. return a.as_sym_bools
  2159. elif a.type == "as_graph":
  2160. return None
  2161. elif a.type == "as_optional_tensors":
  2162. return a.as_optional_tensors
  2163. elif a.type == "as_custom_obj":
  2164. return None
  2165. elif a.type == "as_operator":
  2166. return None
  2167. else:
  2168. raise AssertionError(f"Unknown input type to the ExportedProgram: {a}")
  2169. # Stage 1: Reorder named items.
  2170. def for_args(f, a):
  2171. assert isinstance(a, Argument)
  2172. pytree.tree_map(f, _get_argument(a))
  2173. def sort_nodes(nodes):
  2174. @dataclass
  2175. class Edges:
  2176. outs: List[int]
  2177. ins: int
  2178. graph_inputs: Set[str] = set()
  2179. def_table: Dict[str, int] = {}
  2180. edges: Dict[int, Edges] = {}
  2181. candidates: List[Tuple[str, List[Tuple[str, List[int]]], int]] = []
  2182. rank: Dict[str, int] = {}
  2183. ret: List[Node] = []
  2184. def get_name(a) -> Optional[str]:
  2185. if a is None:
  2186. return None
  2187. if isinstance(a, TensorArgument):
  2188. return a.name
  2189. elif isinstance(a, (SymIntArgument, SymBoolArgument)):
  2190. if a.type == "as_name":
  2191. return a.as_name
  2192. elif a.type in ("as_int", "as_bool"):
  2193. return None
  2194. else:
  2195. raise AssertionError(f"Unknown argument type: {a}")
  2196. elif isinstance(a, OptionalTensorArgument):
  2197. if a.type == "as_tensor":
  2198. return a.as_tensor.name
  2199. elif a.type == "as_none":
  2200. return None
  2201. else:
  2202. raise AssertionError(f"Unknown optional tensor type: {a}")
  2203. else:
  2204. raise AssertionError(f"Unknown argument type: {a}")
  2205. for i in sorted_inputs:
  2206. def add_input(a):
  2207. if s := get_name(a):
  2208. graph_inputs.add(s)
  2209. for_args(add_input, i)
  2210. for idx, node in enumerate(nodes):
  2211. def add_def(a):
  2212. if s := get_name(a):
  2213. assert s not in def_table
  2214. def_table[s] = idx
  2215. for o in node.outputs:
  2216. for_args(add_def, o)
  2217. edges[idx] = Edges([], 0)
  2218. for idx, user in enumerate(nodes):
  2219. def add_edge(a):
  2220. if s := get_name(a):
  2221. if s not in def_table:
  2222. assert s in graph_inputs
  2223. return
  2224. src = def_table[s]
  2225. edges[src].outs.append(idx)
  2226. edges[idx].ins += 1
  2227. for i in user.inputs:
  2228. for_args(add_edge, i.arg)
  2229. def add_rank(a):
  2230. if s := get_name(a):
  2231. assert s not in rank
  2232. rank[s] = len(rank)
  2233. def get_rank(a):
  2234. if s := get_name(a):
  2235. return rank[s]
  2236. else:
  2237. return -1
  2238. for i in sorted_inputs:
  2239. for_args(add_rank, i)
  2240. def add_candidate(idx: int):
  2241. def get_ranks(i):
  2242. ranks = []
  2243. for_args(lambda x: ranks.append(get_rank(x)), i)
  2244. return ranks
  2245. node = nodes[idx]
  2246. args_rank = [(a.name, get_ranks(a.arg)) for a in node.inputs]
  2247. heapq.heappush(candidates, (node.target, args_rank, idx))
  2248. for idx, e in edges.items():
  2249. if e.ins == 0:
  2250. add_candidate(idx)
  2251. while len(candidates) > 0:
  2252. _, _, idx = heapq.heappop(candidates)
  2253. node = nodes[idx]
  2254. for o in node.outputs:
  2255. for_args(add_rank, o)
  2256. ret.append(node)
  2257. assert idx in edges
  2258. for user in edges[idx].outs:
  2259. e = edges[user]
  2260. assert e.ins > 0
  2261. e.ins -= 1
  2262. if e.ins == 0:
  2263. add_candidate(user)
  2264. edges[idx].outs.clear()
  2265. return ret
  2266. sorted_nodes = sort_nodes(graph.nodes)
  2267. assert len(sorted_nodes) == len(graph.nodes)
  2268. # Stage 2: Rename nodes.
  2269. name_table: Dict[str, str] = {}
  2270. def rename_def(a):
  2271. def _rename(arg_name, values):
  2272. new_name = f"_{len(name_table)}"
  2273. assert arg_name not in name_table
  2274. name_table[arg_name] = new_name
  2275. assert arg_name in values
  2276. values[new_name] = values.pop(arg_name)
  2277. return new_name
  2278. if a is None:
  2279. return
  2280. if isinstance(a, TensorArgument):
  2281. a.name = _rename(a.name, graph.tensor_values)
  2282. elif isinstance(a, SymIntArgument):
  2283. if a.type == "as_name":
  2284. a.as_name = _rename(a.as_name, graph.sym_int_values)
  2285. elif isinstance(a, SymBoolArgument):
  2286. if a.type == "as_name":
  2287. a.as_name = _rename(a.as_name, graph.sym_bool_values)
  2288. else:
  2289. raise AssertionError(f"Unknown argument type: {a}")
  2290. def replace_use(a):
  2291. if a is None:
  2292. return
  2293. if isinstance(a, TensorArgument):
  2294. a.name = name_table.get(a.name, a.name)
  2295. elif isinstance(a, SymIntArgument):
  2296. if a.type == "as_name":
  2297. a.as_name = name_table.get(a.as_name, a.as_name)
  2298. elif isinstance(a, SymBoolArgument):
  2299. if a.type == "as_name":
  2300. a.as_name = name_table.get(a.as_name, a.as_name)
  2301. elif isinstance(a, OptionalTensorArgument):
  2302. if a.type == "as_tensor":
  2303. a.as_tensor.name = name_table.get(a.as_tensor.name, a.as_tensor.name)
  2304. else:
  2305. raise AssertionError(f"Unknown argument type: {a}")
  2306. for i in sorted_inputs:
  2307. for_args(rename_def, i)
  2308. for n in sorted_nodes:
  2309. for o in n.outputs:
  2310. for_args(rename_def, o)
  2311. for n in sorted_nodes:
  2312. for i in n.inputs:
  2313. for_args(replace_use, i.arg)
  2314. for o in sorted_outputs:
  2315. for_args(replace_use, o)
  2316. # Stage 3: Remove unstable fields.
  2317. for n in sorted_nodes:
  2318. n.metadata.clear()
  2319. # Stage 4: Aggregate values.
  2320. sorted_tensor_values = dict(sorted(graph.tensor_values.items(), key=operator.itemgetter(0)))
  2321. sorted_sym_int_values = dict(
  2322. sorted(graph.sym_int_values.items(), key=operator.itemgetter(0))
  2323. )
  2324. sorted_sym_bool_values = dict(
  2325. sorted(graph.sym_bool_values.items(), key=operator.itemgetter(0))
  2326. )
  2327. # Stage 5: Recurse in subgraphs.
  2328. counter = 0
  2329. for node in sorted_nodes:
  2330. for i in node.inputs:
  2331. a = i.arg
  2332. if a.type == "as_graph":
  2333. a.as_graph.graph = _canonicalize_graph(
  2334. a.as_graph.graph.inputs, a.as_graph.graph.outputs, a.as_graph.graph
  2335. )
  2336. a.as_graph.name = f"_g{counter}"
  2337. counter += 1
  2338. graph = Graph(
  2339. inputs=sorted_inputs,
  2340. outputs=sorted_outputs,
  2341. nodes=sorted_nodes,
  2342. tensor_values=sorted_tensor_values,
  2343. sym_int_values=sorted_sym_int_values,
  2344. sym_bool_values=sorted_sym_bool_values,
  2345. is_single_tensor_return=graph.is_single_tensor_return,
  2346. )
  2347. return graph, name_table
  2348. def canonicalize(ep: ExportedProgram) -> ExportedProgram:
  2349. """
  2350. Normalize a serialized ExportedProgram, so that different eager program which
  2351. shares the same semantics can get a single representation on disk.
  2352. This function canonicalizes an ExportedProgram by:
  2353. 1. Sorting nodes in topological order.
  2354. 2. Rename nodes to have unique names.
  2355. 3. Remove unstable fields.
  2356. 4. Aggregate the above program fields.
  2357. 5. Recurse in subgraphs.
  2358. Args:
  2359. ep (ExportedProgram): The ExportedProgram to canonicalize.
  2360. Returns:
  2361. ExportedProgram: The canonicalized exported program.
  2362. """
  2363. ep = copy.deepcopy(ep)
  2364. opset_version = dict(sorted(ep.opset_version.items(), key=operator.itemgetter(0)))
  2365. range_constraints = dict(sorted(ep.range_constraints.items(), key=operator.itemgetter(0)))
  2366. module_call_graph = sorted(ep.graph_module.module_call_graph, key=lambda x: x.fqn)
  2367. signature = ep.graph_module.signature
  2368. graph = ep.graph_module.graph
  2369. assert len(graph.inputs) == len(signature.input_specs)
  2370. assert len(graph.outputs) == len(signature.output_specs)
  2371. def rank_input(inp) -> Tuple[int, Optional[str], int]:
  2372. idx, (arg, spec) = inp
  2373. assert isinstance(spec, InputSpec)
  2374. if spec.type == "user_input":
  2375. return 5, None, idx
  2376. elif spec.type == "parameter":
  2377. return 1, spec.parameter.parameter_name, idx
  2378. elif spec.type == "buffer":
  2379. return 2, spec.buffer.buffer_name, idx
  2380. elif spec.type == "tensor_constant":
  2381. return 3, spec.tensor_constant.tensor_constant_name, idx
  2382. elif spec.type == "custom_obj":
  2383. return 4, spec.custom_obj.custom_obj_name, idx
  2384. elif spec.type == "token":
  2385. return 0, None, idx
  2386. elif spec.type == "constant_input":
  2387. return 6, spec.constant_input.name, idx
  2388. else:
  2389. raise AssertionError(f"Unknown input type: {spec}")
  2390. def rank_output(out) -> Tuple[int, Optional[str], int]:
  2391. idx, (arg, spec) = out
  2392. assert isinstance(spec, OutputSpec)
  2393. if spec.type == "user_output":
  2394. return 3, None, idx
  2395. elif spec.type == "loss_output":
  2396. return 3, None, idx
  2397. elif spec.type == "buffer_mutation":
  2398. return 1, spec.buffer_mutation.buffer_name, idx
  2399. elif spec.type == "gradient_to_parameter":
  2400. return 4, spec.gradient_to_parameter.parameter_name, idx
  2401. elif spec.type == "gradient_to_user_input":
  2402. return 5, None, idx
  2403. elif spec.type == "user_input_mutation":
  2404. return 2, None, idx
  2405. elif spec.type == "token":
  2406. return 0, None, idx
  2407. else:
  2408. raise AssertionError(f"Unknown output type: {spec}")
  2409. sorted_ins = sorted(
  2410. enumerate(zip(graph.inputs, signature.input_specs)), key=rank_input
  2411. )
  2412. sorted_inputs, input_specs = zip(*(i for idx, i in sorted_ins)) # type: ignore[assignment]
  2413. sorted_outs = sorted(
  2414. enumerate(zip(graph.outputs, signature.output_specs)), key=rank_output
  2415. )
  2416. sorted_outputs, output_specs = zip(*(i for idx, i in sorted_outs)) # type: ignore[assignment]
  2417. sorted_graph, replace_table = _canonicalize_graph(
  2418. sorted_inputs, sorted_outputs, graph
  2419. )
  2420. def replace_input(inp):
  2421. assert isinstance(spec, InputSpec)
  2422. if spec.type == "user_input":
  2423. arg = spec.user_input.arg
  2424. if arg.type == "as_tensor":
  2425. t = arg.as_tensor
  2426. t.name = replace_table[t.name]
  2427. elif arg.type == "as_sym_int":
  2428. s = arg.as_sym_int
  2429. if s.type == "as_name":
  2430. s.as_name = replace_table[s.as_name]
  2431. elif s.type == "as_int":
  2432. pass
  2433. else:
  2434. raise AssertionError(f"Unknown sym_int type: {s}")
  2435. elif arg.type in (
  2436. "as_none",
  2437. "as_bool",
  2438. "as_int",
  2439. "as_float",
  2440. "as_string",
  2441. "as_custom_obj",
  2442. ):
  2443. return
  2444. else:
  2445. raise AssertionError(f"Unknown input type: {arg}")
  2446. elif spec.type == "parameter":
  2447. t = spec.parameter.arg
  2448. t.name = replace_table[t.name]
  2449. elif spec.type == "buffer":
  2450. t = spec.buffer.arg
  2451. t.name = replace_table[t.name]
  2452. elif spec.type == "tensor_constant":
  2453. t = spec.tensor_constant.arg
  2454. t.name = replace_table[t.name]
  2455. elif spec.type == "custom_obj":
  2456. return
  2457. elif spec.type == "token":
  2458. tok = spec.token.arg
  2459. tok.name = replace_table[tok.name]
  2460. elif spec.type == "constant_input":
  2461. return
  2462. else:
  2463. raise AssertionError(f"Unknown input type: {spec}")
  2464. def replace_output(out):
  2465. assert isinstance(spec, OutputSpec)
  2466. if spec.type == "user_output":
  2467. arg = spec.user_output.arg
  2468. if arg.type == "as_tensor":
  2469. t = arg.as_tensor
  2470. t.name = replace_table[t.name]
  2471. elif arg.type == "as_sym_int":
  2472. s = arg.as_sym_int
  2473. if s.type == "as_name":
  2474. s.as_name = replace_table[s.as_name]
  2475. elif s.type == "as_int":
  2476. pass
  2477. else:
  2478. raise AssertionError(f"Unknown sym_int type: {s}")
  2479. elif arg.type in ("as_none", "as_int", "as_float", "as_string"):
  2480. return
  2481. else:
  2482. raise AssertionError(f"Unknown input type: {arg}")
  2483. elif spec.type == "loss_output":
  2484. t = spec.loss_output.arg
  2485. t.name = replace_table[t.name]
  2486. elif spec.type == "buffer_mutation":
  2487. t = spec.buffer_mutation.arg
  2488. t.name = replace_table[t.name]
  2489. elif spec.type == "gradient_to_parameter":
  2490. t = spec.gradient_to_parameter.arg
  2491. t.name = replace_table[t.name]
  2492. elif spec.type == "gradient_to_user_input":
  2493. g = spec.gradient_to_user_input
  2494. g.arg.name = replace_table[g.arg.name]
  2495. g.user_input_name = replace_table[g.user_input_name]
  2496. elif spec.type == "user_input_mutation":
  2497. u = spec.user_input_mutation
  2498. u.arg.name = replace_table[u.arg.name]
  2499. u.user_input_name = replace_table[u.user_input_name]
  2500. elif spec.type == "token":
  2501. tok = spec.token.arg
  2502. tok.name = replace_table[tok.name]
  2503. else:
  2504. raise AssertionError(f"Unknown output type: {spec}")
  2505. for spec in input_specs:
  2506. replace_input(spec)
  2507. for spec in output_specs:
  2508. replace_output(spec)
  2509. return ExportedProgram(
  2510. graph_module=GraphModule(
  2511. graph=sorted_graph,
  2512. signature=GraphSignature(
  2513. input_specs=list(input_specs),
  2514. output_specs=list(output_specs),
  2515. ),
  2516. module_call_graph=module_call_graph,
  2517. ),
  2518. opset_version=opset_version,
  2519. range_constraints=range_constraints,
  2520. schema_version=ep.schema_version,
  2521. dialect=ep.dialect
  2522. )
  2523. class CustomOpHandler:
  2524. """
  2525. Base class for handling custom operators.
  2526. """
  2527. @classmethod
  2528. def namespace(cls):
  2529. raise NotImplementedError(f"{cls.__class__} namespace() must be implemented")
  2530. @classmethod
  2531. def op_name(cls, op_type):
  2532. raise NotImplementedError(f"{cls.__class__} op_name() must be implemented")
  2533. @classmethod
  2534. def op_type(cls, op_name):
  2535. raise NotImplementedError(f"{cls.__class__} op_type() must be implemented")
  2536. @classmethod
  2537. def op_schema(cls, op_type):
  2538. raise NotImplementedError(f"{cls.__class__} op_schema() must be implemented")
  2539. def register_custom_op_handler(
  2540. op_handler: CustomOpHandler,
  2541. op_type: Type[Any],
  2542. ):
  2543. """Register custom de/serialization method for a node."""
  2544. assert isinstance(op_handler, CustomOpHandler), f"Expected CustomOpHandler, got {type(op_handler)}."
  2545. _serialization_registry[op_type] = op_handler
  2546. # FIXME: handles deserialization later.
  2547. _deserialization_registry[op_handler.namespace()] = op_handler
  2548. def allowed_registered_op_types():
  2549. return tuple(
  2550. _serialization_registry.keys()
  2551. )
  2552. # Registry to store all custom serialization implementations.
  2553. # The registry maps a operation to its serialization function (a callable), in their own
  2554. # namespace to avoid conflicts.
  2555. # Serialization: Op type --> custom handler.
  2556. # De-serialization: Namespace --> custom handler.
  2557. _serialization_registry: Dict[Type[Any], CustomOpHandler] = {}
  2558. _deserialization_registry: Dict[str, CustomOpHandler] = {}