| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179 |
- # mypy: ignore-errors
- import functools
- import itertools
- import math
- import sys
- from typing import Callable, Union
- import torch
- import torch._custom_op
- import torch._logging
- from torch._ops import OpOverload
- from torch._prims_common import (
- elementwise_dtypes,
- ELEMENTWISE_TYPE_PROMOTION_KIND,
- is_boolean_dtype,
- is_float_dtype,
- is_integer_dtype,
- )
- from torch._subclasses.fake_tensor import (
- DataDependentOutputException,
- DynamicOutputShapeException,
- FakeTensor,
- in_kernel_invocation_manager,
- run_fallback_kernel,
- UnsupportedOperatorException,
- )
- from torch.fx.operator_schemas import normalize_function
- from torch.utils._stats import count_label
- pytree = torch.utils._pytree
- __all__ = [
- "op_implementations_checks",
- "get_fast_op_impls",
- "stride_incorrect_op",
- "has_meta",
- ]
- op_implementations_dict = {}
- op_implementations_checks = []
- aten = torch._ops.ops.aten
- def ordered_set(*items):
- return dict.fromkeys(items, True)
- # This function indicates if the backend device
- # supports non-contiguous tensors
- def is_noncontiguous_supported(device):
- if device.type == "hpu":
- return False
- return True
- _like_tensor_constructors = ordered_set(
- aten.empty_like.default,
- aten.empty_like.out,
- aten.full_like.default,
- aten.full_like.out,
- aten.ones_like.default,
- aten.ones_like.out,
- aten.rand_like.default,
- aten.rand_like.out,
- aten.randn_like.default,
- aten.randn_like.out,
- aten.randint_like.default,
- aten.randint_like.out,
- aten.randint_like.low_dtype,
- aten.randint_like.low_dtype_out,
- aten.zeros_like.default,
- aten.zeros_like.out,
- aten.new_empty.default,
- aten.new_empty.out,
- aten.new_empty_strided.default,
- aten.new_empty_strided.out,
- aten.new_full.default,
- aten.new_full.out,
- aten.new_zeros.default,
- aten.new_zeros.out,
- aten.new_ones.default,
- aten.new_ones.out,
- )
- _device_not_kwarg_ops = ordered_set(
- aten._resize_output_.default,
- aten._nested_tensor_from_tensor_list.default,
- aten._nested_tensor_from_tensor_list.out,
- aten.pin_memory.default,
- aten.is_pinned.default,
- aten.to.device,
- aten.to.prim_Device,
- aten._pin_memory.default,
- aten._pin_memory.out,
- aten._resize_output.default,
- aten._resize_output.out,
- )
- # this op is never actually used
- _non_kwarg_device_constructors = (aten._list_to_tensor,)
- def contains_tensor_types(type):
- tensor_type = torch._C.TensorType.get()
- return type.isSubtypeOf(tensor_type) or any(
- contains_tensor_types(e) for e in type.containedTypes()
- )
- @functools.lru_cache(None)
- def _is_tensor_constructor(func: OpOverload):
- assert isinstance(func, OpOverload)
- schema = func._schema
- if any(contains_tensor_types(arg.type) for arg in schema.arguments):
- return False
- # TODO: no real reason to restrict multiple outputs
- return (
- len(schema.returns) == 1 and schema.returns[0].type is torch._C.TensorType.get()
- )
- def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]):
- def impl_decorator(op_impl):
- if isinstance(run_impl_check, OpOverload):
- assert (
- run_impl_check not in op_implementations_dict
- ), f"duplicate registration: {run_impl_check}"
- op_implementations_dict[run_impl_check] = op_impl
- elif isinstance(run_impl_check, (list, tuple)):
- for op in run_impl_check:
- register_op_impl(op)(op_impl)
- else:
- assert callable(run_impl_check)
- op_implementations_checks.append((run_impl_check, op_impl))
- return op_impl
- return impl_decorator
- @register_op_impl(op_implementations_dict.__contains__)
- def dispatch_to_op_implementations_dict(fake_mode, func, *args, **kwargs):
- return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
- @register_op_impl(_is_tensor_constructor)
- @register_op_impl([*_like_tensor_constructors])
- def constructors(fake_mode, func, *args, **kwargs):
- assert func not in _non_kwarg_device_constructors
- _, new_kwargs = normalize_function(
- func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
- )
- if "names" in kwargs:
- raise UnsupportedOperatorException(
- "torch.compile doesn't support named tensors"
- )
- if func in _like_tensor_constructors:
- default_device = new_kwargs["input"].device
- # TODO: file issue
- args = (new_kwargs.pop("input"),)
- else:
- # cpu is default device if none is specified
- default_device = torch.device("cpu")
- args = ()
- out_device = new_kwargs.pop("device", None)
- out_device = out_device if out_device is not None else default_device
- new_kwargs["device"] = torch.device("meta")
- # _like constructors have fake tensor inputs (maybe this causes the non-like
- # to fail? hmmm)
- with in_kernel_invocation_manager(fake_mode):
- r = func(*args, **new_kwargs)
- return FakeTensor(fake_mode, r, out_device)
- @register_op_impl(aten.to.prim_Device)
- @register_op_impl(aten.to.device)
- def non_kwarg_to(fake_mode, func, *args, **kwargs):
- _, new_kwargs = normalize_function(
- func, args, kwargs, normalize_to_only_use_kwargs=True
- )
- input_device = new_kwargs["device"]
- out_device = input_device if input_device else new_kwargs["input"].device
- new_kwargs["device"] = torch.device("meta")
- inp = new_kwargs.pop("input")
- with in_kernel_invocation_manager(fake_mode):
- r = func(inp, **new_kwargs)
- # TODO: I think this does the wrong thing if r is inp
- return fake_mode.fake_tensor_converter.from_meta_and_device(
- fake_mode, r, out_device
- )
- def stride_incorrect_op(op):
- if op.namespace not in ("aten", "prims"):
- return False
- if op is aten._fft_c2c.default:
- return False
- op_name = op.name()
- if "fft" in op_name:
- return True
- return False
- # These operators have meta implementations with incorrect strides
- @register_op_impl(stride_incorrect_op)
- def wordaround_stride_incorrect_op(fake_mode, func, *args, **kwargs):
- # This is a workaround for meta implmentations with incorrect strides
- def is_symbolic(x):
- if isinstance(x, FakeTensor):
- return x._has_symbolic_sizes_strides
- if isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool)):
- return True
- return False
- # For static shapes, we can fall back to eager for the real strides
- if fake_mode.allow_fallback_kernels:
- require_dynamic = any(
- is_symbolic(x) for x in itertools.chain(args, kwargs.values())
- )
- if not require_dynamic:
- flat_args, args_spec = pytree.tree_flatten((args, kwargs))
- return run_fallback_kernel(fake_mode, func, flat_args, args_spec, None)
- raise UnsupportedOperatorException(func)
- # Dont default to default device handling,
- # since the device of `the_template` is ignored
- @register_op_impl(aten.resize_as_.default)
- def resize_as_(fake_mode, func, *args, **kwargs):
- with in_kernel_invocation_manager(fake_mode):
- return func(*args, **kwargs)
- @register_op_impl(aten._sparse_coo_tensor_with_dims_and_tensors.default)
- def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs):
- # TODO: remove me
- return constructors(fake_mode, func, *args, **kwargs)
- # index.Tensor data-dependent in only some conditions
- @register_op_impl(
- lambda func: torch.Tag.dynamic_output_shape in func.tags
- and func
- not in [aten.index.Tensor, aten.nonzero.default, aten.repeat_interleave.Tensor]
- )
- def dyn_shape(fake_mode, func, *args, **kwargs):
- raise DynamicOutputShapeException(func)
- def _unique(
- fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False
- ):
- if (
- fake_mode.shape_env is None
- or not fake_mode.shape_env.allow_dynamic_output_shape_ops
- ):
- # Without symints/symfloats, cannot handle this
- raise DynamicOutputShapeException(func)
- # Do not use a memo for unique_dim
- if dim is not None or (nnz := arg.unique_memo) is None:
- # Avoid importing sympy at a module level
- from torch.fx.experimental.symbolic_shapes import (
- _constrain_range_for_size,
- has_free_symbols,
- )
- if not has_free_symbols(arg.numel()) and arg.numel() == 0:
- # If numel is zero, then the output size must be zero.
- # In this case, we must not allocate an unbacked SymInt,
- # because if we do, it will immediately get refined to
- # zero, but this will be inconsistent with size oblivious
- # tests (which will continue to claim that the unbacked
- # symint cannot equal zero). We could also unconditionally
- # allocate an unbacked SymInt and not refine its range,
- # but this seems more precise.
- nnz = 0
- else:
- nnz = fake_mode.shape_env.create_unbacked_symint()
- maxval = sys.maxsize - 1
- numel = arg.numel() if dim is None else arg.size(dim)
- if not has_free_symbols(numel):
- maxval = int(numel)
- _constrain_range_for_size(nnz, max=maxval)
- if dim is None:
- arg.unique_memo = nnz
- if dim is None:
- ret = [arg.new_empty((nnz,))]
- else:
- ret = [arg.new_empty(*arg.shape[:dim], nnz, *arg.shape[dim + 1 :])]
- return_if_dim_and_cpu = dim is not None and arg.fake_device == torch.device("cpu")
- if return_inverse or return_if_dim_and_cpu:
- inverse = arg.new_empty(arg.shape if dim is None else (arg.shape[dim],))
- else:
- inverse = arg.new_empty(0)
- ret.append(inverse)
- if return_counts or return_if_dim_and_cpu:
- counts = arg.new_empty(ret[0].shape if dim is None else (ret[0].shape[dim],))
- else:
- counts = arg.new_empty(0)
- ret.append(counts)
- return tuple(ret)
- @register_op_impl(aten._unique2.default)
- def unique2(
- fake_mode, func, arg, sorted=True, return_inverse=False, return_counts=False
- ):
- return _unique(fake_mode, func, arg, None, sorted, return_inverse, return_counts)
- @register_op_impl(aten.unique_dim.default)
- def unique_dim(
- fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False
- ):
- return _unique(
- fake_mode,
- func,
- arg,
- # normalize dim to be non-negative
- dim if dim >= 0 else dim % max(arg.ndim, 1),
- sorted,
- return_inverse,
- return_counts,
- )
- @register_op_impl(aten.repeat_interleave.Tensor)
- def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None):
- if output_size is None:
- if (
- fake_mode.shape_env is None
- or not fake_mode.shape_env.allow_dynamic_output_shape_ops
- ):
- raise DynamicOutputShapeException(func)
- output_size = fake_mode.shape_env.create_unbacked_symint()
- # Avoid importing sympy at a module level
- from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
- _constrain_range_for_size(output_size)
- # TODO: consider a memo
- return repeats.new_empty(output_size)
- @register_op_impl(torch.ops.aten._local_scalar_dense.default)
- def local_scalar_dense(fake_mode, func, arg):
- if (r := arg.item_memo) is not None:
- return r
- if fake_mode.shape_env is None or (
- not fake_mode.shape_env.allow_scalar_outputs
- and not fake_mode.allow_scalar_outputs
- ):
- # Without symints/symfloats, cannot handle this
- raise DataDependentOutputException(func)
- if is_float_dtype(arg.dtype):
- r = fake_mode.shape_env.create_unbacked_symfloat()
- elif is_integer_dtype(arg.dtype):
- r = fake_mode.shape_env.create_unbacked_symint()
- elif is_boolean_dtype(arg.dtype):
- r = fake_mode.shape_env.create_unbacked_symbool()
- else:
- raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}")
- arg.item_memo = r
- return r
- @register_op_impl(torch.ops.aten.nonzero.default)
- def nonzero(fake_mode, func, arg):
- if (
- fake_mode.shape_env is None
- or not fake_mode.shape_env.allow_dynamic_output_shape_ops
- ):
- # Without symints/symfloats, cannot handle this
- raise DynamicOutputShapeException(func)
- if (nnz := arg.nonzero_memo) is None:
- # Avoid importing sympy at a module level
- from torch.fx.experimental.symbolic_shapes import (
- _constrain_range_for_size,
- has_free_symbols,
- )
- if not has_free_symbols(arg.numel()) and arg.numel() == 0:
- # If numel is zero, then the output size must be zero.
- # In this case, we must not allocate an unbacked SymInt,
- # because if we do, it will immediately get refined to
- # zero, but this will be inconsistent with size oblivious
- # tests (which will continue to claim that the unbacked
- # symint cannot equal zero). We could also unconditionally
- # allocate an unbacked SymInt and not refine its range,
- # but this seems more precise.
- nnz = 0
- else:
- nnz = fake_mode.shape_env.create_unbacked_symint()
- maxval = sys.maxsize - 1
- if not has_free_symbols(arg.numel()):
- maxval = int(arg.numel())
- _constrain_range_for_size(nnz, max=maxval)
- arg.nonzero_memo = nnz
- return arg.new_empty((nnz, arg.dim()), dtype=torch.int64)
- @register_op_impl(torch.ops.aten.masked_select.default)
- def masked_select(fake_mode, func, self, mask):
- if (
- fake_mode.shape_env is None
- or not fake_mode.shape_env.allow_dynamic_output_shape_ops
- ):
- # Without symints/symfloats, cannot handle this
- raise DynamicOutputShapeException(func)
- nnz = fake_mode.shape_env.create_unbacked_symint()
- # see nonzero for commentary
- maxval = sys.maxsize - 1
- # Avoid importing sympy at a module level
- from torch.fx.experimental.symbolic_shapes import (
- _constrain_range_for_size,
- has_free_symbols,
- )
- if not has_free_symbols(self.numel()):
- if self.numel() > 2:
- maxval = int(self.numel())
- _constrain_range_for_size(nnz, max=maxval)
- return self.new_empty((nnz,))
- # NB: this must be ordered after local_scalar_dense
- @register_op_impl(lambda func: torch.Tag.data_dependent_output in func.tags)
- def data_dep(fake_mode, func, *args, **kwargs):
- raise DataDependentOutputException(func)
- # Bool Indices get Expanded as Masks
- # See: IndexingUtils.h:expandTensors
- def check_no_bool_index_tensors(func, self, indices):
- for index in indices:
- if index is not None and index.dtype in (torch.bool, torch.uint8):
- raise DynamicOutputShapeException(func)
- def run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs):
- _, new_kwargs = normalize_function(
- func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
- )
- out_device = new_kwargs["input"].device
- with in_kernel_invocation_manager(fake_mode):
- out = func(*args, **kwargs)
- if not is_noncontiguous_supported(out_device):
- out = out.new_empty(out.shape)
- if out is new_kwargs["input"]:
- return out # copy_
- return FakeTensor(fake_mode, out, out_device)
- _is_builtin_namespaces = ordered_set("aten", "prims", "prim")
- def is_builtin(op):
- return op.namespace in _is_builtin_namespaces
- def has_meta(func):
- return torch._C._dispatch_has_computed_kernel_for_dispatch_key(func.name(), "Meta")
- @register_op_impl(
- lambda func: is_builtin(func) and "foreach" in func.name() and has_meta(func)
- )
- def foreach_run_and_map_input_device(fake_mode, func, *args, **kwargs):
- tensor_lists = []
- for arg in itertools.chain(args, kwargs.values()):
- if (
- isinstance(arg, (list, tuple))
- and len(arg)
- and isinstance(arg[0], torch.Tensor)
- ):
- tensor_lists.append(arg)
- try:
- with in_kernel_invocation_manager(fake_mode):
- out_meta = func(*args, **kwargs)
- except NotImplementedError as not_implemented_error:
- return NotImplemented
- if not out_meta:
- return out_meta
- assert tensor_lists
- out_fake = []
- for i, meta_t in enumerate(out_meta):
- device, _ = FakeTensor._find_common_device(func, [tl[i] for tl in tensor_lists])
- out_fake.append(
- fake_mode.fake_tensor_converter.from_meta_and_device(
- fake_mode, meta_t, device
- )
- )
- return out_fake
- # Dont default to default device handling,
- # Since op can take in non-zero sized cpu
- # index tensors with cuda self
- @register_op_impl(aten.index.Tensor)
- def index_tensor(fake_mode, func, *args, **kwargs):
- from torch._meta_registrations import meta_index_Tensor
- _, new_kwargs = normalize_function(
- func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
- )
- out_device = new_kwargs["input"].device
- # ensure nonzero call goes to fake tensor
- with fake_mode:
- out = meta_index_Tensor(*args, **kwargs)
- return out.to(out_device)
- # Can take mixed meta/non-meta arguments; the meta registration
- # will roughly do the right thing even when given real devices
- @register_op_impl(aten._embedding_bag.default)
- def embedding_bag(fake_mode, func, *args, **kwargs):
- from torch._meta_registrations import meta_embedding_bag
- with fake_mode:
- return meta_embedding_bag(*args, **kwargs)
- # takes in multiple-devices, dont default to default device handling
- @register_op_impl(aten._unsafe_index_put.default)
- @register_op_impl(aten.copy.default)
- @register_op_impl(aten.copy_.default)
- @register_op_impl(aten.slice_scatter.default)
- def multi_device_op_default(fake_mode, func, *args, **kwargs):
- return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)
- # same with multi_device_op_default, but return the input
- @register_op_impl(aten.copy.out)
- @register_op_impl(aten.slice_scatter.out)
- def multi_device_op_out(fake_mode, func, *args, **kwargs):
- with in_kernel_invocation_manager(fake_mode):
- out = func(*args, **kwargs)
- _, new_kwargs = normalize_function(
- func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
- )
- return new_kwargs["input"]
- @register_op_impl(aten.index_put.default)
- @register_op_impl(aten.index_put_.default)
- def index_put_impl(fake_mode, func, *args, **kwargs):
- _, new_kwargs = normalize_function(
- func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
- )
- values = new_kwargs["values"]
- self_device = new_kwargs["input"].fake_device
- torch._check(
- self_device == values.fake_device or (values.ndim == 0 and values.numel() == 1),
- lambda: f"Mismatching {func} device between self ({self_device}) and values ({values.device})",
- )
- out = run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)
- if func is aten.index_put_.default:
- return new_kwargs["input"]
- else:
- return out
- @register_op_impl(aten._nested_tensor_from_tensor_list.default)
- @register_op_impl(aten._nested_tensor_from_tensor_list.out)
- @register_op_impl(aten._nested_view_from_buffer.default)
- @register_op_impl(aten._nested_view_from_buffer_copy.default)
- def nested_tensors_unsupported(fake_mode, func, *args, **kwargs):
- raise UnsupportedOperatorException(
- "torch.compile does not support strided NestedTensor"
- )
- @register_op_impl(
- [
- x
- for x in _device_not_kwarg_ops
- if x
- not in (
- # these are already registered elsewhere
- aten.to.device,
- aten.to.prim_Device,
- aten._nested_tensor_from_tensor_list.default,
- aten._nested_tensor_from_tensor_list.out,
- )
- ]
- )
- def nyi(fake_mode, func, *args, **kwargs):
- assert func not in _device_not_kwarg_ops, f"NYI: {func}"
- @register_op_impl([aten.convolution.default, aten.convolution_backward.default])
- def conv(fake_mode, func, *args, **kwargs):
- _, kwargs = normalize_function(
- func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
- )
- device = kwargs["input"].fake_device
- # need to re-enable mode so the tensors report fake device
- with fake_mode:
- # if the input is unsqueezed is done in Convolution.cpp we get segfault
- k = kwargs["weight"].ndim
- batch = kwargs["input"].shape[0]
- # Avoid importing sympy at a module level
- from torch.fx.experimental.symbolic_shapes import has_hint
- if not has_hint(batch):
- # TODO: We can make this a little more faithful with best effort
- # channels last detection (but only if it's statically obvious!)
- mem_fmt = None
- elif k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu:
- mem_fmt = None
- else:
- if func is aten.convolution.default:
- conv_backend = torch._C._select_conv_backend(**kwargs)
- else:
- conv_backend = torch._C._select_conv_backend(
- kwargs["input"],
- kwargs["weight"],
- bias=None,
- stride=kwargs["stride"],
- padding=kwargs["padding"],
- dilation=kwargs["dilation"],
- transposed=kwargs["transposed"],
- output_padding=kwargs["output_padding"],
- groups=kwargs["groups"],
- bias_sizes=kwargs["bias_sizes"],
- )
- mem_fmt = torch._C._conv_determine_backend_memory_format(
- kwargs["input"], kwargs["weight"], conv_backend
- )
- def convert(t, mem_fmt):
- if t is None:
- return t
- if mem_fmt is not None:
- t = t.to(memory_format=mem_fmt)
- return FakeTensor(fake_mode, t, device)
- with in_kernel_invocation_manager(fake_mode):
- out = func(**kwargs)
- if func is aten.convolution.default:
- return convert(out, mem_fmt)
- else:
- return (
- convert(out[0], mem_fmt),
- convert(out[1], mem_fmt),
- convert(out[2], None),
- )
- @register_op_impl(aten._scaled_dot_product_flash_attention.default)
- def meta__scaled_dot_product_flash(fake_mode, func, *args, **kwargs):
- _, kwargs = normalize_function(
- func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
- )
- query = kwargs["query"]
- key = kwargs["key"]
- return_debug_mask = kwargs["return_debug_mask"]
- # unused: value, dropout_p, is_causal, scale
- def convert_tensor(t, device):
- return FakeTensor(fake_mode, t, device)
- batch_size = query.size(0)
- num_heads = query.size(1)
- max_seqlen_batch_q = query.size(2)
- head_dim = query.size(3)
- max_seqlen_batch_k = key.size(2)
- query_t = query.transpose(1, 2)
- # empty_like already returns a fake tensor so we don't need to convert it
- attention = torch.empty_like(query_t).transpose(1, 2)
- logsumexp = convert_tensor(
- torch.empty(
- (batch_size, num_heads, max_seqlen_batch_q),
- dtype=torch.float,
- device="meta",
- ),
- device=query.device,
- )
- if return_debug_mask:
- blocksize_c = 128 if head_dim > 64 else 256
- max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
- if max_seqlen_batch_k <= 128:
- max_seqlen_k = 128
- elif max_seqlen_batch_k <= 256:
- max_seqlen_k = 256
- debug_mask = convert_tensor(
- torch.empty(
- (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
- dtype=query.dtype,
- device="meta",
- ),
- device=query.device,
- )
- else:
- debug_mask = convert_tensor(
- torch.empty(0, dtype=query.dtype, device="meta"),
- query.device,
- )
- # Note [Seed and Offset]: device for seed and offset below depends on whether we are
- # capturing or not, but at the time of tracing we don't know if we
- # are going to use cudagraphs or not, so we return meta tensors here
- # it's possible we'll need to have some special handling in inductor for sdpa
- return (
- attention,
- logsumexp,
- None,
- None,
- max_seqlen_batch_q,
- max_seqlen_batch_k,
- convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),
- convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),
- debug_mask,
- )
- @register_op_impl(aten._scaled_dot_product_efficient_attention.default)
- def meta__scaled_dot_product_efficient(fake_mode, func, *args, **kwargs):
- _, kwargs = normalize_function(
- func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
- )
- query = kwargs["query"]
- key = kwargs["key"]
- value = kwargs["value"]
- compute_log_sumexp = kwargs["compute_log_sumexp"]
- # unused: attn_bias, dropout_p, is_causal, scale
- def convert_tensor(t, device):
- return FakeTensor(fake_mode, t, device)
- query = query.transpose(1, 2)
- key = key.transpose(1, 2)
- value = value.transpose(1, 2)
- B = query.size(0)
- M = query.size(1)
- N = key.size(1)
- num_heads = query.size(-2)
- K = query.size(-1)
- Kv = value.size(-1)
- res = convert_tensor(
- torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"),
- query.device,
- )
- logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0
- logsum_exp = convert_tensor(
- torch.empty(
- (B, num_heads, logsumexp_dim),
- dtype=torch.float,
- device="meta",
- ),
- query.device,
- )
- res = res.transpose(1, 2)
- # See Note [Seed and Offset]:
- seed = convert_tensor(
- torch.empty((), dtype=torch.long, device="meta"), query.device
- )
- offset = convert_tensor(
- torch.empty((), dtype=torch.long, device="meta"), query.device
- )
- return res, logsum_exp, seed, offset
- @register_op_impl(aten._flash_attention_forward.default)
- def meta__flash_attention_forward(fake_mode, func, *args, **kwargs):
- _, kwargs = normalize_function(
- func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
- )
- query = kwargs["query"]
- key = kwargs["key"]
- cum_seq_q = kwargs["cum_seq_q"]
- cum_seq_k = kwargs["cum_seq_k"]
- max_q = kwargs["max_q"]
- max_k = kwargs["max_k"]
- return_debug_mask = kwargs["return_debug_mask"]
- # unused: value, dropout_p, is_causal, scale
- # unused: seqused_k, alibi_slopes, window_size_left, window_size_right
- def convert_tensor(t, device):
- return FakeTensor(fake_mode, t, device)
- # NB: there are two underlying paths:
- # 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim)
- # 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total
- # includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total
- batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1
- max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q
- max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k
- num_heads = query.size(-2)
- head_dim = query.size(-1)
- # Cuda Path
- # note: empty_like already returns a fake tensor, we don't need to wrap it
- attention = torch.empty_like(query)
- logsumexp = convert_tensor(
- torch.empty(
- (batch_size, num_heads, max_seqlen_batch_q),
- dtype=torch.float,
- device="meta",
- ),
- device=query.device,
- )
- if return_debug_mask:
- blocksize_c = 128 if head_dim > 64 else 256
- max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
- if max_seqlen_batch_k <= 128:
- max_seqlen_k = 128
- elif max_seqlen_batch_k <= 256:
- max_seqlen_k = 256
- debug_mask = convert_tensor(
- torch.empty(
- (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
- dtype=query.dtype,
- device="meta",
- ),
- query.device,
- )
- else:
- debug_mask = convert_tensor(
- torch.empty(0, dtype=query.dtype, device="meta"),
- query.device,
- )
- # See Note [Seed and Offset]:
- return (
- attention,
- logsumexp,
- convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),
- convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),
- debug_mask,
- )
- @register_op_impl(aten._efficient_attention_forward.default)
- def meta__efficient_attention_forward(fake_mode, func, *args, **kwargs):
- _, kwargs = normalize_function(
- func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
- )
- query = kwargs["query"]
- key = kwargs["key"]
- value = kwargs["value"]
- cu_seqlens_q = kwargs["cu_seqlens_q"]
- max_seqlen_q = kwargs["max_seqlen_q"]
- max_seqlen_k = kwargs["max_seqlen_k"]
- compute_log_sumexp = kwargs["compute_log_sumexp"]
- # unused: bias, cu_seqlens_k, dropout_p, custom_mask_type, scale, seqlen_k
- def convert_tensor(t, device):
- return FakeTensor(fake_mode, t, device)
- B = query.size(0)
- M = query.size(1)
- N = key.size(1)
- num_heads = query.size(-2)
- K = query.size(-1)
- Kv = value.size(-1)
- res = convert_tensor(
- torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"),
- query.device,
- )
- logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B
- actual_max_seqlen_q = M
- if cu_seqlens_q is not None:
- assert max_seqlen_q is not None
- actual_max_seqlen_q = max_seqlen_q
- actual_max_seqlen_k = max_seqlen_k if max_seqlen_k is not None else N
- logsumexp_dim = (
- math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0
- )
- logsum_exp = convert_tensor(
- torch.empty(
- (logsumexp_batch_dim, num_heads, logsumexp_dim),
- dtype=torch.float,
- device="meta",
- ),
- query.device,
- )
- # See Note [Seed and Offset]:
- seed = convert_tensor(
- torch.empty((), dtype=torch.long, device="meta"), query.device
- )
- offset = convert_tensor(
- torch.empty((), dtype=torch.long, device="meta"), query.device
- )
- return res, logsum_exp, seed, offset, actual_max_seqlen_q, actual_max_seqlen_k
- @register_op_impl(torch.ops.aten._pack_padded_sequence.default)
- def _pack_padded_sequence(fake_mode, func, inputs, lengths, batch_first):
- if (
- fake_mode.shape_env is None
- or not fake_mode.shape_env.allow_dynamic_output_shape_ops
- ):
- # Without symints/symfloats, cannot handle this
- raise DynamicOutputShapeException(func)
- new_batch_size = fake_mode.shape_env.create_unbacked_symint()
- from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
- _constrain_range_for_size(new_batch_size)
- if not batch_first:
- # Inputs should have shape (batch_size, seq_len, *)
- inputs = inputs.transpose(0, 1)
- res_size = inputs.shape[1:]
- packed_data = inputs.new_empty(res_size)
- batch_size = inputs.new_empty((new_batch_size,))
- return (packed_data, batch_size)
- FAST_OP_IMPLEMENTATIONS = {}
- # Unlike register_op_impl, these don't do the slow iteration for
- # run_impl_check, and these run BEFORE decompositions
- def register_fast_op_impl(func: OpOverload):
- def impl_decorator(op_impl):
- FAST_OP_IMPLEMENTATIONS[func] = op_impl
- return op_impl
- return impl_decorator
- # infer_size_impl in ExpandUtils
- def infer_size(a, b):
- from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
- dimsA = len(a)
- dimsB = len(b)
- ndim = max(dimsA, dimsB)
- expandedSizes = [0] * ndim
- for i in range(ndim - 1, -1, -1):
- offset = ndim - 1 - i
- dimA = dimsA - 1 - offset
- dimB = dimsB - 1 - offset
- sizeA = a[dimA] if dimA >= 0 else 1
- sizeB = b[dimB] if dimB >= 0 else 1
- # NB: It is very important to test for broadcasting, before testing
- # sizeA == sizeB. This is because the broadcasting tests are likely
- # to be statically known (in particular, if sizeA/sizeB is unbacked
- # but size-like, we will unsoundly assume they never equal 1), but
- # the sizeA == sizeB test may not be statically known. However, once
- # we have established that no broadcasting is happening, the
- # sizeA == sizeB is now expect_true and we can defer it as a runtime
- # assert (this works because Python will return the terminal
- # expression of an or statement as-is, without bool()'ing it; if this
- # were not the case, we'd need to write this using torch.sym_or() or
- # something like that).
- torch._check(
- guard_size_oblivious(sizeA == 1)
- or guard_size_oblivious(sizeB == 1)
- or sizeA == sizeB,
- lambda: f"The size of tensor a ({sizeA}) "
- f"must match the size of tensor b ({sizeB}) "
- f"at non-singleton dimension {i})",
- )
- expandedSizes[i] = sizeB if guard_size_oblivious(sizeA == 1) else sizeA
- return tuple(expandedSizes)
- def make_fast_binary_impl(slow_ref):
- def fast_binary_impl(mode, *args, **kwargs):
- def slow(msg):
- count_label(f"slow {msg}")
- with mode:
- return slow_ref(*args, **kwargs)
- count_label("attempt fast")
- # Fast path (based off of TensorIterator fast path).
- # Unfortunately, there is no way to easily deduplicate
- # this with either the TensorIterator C++ implementation
- # (which we don't want to SymIntify, and also the algorithm
- # here is slightly different from TensorIterator to allow
- # for broadcasting), nor the PrimTorch implementation
- # (which does not actually implement a fast path.)
- operands = args
- # compute_shape
- has_scalars = False
- has_tensors = False
- final_shape = None
- for op in operands:
- shape = op.shape if isinstance(op, torch.Tensor) else ()
- if len(shape) == 0:
- has_scalars = True
- else:
- has_tensors = True
- if final_shape is None:
- final_shape = shape
- # TODO: Minor optimization: track if the shapes
- # were equal so you can skip the equality check
- # below if unnecessary
- final_shape = infer_size(final_shape, shape)
- assert final_shape is not None
- # Do some extra safety checks to see if the output
- # stride is obvious
- for op in operands:
- if (
- isinstance(op, torch.Tensor)
- and len(op.shape) == len(final_shape)
- and op.shape == final_shape
- ):
- break
- else:
- return slow("both tensors nontrivially broadcast")
- # compute_types
- cpu = torch.device("cpu")
- common_device = cpu
- common_dtype = None
- output_dtype = None
- has_different_input_dtypes = False
- for op in operands:
- if not isinstance(op, torch.Tensor):
- # Use elementwise_dtypes for the tricky case
- has_different_input_dtypes = True
- continue
- if common_device == cpu and not op.device.type == "cpu":
- common_device = op.device
- # Slightly simplified here as target_dtype cannot vary
- if common_dtype is None:
- common_dtype = op.dtype
- elif common_dtype != op.dtype:
- has_different_input_dtypes = True
- if has_different_input_dtypes:
- # compute promotion
- # TODO: we don't need the compute type
- _, common_dtype = elementwise_dtypes(
- *operands, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
- )
- # check all tensors on same device
- # cpu scalars are assumed allow
- current_cpu_scalars_on_non_cpu = 0
- max_cpu_scalars_on_non_cpu = 1 # hard coded atm
- for op in operands:
- if not isinstance(op, torch.Tensor):
- continue
- if common_device != cpu and op.dim() == 0 and op.device == cpu:
- if current_cpu_scalars_on_non_cpu >= max_cpu_scalars_on_non_cpu:
- return slow("error")
- current_cpu_scalars_on_non_cpu += 1
- elif op.device != common_device:
- return slow("error")
- # compute_fast_setup_type
- is_contiguous = True
- is_channels_last = True
- # TODO: is_non-overlapping_and_dense (not bound from Python
- # no inplace, no out, everything defined
- if is_noncontiguous_supported(common_device):
- for op in operands:
- if not isinstance(op, torch.Tensor):
- continue
- is_contiguous = is_contiguous and op.is_contiguous(
- memory_format=torch.contiguous_format
- )
- is_channels_last = is_channels_last and op.is_contiguous(
- memory_format=torch.channels_last
- )
- if is_contiguous:
- # do contiguous
- count_label("fast is_contiguous")
- return FakeTensor(
- mode,
- torch.empty(
- final_shape,
- dtype=common_dtype,
- device="meta",
- memory_format=torch.contiguous_format,
- ),
- device=common_device,
- )
- if is_channels_last:
- count_label("fast channels_last")
- # do channels last
- return FakeTensor(
- mode,
- torch.empty(
- final_shape,
- dtype=common_dtype,
- device="meta",
- memory_format=torch.channels_last,
- ),
- device=common_device,
- )
- return slow("no contiguity match")
- return fast_binary_impl
- @functools.lru_cache(None)
- def get_fast_op_impls():
- import torch._refs
- register_fast_op_impl(torch.ops.aten.add.Tensor)(
- make_fast_binary_impl(torch._refs.add)
- )
- register_fast_op_impl(torch.ops.aten.sub.Tensor)(
- make_fast_binary_impl(torch._refs.sub)
- )
- register_fast_op_impl(torch.ops.aten.mul.Tensor)(make_fast_binary_impl(torch._refs.mul)) # type: ignore[has-type]
- register_fast_op_impl(torch.ops.aten.div.Tensor)(
- make_fast_binary_impl(torch._refs.div)
- )
- return FAST_OP_IMPLEMENTATIONS
|