| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589 |
- # mypy: allow-untyped-defs
- import operator
- from typing import Any, Dict, List, Optional, Set, Tuple, Union
- import torch
- import torch.export._trace
- from torch.export.exported_program import ExportedProgram
- from torch.export.graph_signature import (
- ConstantArgument,
- InputKind,
- InputSpec,
- OutputKind,
- OutputSpec,
- TensorArgument,
- )
- from torch.fx import subgraph_rewriter
- from torch.onnx.utils import _create_jit_graph
- from torchgen.model import FunctionSchema
- def inplace_optimize_sym_size_div(gm: torch.fx.GraphModule):
- def pattern(im, dim, scale):
- sym_size_int = torch.ops.aten.sym_size.int(im, dim)
- scalar_tensor = torch.ops.aten.scalar_tensor(sym_size_int)
- div_scalar_mode = torch.ops.aten.div.Scalar_mode(
- scalar_tensor, scale, rounding_mode="trunc"
- )
- int_tensor = torch.ops.aten.Int.Tensor(div_scalar_mode)
- return int_tensor
- def replacement(im, dim, scale):
- sym_size_int = torch.ops.aten.sym_size.int(im, dim)
- return sym_size_int // scale
- replaced_patterns = subgraph_rewriter.replace_pattern(gm, pattern, replacement)
- def normalize_name(name: str) -> str:
- return name.replace(".", "_")
- def ir_name_to_func_name(name: str) -> str:
- """prim::If -> convert_prim_If"""
- name_list = name.split("::")
- return "convert_" + "_".join(name_list)
- # Those operators will be automatically populated to a instance method
- # of TS2FXGraphConverter with name convert_<namespace>_<opname>().
- # Please check __init__ for method population implementations.
- kind_to_standard_operators = {
- "prim::TupleIndex": operator.getitem,
- "aten::__is__": operator.is_,
- "aten::__isnot__": operator.is_not,
- "aten::__not__": operator.not_,
- "aten::__contains__": operator.contains,
- }
- def get_op_overload(node: torch._C.Node):
- schema_str = node.schema()
- schema = FunctionSchema.parse(schema_str)
- ns, op_name = str(schema.name.name).split("::")
- override = schema.name.overload_name
- try:
- op_overload_mod = getattr(torch.ops, ns)
- op_overload_packet = getattr(op_overload_mod, op_name)
- if override:
- op_overload = getattr(op_overload_packet, override)
- else:
- op_overload = op_overload_packet.default
- except Exception as e:
- raise RuntimeError(
- f"Unable to find operator {node.kind()} with schema {node.schema}"
- ) from e
- return op_overload
- class TS2FXGraphConverter:
- def __init__(
- self,
- ts_graph: Union[torch._C.Graph, torch._C.Block],
- param_names: Set[str],
- buffer_names: Set[str],
- ):
- self.ts_graph = ts_graph
- self.param_names = param_names
- self.buffer_names = buffer_names
- self.fx_graph: torch.fx.Graph = torch.fx.Graph()
- self.input_specs: List[InputSpec] = []
- self.output_specs: List[OutputSpec] = []
- self.name_to_node: Dict[
- str, Union[torch.fx.Node, List[torch.fx.Node], Dict[Any, torch.fx.Node]]
- ] = {}
- self.constant_map: Dict[str, Any] = {}
- self.attribute_map: Dict[str, Any] = {}
- self.tensor_constants: Dict[str, torch.Tensor] = {}
- self.subgraphs: Dict[str, torch.fx.GraphModule] = {}
- # Populate methods for the standard operators.
- for k in kind_to_standard_operators.keys():
- handler_func_name = ir_name_to_func_name(k)
- # Create an indirect function call:
- # convert_<namespace>_<opname> --> lambda node: _convert_standard_operator(node)
- setattr(
- self,
- handler_func_name,
- lambda node: self._convert_standard_operators(node),
- )
- def add_subgraph(self, subgraph) -> str:
- name = f"subgraph_{len(self.subgraphs)}"
- self.subgraphs[name] = subgraph
- return name
- def get_args_kwargs(self, node: torch._C.Node, schema):
- args = []
- kwargs = {}
- for input, schema_arg in zip(node.inputs(), schema.arguments):
- if schema_arg.kwarg_only:
- kwargs[schema_arg.name] = self.get_fx_value(input)
- else:
- args.append(self.get_fx_value(input))
- return tuple(args), kwargs
- def get_fx_value(self, value: torch._C.Value):
- value_name = value.debugName()
- if value_name in self.name_to_node:
- input_node = self.name_to_node[value_name]
- return input_node
- elif value_name in self.attribute_map:
- attr_name = self.attribute_map[value_name]
- if attr_name in self.name_to_node:
- input_node = self.name_to_node[attr_name]
- return input_node
- else:
- raise ValueError(f"Value {attr_name} not found")
- elif value_name in self.constant_map:
- return self.constant_map[value_name]
- else:
- raise ValueError(f"Input {value_name} not found")
- def convert(self) -> torch.fx.GraphModule:
- self.convert_graph_inputs()
- for node in self.ts_graph.nodes():
- self.convert_node(node)
- self.convert_graph_outputs()
- gm = torch.fx.GraphModule(self.subgraphs, self.fx_graph)
- inplace_optimize_sym_size_div(gm)
- gm.graph.lint()
- return gm
- def convert_graph_inputs(self):
- for graph_input in self.ts_graph.inputs():
- name = graph_input.debugName()
- normalized_name = normalize_name(name)
- fx_node = self.fx_graph.placeholder(normalized_name)
- # fx_node.meta["val"] = FakeTensor()
- # TODO: set fx_node.meta["val"]
- self.name_to_node[name] = fx_node
- if name in self.param_names:
- self.input_specs.append(
- InputSpec(
- InputKind.PARAMETER,
- arg=TensorArgument(name=normalized_name),
- target=name,
- )
- )
- elif name in self.buffer_names:
- self.input_specs.append(
- InputSpec(
- InputKind.BUFFER,
- arg=TensorArgument(name=normalized_name),
- target=name,
- persistent=True,
- )
- )
- else:
- self.input_specs.append(
- InputSpec(
- InputKind.USER_INPUT,
- arg=TensorArgument(name=normalized_name),
- target=name,
- )
- )
- def convert_prim_Constant(self, node: torch._C.Node):
- name = node.output().debugName()
- value: Any = None
- if node.hasAttribute("value"):
- constant_kind = node.kindOf("value")
- if constant_kind == "i":
- value = node.i("value")
- elif constant_kind == "f":
- value = node.f("value")
- elif constant_kind == "s":
- value = node.s("value")
- elif constant_kind == "t":
- # lift tensor constant as a placeholder
- placeholder_name = f"constant_{name}"
- fx_node = self.fx_graph.placeholder(placeholder_name)
- self.name_to_node[name] = fx_node
- self.tensor_constants[placeholder_name] = node.t("value")
- self.input_specs.append(
- InputSpec(
- InputKind.CONSTANT_TENSOR,
- arg=TensorArgument(name=placeholder_name),
- target=placeholder_name,
- )
- )
- value = fx_node
- elif constant_kind == "ival":
- value = node.ival("value")
- else:
- raise ValueError(f"Unsupported constant type: {node.kindOf('value')}")
- else:
- value = None
- self.constant_map[name] = value
- def convert_prim_device(self, node: torch._C.Node):
- input_type = node.input().type()
- if input_type.isSubtypeOf(torch._C.TensorType.get()):
- device = input_type.device() # type: ignore[attr-defined]
- output_name = node.output().debugName()
- self.constant_map[output_name] = device
- else:
- raise ValueError(f"Unsupported JitType ({input_type}) when get device")
- def convert_prim_dtype(self, node: torch._C.Node):
- dtype = node.input().type().dtype()
- output_name = node.output().debugName()
- self.constant_map[output_name] = dtype
- def convert_prim_GetAttr(self, node: torch._C.Node):
- def get_attr(name: str):
- if name in self.attribute_map:
- return self.attribute_map[name]
- else:
- raise ValueError(f"Attribute {name} not found")
- output_name = node.output().debugName()
- attr_name = node.s("name")
- input_name = node.input().debugName()
- root_attr_name = get_attr(input_name)
- self.attribute_map[output_name] = (
- f"{root_attr_name}.{attr_name}" if root_attr_name else attr_name
- )
- def convert_call_function_op(self, node: torch._C.Node):
- target = get_op_overload(node)
- if target is torch.ops.aten.size.int:
- target = torch.ops.aten.sym_size.int
- args, kwargs = self.get_args_kwargs(node, target._schema)
- fx_node = self.fx_graph.call_function(target, args, kwargs)
- # TODO: covnert sourceRange() into stack_trace
- # fx_node.meta["stack_trace"] = node.sourceRange()
- output_name = node.output().debugName()
- self.name_to_node[output_name] = fx_node
- def convert_prim_TupleConstruct(self, node: torch._C.Node):
- self._convert_prim_iterator(node)
- def convert_prim_ListConstruct(self, node: torch._C.Node):
- self._convert_prim_iterator(node)
- def _convert_prim_iterator(self, node: torch._C.Node):
- output_list = []
- for inp in node.inputs():
- output_list.append(self.get_fx_value(inp))
- output_name = node.output().debugName()
- self.name_to_node[output_name] = output_list
- def convert_prim_DictConstruct(self, node: torch._C.Node):
- output_dict = {}
- k, v = None, None
- for i, inp in enumerate(node.inputs()):
- # We assume key value are stored in pair in the DictConstruct.
- # The first element is the key and the following is the value.
- if i % 2 == 0:
- k = self.get_fx_value(inp)
- else:
- v = self.get_fx_value(inp)
- assert (
- k is not None and v is not None
- ), "DictConstruct has an empty key value pair."
- output_dict[k] = v
- k, v = None, None
- assert (
- k is None and v is None
- ), "DictConstruct has an odd number of elements (violating our assumption)."
- output_name = node.output().debugName()
- self.name_to_node[output_name] = output_dict
- def convert_prim_ListUnpack(self, node: torch._C.Node):
- self._convert_prim_unpack_iterator(node)
- def convert_prim_TupleUnpack(self, node: torch._C.Node):
- self._convert_prim_unpack_iterator(node)
- def _convert_prim_unpack_iterator(self, node: torch._C.Node):
- # Single input and multiple outputs for unpacking.
- for i, outp in enumerate(node.outputs()):
- outp_name = outp.debugName()
- inp = self.get_fx_value(node.input())
- fx_node = self.fx_graph.call_function(operator.getitem, (inp, i))
- self.name_to_node[outp_name] = fx_node
- def convert_aten_Int(self, node: torch._C.Node):
- # converts aten::Int as aten._to_copy + aten::_local_scalar_dense
- target = torch.ops.aten._to_copy.default
- args = tuple(self.get_fx_value(input) for input in node.inputs())
- to_copy_node = self.fx_graph.call_function(target, args, {"dtype": torch.int32})
- fx_node = self.fx_graph.call_function(
- torch.ops.aten._local_scalar_dense.default, (to_copy_node,)
- )
- # TODO: covnert sourceRange() into stack_trace
- # fx_node.meta["stack_trace"] = node.sourceRange()
- output_name = node.output().debugName()
- self.name_to_node[output_name] = fx_node
- def convert_prim_NumToTensor(self, node: torch._C.Node):
- # converts prim::NumToTensor as aten.scalar_tensor
- target = torch.ops.aten.scalar_tensor
- args = tuple(self.get_fx_value(input) for input in node.inputs())
- fx_node = self.fx_graph.call_function(target, args)
- output_name = node.output().debugName()
- self.name_to_node[output_name] = fx_node
- def convert_prim_CreateObject(self, node: torch._C.Node):
- output_name = node.output().debugName()
- self.attribute_map[output_name] = ""
- def convert_aten__convolution(self, node: torch._C.Node):
- # converts aten::_convolution as aten.convolution, since aten::_convolution
- # doesn't have a meta function
- target = torch.ops.aten.convolution.default
- args, kwargs = self.get_args_kwargs(node, target._schema)
- fx_node = self.fx_graph.call_function(target, args, kwargs)
- output_name = node.output().debugName()
- self.name_to_node[output_name] = fx_node
- def convert_aten_div(self, node: torch._C.Node):
- target = get_op_overload(node)
- schema = target._schema
- args, kwargs = self.get_args_kwargs(node, schema)
- # converts aten::div.Tensor_mode(x, tensor_constant)
- # as aten.div.Scalar_mode(x, tensor_constant.item())
- if schema.overload_name == "Tensor_mode":
- arg1_name = args[1].name
- if arg1_name in self.tensor_constants:
- tensor_constant = self.tensor_constants[arg1_name]
- if tensor_constant.numel() == 1:
- updated_args = list(args)
- updated_args[1] = self.tensor_constants[arg1_name].item()
- fx_node = self.fx_graph.call_function(
- torch.ops.aten.div.Scalar_mode,
- tuple(updated_args),
- kwargs,
- )
- # TODO: covnert sourceRange() into stack_trace
- # fx_node.meta["stack_trace"] = node.sourceRange()
- output_name = node.output().debugName()
- self.name_to_node[output_name] = fx_node
- return
- self.convert_call_function_op(node)
- def convert_aten___getitem__(self, node: torch._C.Node):
- input_container, index = tuple(
- self.get_fx_value(input) for input in node.inputs()
- )
- fx_node = self.fx_graph.call_function(
- operator.getitem, (input_container, index)
- )
- output_name = node.output().debugName()
- self.name_to_node[output_name] = fx_node
- def convert_prim_If(self, node: torch._C.Node):
- inputs = list(node.inputs())
- assert len(inputs) == 1
- predicate = self.get_fx_value(inputs[0])
- # Get union of inputs to blocks
- arguments = set()
- for block in node.blocks():
- block_args = set()
- # TODO: block.inputs(), not sure what theyre used for
- for block_node in block.nodes():
- for block_node_in in block_node.inputs():
- if block_node_in.debugName() in self.name_to_node:
- block_args.add(block_node_in.debugName())
- arguments.update(block_args)
- arguments = list(arguments)
- # Convert blocks to subgraphs
- subgraph_nodes = []
- for block in node.blocks():
- subgraph_converter = TS2FXGraphConverter(block, set(), set())
- subgraph_converter.constant_map = self.constant_map
- for block_arg in arguments:
- normalized_block_arg_name = normalize_name(block_arg)
- placeholder_node = subgraph_converter.fx_graph.placeholder(
- normalized_block_arg_name
- )
- subgraph_converter.name_to_node[block_arg] = placeholder_node
- subgraph = subgraph_converter.convert()
- subgraph_name = self.add_subgraph(subgraph)
- subgraph_nodes.append(self.fx_graph.get_attr(subgraph_name))
- assert len(subgraph_nodes) == 2
- fx_block_args = [self.name_to_node[arg_name] for arg_name in arguments]
- args = (
- predicate,
- subgraph_nodes[0],
- subgraph_nodes[1],
- tuple(fx_block_args),
- )
- cond_node = self.fx_graph.call_function(torch.cond, args, {})
- output_name = node.output().debugName()
- self.name_to_node[output_name] = cond_node
- def convert_aten_Bool(self, node: torch._C.Node):
- self._convert_as_noop(node)
- def _convert_as_noop(self, node: torch._C.Node):
- # Converts the node as a no-op by mapping its output node as arg[0]
- target = get_op_overload(node)
- schema = target._schema
- args, kwargs = self.get_args_kwargs(node, schema)
- output_name = node.output().debugName()
- self.name_to_node[output_name] = args[0]
- def convert_profiler__record_function_enter_new(self, node: torch._C.Node):
- target = torch.ops.profiler._record_function_enter_new
- args = tuple(self.get_fx_value(input) for input in node.inputs())
- fx_node = self.fx_graph.call_function(target, args)
- output_name = node.output().debugName()
- self.name_to_node[output_name] = fx_node
- def convert_profiler__record_function_exit(self, node: torch._C.Node):
- # _record_function_exit has side effect so we keep it in fx.graph
- # currently, _record_function_enter_new and _record_function_exit are
- # discarded during `retrace_as_exported_program`.
- target = torch.ops.profiler._record_function_exit
- args = tuple(self.get_fx_value(input) for input in node.inputs())
- self.fx_graph.call_function(target, args)
- def _convert_standard_operators(self, node: torch._C.Node):
- target = kind_to_standard_operators[node.kind()]
- args = tuple(self.get_fx_value(input) for input in node.inputs())
- fx_node = self.fx_graph.call_function(target, args)
- output_name = node.output().debugName()
- self.name_to_node[output_name] = fx_node
- def convert_node(self, node: torch._C.Node):
- node_kind = node.kind()
- # Get handler based on namespace and operator name.
- # Provide a default node handler as well in case we don't find
- # matching converter for that.
- handler_func_name = ir_name_to_func_name(node_kind)
- handler_func = getattr(self, handler_func_name, self.convert_call_function_op)
- handler_func(node)
- def convert_graph_outputs(self):
- args = []
- for graph_output in self.ts_graph.outputs():
- output_name = graph_output.debugName()
- if output_name in self.name_to_node:
- args.append(self.name_to_node[output_name])
- self.output_specs.append(
- OutputSpec(
- OutputKind.USER_OUTPUT,
- arg=TensorArgument(name=output_name),
- target=output_name,
- )
- )
- elif output_name in self.constant_map:
- args.append(self.constant_map[output_name])
- self.output_specs.append(
- OutputSpec(
- OutputKind.USER_OUTPUT,
- arg=ConstantArgument(
- name=output_name, value=self.constant_map[output_name]
- ),
- target=output_name,
- )
- )
- else:
- raise ValueError(f"Output {output_name} not found")
- self.fx_graph.output(
- args[0]
- ) # Get rid of an extra list wrapped around final output.
- class TS2EPConverter:
- # TorchScript model to ExportedProgram converter
- def __init__(
- self,
- ts_model,
- sample_args: Tuple[Any, ...],
- sample_kwargs: Optional[Dict[str, Any]] = None,
- ):
- self.ts_model = ts_model
- self.ts_graph, self.params, _, _ = _create_jit_graph(ts_model, sample_args)
- self.sample_args = sample_args
- self.sample_kwargs = sample_kwargs
- self.param_names: Set[str] = {name for name, _ in ts_model.named_parameters()}
- self.buffer_names: Set[str] = {name for name, _ in ts_model.named_buffers()}
- def convert(self) -> ExportedProgram:
- graph_converter = TS2FXGraphConverter(
- self.ts_graph, self.param_names, self.buffer_names
- )
- gm = graph_converter.convert()
- ep = self.retrace_as_exported_program(gm, graph_converter.tensor_constants)
- return ep
- def retrace_as_exported_program(self, gm: torch.fx.GraphModule, tensor_constants):
- # TODO: adjust input orders to match GraphSignature convention
- inputs = [*self.sample_args, *self.params, *tensor_constants.values()]
- ep = torch.export._trace._export(
- gm,
- tuple(inputs),
- strict=False,
- pre_dispatch=True,
- )
- return ep
|