| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476 |
- # mypy: allow-untyped-defs
- import math
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
- number = Union[int, float]
- # flake8: noqa
- ###
- # There are generated files that depend on this file
- # To re-generate, please run from the root of the repo:
- # python torchgen/shape_functions/gen_jit_shape_functions.py
- # How to test:
- # After regenerating files, compile PyTorch.
- # Then run: ./build/bin/test_jit --gtest_filter=TestShapeGraphLinting.Basic
- # If you have enabled opinfo testing for the op, also run:
- # python test/test_ops_jit.py TestJitCPU.test_variant_consistency_jit_[FAILING_OP]_cpu_float32
- # to reproduce errors from opinfo tests.
- # Example PR: https://github.com/pytorch/pytorch/pull/80860/files
- ####
- import torch
- def broadcast(a: List[int], b: List[int]):
- dimsA = len(a)
- dimsB = len(b)
- ndim = max(dimsA, dimsB)
- expandedSizes: List[int] = []
- for i in range(ndim):
- offset = ndim - 1 - i
- dimA = dimsA - 1 - offset
- dimB = dimsB - 1 - offset
- sizeA = a[dimA] if (dimA >= 0) else 1
- sizeB = b[dimB] if (dimB >= 0) else 1
- if sizeA != sizeB and sizeA != 1 and sizeB != 1:
- # TODO: only assertion error is bound in C++ compilation right now
- raise AssertionError(
- f"The size of tensor a {sizeA} must match the size of tensor b ({sizeB}) at non-singleton dimension {i}"
- )
- expandedSizes.append(sizeB if sizeA == 1 else sizeA)
- return expandedSizes
- def broadcast_three(a: List[int], b: List[int], c: List[int]):
- return broadcast(broadcast(a, b), c)
- def broadcast_one_three(a: List[int], b: Any, c: List[int]):
- return broadcast(a, c)
- def adaptive_avg_pool2d(self: List[int], out: List[int]):
- assert len(out) == 2
- assert len(self) == 3 or len(self) == 4
- for i in range(1, len(self)):
- assert self[i] != 0
- shape: List[int] = []
- for i in range(0, len(self) - 2):
- shape.append(self[i])
- for elem in out:
- shape.append(elem)
- return shape
- def _copy(self: List[int]):
- out: List[int] = []
- for elem in self:
- out.append(elem)
- return out
- def unary(self: List[int]):
- return _copy(self)
- def broadcast_inplace(a: List[int], b: List[int]):
- dimsA = len(a)
- dimsB = len(b)
- if dimsB > dimsA:
- raise AssertionError(
- f"The dims of tensor b ({dimsB}) must be less than or equal tothe dims of tensor a ({dimsA}) "
- )
- for dimA in range(dimsA):
- dimB = dimsB - dimsA + dimA
- sizeA = a[dimA]
- sizeB = b[dimB] if (dimB >= 0) else 1
- if sizeA != sizeB and sizeB != 1:
- # TODO: only assertion error is bound in C++ compilation right now
- raise AssertionError(
- "The size of tensor a {} must match the size of tensor b ("
- "{}) at non-singleton dimension {}".format(sizeA, sizeB, dimA)
- )
- return _copy(a)
- def expand(self: List[int], sizes: List[int]):
- assert len(sizes) >= len(self)
- ndim = len(sizes)
- tensor_dim = len(self)
- if ndim == 0:
- return _copy(sizes)
- out: List[int] = []
- for i in range(ndim):
- offset = ndim - 1 - i
- dim = tensor_dim - 1 - offset
- size = self[dim] if dim >= 0 else 1
- targetSize = sizes[i]
- if targetSize == -1:
- assert dim >= 0
- targetSize = size
- if size != targetSize:
- assert size == 1
- size = targetSize
- out.append(size)
- return out
- def expand_one_unused(self: List[int], sizes: List[int], inp0: Any):
- return expand(self, sizes)
- def infer_size_impl(shape: List[int], numel: int) -> List[int]:
- newsize = 1
- infer_dim: Optional[int] = None
- for dim in range(len(shape)):
- if shape[dim] == -1:
- if infer_dim is not None:
- raise AssertionError("only one dimension can be inferred")
- infer_dim = dim
- elif shape[dim] >= 0:
- newsize *= shape[dim]
- else:
- raise AssertionError("invalid shape dimensions")
- if not (
- numel == newsize
- or (infer_dim is not None and newsize > 0 and numel % newsize == 0)
- ):
- raise AssertionError("invalid shape")
- out = _copy(shape)
- if infer_dim is not None:
- out[infer_dim] = numel // newsize
- return out
- def numel(sizes: List[int]):
- numel = 1
- for elem in sizes:
- numel *= elem
- return numel
- def view(self: List[int], sizes: List[int]):
- return infer_size_impl(sizes, numel(self))
- def view_one_unused(self: List[int], sizes: List[int], *, implicit: bool = False):
- return view(self, sizes)
- def sum_mean_dim(
- self: List[int], opt_dims: Optional[List[int]], keep_dim: bool, dt: Any
- ):
- out: List[int] = []
- if opt_dims is None or len(opt_dims) == 0:
- dims: List[int] = list(range(len(self)))
- else:
- dims = opt_dims
- for idx in range(len(self)):
- is_mean_dim: bool = False
- for reduce_dim in dims:
- if idx == maybe_wrap_dim(reduce_dim, len(self)):
- is_mean_dim = True
- if is_mean_dim:
- if keep_dim:
- out.append(1)
- else:
- out.append(self[idx])
- return out
- def max_dim(self: List[int], dim: int, keep_dim: bool):
- out = sum_mean_dim(self, [dim], keep_dim, None)
- return out, out
- # note: python already rounds down towards negative infinity on integer division, special arithmetic not needed
- def div_rtn(x: int, y: int):
- return x // y
- def pooling_output_shape_pad_lr(
- inputSize: int,
- kernelSize: int,
- pad_l: int,
- pad_r: int,
- stride: int,
- dilation: int,
- ceil_mode: bool,
- ):
- outputSize = (
- div_rtn(
- inputSize
- + pad_l
- + pad_r
- - dilation * (kernelSize - 1)
- - 1
- + (stride - 1 if ceil_mode else 0),
- stride,
- )
- + 1
- )
- if ceil_mode:
- if (outputSize - 1) * stride >= inputSize + pad_l:
- outputSize = outputSize - 1
- return outputSize
- def pooling_output_shape(
- inputSize: int,
- kernelSize: int,
- pad_l: int,
- stride: int,
- dilation: int,
- ceil_mode: bool,
- ):
- assert stride != 0, "stride should not be zeero"
- return pooling_output_shape_pad_lr(
- inputSize, kernelSize, pad_l, pad_l, stride, dilation, ceil_mode
- )
- def pool2d_shape_check(
- input: List[int],
- kH: int,
- kW: int,
- dH: int,
- dW: int,
- padH: int,
- padW: int,
- dilationH: int,
- dilationW: int,
- nInputPlane: int,
- inputHeight: int,
- inputWidth: int,
- outputHeight: int,
- outputWidth: int,
- ):
- ndim = len(input)
- nOutputPlane = nInputPlane
- assert kW > 0 and kH > 0
- assert dW > 0 and dH > 0
- assert dilationH > 0 and dilationW > 0
- valid_dims = input[1] != 0 and input[2] != 0
- assert (
- ndim == 3
- and input[0] != 0
- and valid_dims
- or (ndim == 4 and valid_dims and input[3] != 0)
- )
- assert kW // 2 >= padW and kH // 2 >= padH
- assert outputWidth >= 1 and outputHeight >= 1
- def max_pool2d(
- input: List[int],
- kernel_size: List[int],
- stride: List[int],
- padding: List[int],
- dilation: List[int],
- ceil_mode: bool,
- ):
- assert (
- len(kernel_size) == 1 or len(kernel_size) == 2
- ), "max_pool2d: kernel_size must either be a single int, or a tuple of two ints"
- kH = kernel_size[0]
- kW = kH if len(kernel_size) == 1 else kernel_size[1]
- assert (
- len(stride) == 0 or len(stride) == 1 or len(stride) == 2
- ), "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints"
- dH = kH if len(stride) == 0 else stride[0]
- if len(stride) == 0:
- dW = kW
- elif len(stride) == 1:
- dW = dH
- else:
- dW = stride[1]
- assert (
- len(padding) == 1 or len(padding) == 2
- ), "max_pool2d: padding must either be a single int, or a tuple of two ints"
- padH = padding[0]
- padW = padH if len(padding) == 1 else padding[1]
- assert (
- len(dilation) == 1 or len(dilation) == 2
- ), "max_pool2d: dilation must be either a single int, or a tuple of two ints"
- dilationH = dilation[0]
- dilationW = dilationH if len(dilation) == 1 else dilation[1]
- assert len(input) == 3 or len(input) == 4
- nbatch = input[-4] if len(input) == 4 else 1
- nInputPlane = input[-3]
- inputHeight = input[-2]
- inputWidth = input[-1]
- outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode)
- outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode)
- pool2d_shape_check(
- input,
- kH,
- kW,
- dH,
- dW,
- padH,
- padW,
- dilationH,
- dilationW,
- nInputPlane,
- inputHeight,
- inputWidth,
- outputHeight,
- outputWidth,
- )
- if len(input) == 3:
- return [nInputPlane, outputHeight, outputWidth]
- else:
- return [nbatch, nInputPlane, outputHeight, outputWidth]
- def max_pool2d_with_indices(
- input: List[int],
- kernel_size: List[int],
- stride: List[int],
- padding: List[int],
- dilation: List[int],
- ceil_mode: bool,
- ):
- out = max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
- return (out, out)
- def upsample_nearest2d(
- input: List[int],
- output_size: Optional[List[int]],
- scale_factors: Optional[List[float]],
- ):
- out: List[int] = []
- out.append(input[0])
- out.append(input[1])
- if scale_factors is None and output_size is None:
- assert 0, "Either output_size or scale_factors must be presented"
- if output_size is not None:
- assert (
- scale_factors is None
- ), "Must specify exactly one of output_size and scale_factors"
- assert len(output_size) == 2
- out.append(output_size[0])
- out.append(output_size[1])
- if scale_factors is not None:
- assert (
- output_size is None
- ), "Must specify exactly one of output_size and scale_factors"
- assert len(scale_factors) == 2
- out.append(int(input[2] * scale_factors[0]))
- out.append(int(input[3] * scale_factors[1]))
- return out
- def mm(self: List[int], mat2: List[int]):
- assert len(self) == 2, "self must be a matrix"
- assert len(mat2) == 2, "mat2 must be a matrix"
- assert self[1] == mat2[0]
- return [self[0], mat2[1]]
- def dot(self: List[int], tensor: List[int]):
- assert len(self) == 1 and len(tensor) == 1
- assert self[0] == tensor[0]
- out: List[int] = []
- return out
- def mv(self: List[int], vec: List[int]):
- assert len(self) == 2 and len(vec) == 1
- assert self[1] == vec[0]
- # TODO: return self
- return [self[0]]
- def unsqueeze(li: List[int], dim: int):
- dim = maybe_wrap_dim(dim, len(li) + 1)
- out = _copy(li)
- out.insert(dim, 1)
- return out
- def squeeze_nodim(li: List[int]):
- out: List[int] = []
- for i in range(len(li)):
- if li[i] != 1:
- out.append(li[i])
- return out
- def squeeze(li: List[int], dim: int):
- out: List[int] = []
- wrapped_dim = maybe_wrap_dim(dim, len(li))
- for i in range(len(li)):
- if i == wrapped_dim:
- if li[i] != 1:
- out.append(li[i])
- else:
- out.append(li[i])
- return out
- def squeeze_dims(li: List[int], dims: List[int]):
- if len(dims) == 0:
- return li
- wrapped_dims = _copy(dims)
- for i in range(len(dims)):
- wrapped_dims[i] = maybe_wrap_dim(wrapped_dims[i], len(li))
- result: List[int] = []
- for i in range(len(li)):
- if li[i] == 1:
- if i not in wrapped_dims:
- result.append(li[i])
- else:
- result.append(li[i])
- return result
- def index_select(self: List[int], dim: int, index: List[int]):
- dim = maybe_wrap_dim(dim, len(self))
- numel = multiply_integers(index)
- assert len(index) <= 1
- assert dim == 0 or dim < len(self)
- result_size: List[int] = []
- for i in range(len(self)):
- if dim == i:
- result_size.append(numel)
- else:
- result_size.append(self[i])
- return result_size
- def embedding(
- weight: List[int],
- indices: List[int],
- padding_idx: int = -1,
- scale_grad_by_freq: bool = False,
- sparse: bool = False,
- ):
- assert len(weight) == 2
- if len(indices) == 1:
- return index_select(weight, 0, indices)
- size = _copy(indices)
- size.append(weight[1])
- return size
- def max_int():
- return 9223372036854775807
- def slice(
- self: List[int], dim: int, start: Optional[int], end: Optional[int], step: int
- ):
- ndim = len(self)
- assert ndim != 0
- dim = maybe_wrap_dim(dim, ndim)
- start_val = start if start is not None else 0
- end_val = end if end is not None else max_int()
- assert step > 0
- if start_val == max_int():
- start_val = 0
- if start_val < 0:
- start_val += self[dim]
- if end_val < 0:
- end_val += self[dim]
- if start_val < 0:
- start_val = 0
- elif start_val > self[dim]:
- start_val = self[dim]
- if end_val < start_val:
- end_val = start_val
- elif end_val >= self[dim]:
- end_val = self[dim]
- slice_len = end_val - start_val
- out = _copy(self)
- out[dim] = (slice_len + step - 1) // step
- return out
- def check_cat_no_zero_dim(tensors: List[List[int]]):
- for tensor in tensors:
- assert len(tensor) > 0
- def legacy_cat_wrap_dim(dim: int, tensor_sizes: List[List[int]]):
- out_dim: Optional[int] = None
- for size in tensor_sizes:
- if not (len(size) == 1 and size[0] == 0):
- if out_dim is None:
- out_dim = maybe_wrap_dim(dim, len(size))
- if out_dim is None:
- out_dim = dim
- return out_dim
- def should_skip(tensor: List[int]):
- return numel(tensor) == 0 and len(tensor) == 1
- def check_cat_shape_except_dim(
- first: List[int], second: List[int], dimension: int, index: int
- ):
- first_dims = len(first)
- second_dims = len(second)
- assert first_dims == second_dims, "Tensors must have same number of dimensions"
- for dim in range(0, first_dims):
- if dim != dimension:
- assert (
- first[dim] == second[dim]
- ), "Sizes of tensors must match except in dimension"
- def cat(tensors: List[List[int]], dim: int):
- check_cat_no_zero_dim(tensors)
- dim = legacy_cat_wrap_dim(dim, tensors)
- assert len(tensors) > 0
- not_skipped_tensor: Optional[List[int]] = None
- for tensor in tensors:
- if not should_skip(tensor):
- not_skipped_tensor = tensor
- if not_skipped_tensor is None:
- return [0]
- cat_dim_size = 0
- for i in range(len(tensors)):
- tensor = tensors[i]
- if not should_skip(tensor):
- check_cat_shape_except_dim(not_skipped_tensor, tensor, dim, i)
- cat_dim_size = cat_dim_size + tensor[dim]
- result_size = _copy(not_skipped_tensor)
- result_size[dim] = cat_dim_size
- return result_size
- def stack(tensors: List[List[int]], dim: int):
- unsqueezed_tensors: List[List[int]] = []
- for tensor in tensors:
- unsqueezed = unsqueeze(tensor, dim)
- unsqueezed_tensors.append(unsqueezed)
- return cat(unsqueezed_tensors, dim)
- def select(self: List[int], dim: int, index: int):
- ndim = len(self)
- assert ndim != 0
- dim = maybe_wrap_dim(dim, ndim)
- size = self[dim]
- assert not (index < -size or index >= size)
- if index < 0:
- index += size
- out: List[int] = []
- for i in range(ndim):
- if i != dim:
- out.append(self[i])
- return out
- def matmul(tensor1: List[int], tensor2: List[int]):
- dim_tensor1 = len(tensor1)
- dim_tensor2 = len(tensor2)
- if dim_tensor1 == 1 and dim_tensor2 == 1:
- return dot(tensor1, tensor2)
- elif dim_tensor1 == 2 and dim_tensor2 == 1:
- return mv(tensor1, tensor2)
- elif dim_tensor1 == 1 and dim_tensor2 == 2:
- return squeeze(mm(unsqueeze(tensor1, 0), tensor2), 0)
- elif dim_tensor1 == 2 and dim_tensor2 == 2:
- return mm(tensor1, tensor2)
- elif dim_tensor1 >= 1 and dim_tensor2 >= 1:
- # We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
- # we track m1 vs m2 separately even though they must match for nicer error messages
- n = tensor1[-2] if dim_tensor1 > 1 else 1
- m1 = tensor1[-1]
- batch_tensor1: List[int] = []
- # TODO: handling of slice
- for i in range(dim_tensor1 - 2):
- batch_tensor1.append(tensor1[i])
- m2 = tensor2[-1] if dim_tensor2 > 1 else 1
- p = tensor2[-1]
- batch_tensor2: List[int] = []
- # TODO: handling of slice
- for i in range(dim_tensor2 - 2):
- batch_tensor2.append(tensor2[i])
- # expand the batch portion (i.e. cut off matrix dimensions and expand rest)
- expand_batch_portion = broadcast(batch_tensor1, batch_tensor2)
- # todo: copy ?
- output_shape = expand_batch_portion
- if dim_tensor1 > 1:
- output_shape.append(n)
- if dim_tensor2 > 1:
- output_shape.append(p)
- return output_shape
- else:
- assert False, "both arguments to matmul need to be at least 1D"
- def t(self: List[int]):
- assert len(self) <= 2
- self_len = len(self)
- if self_len == 0:
- out: List[int] = []
- return out
- elif self_len == 1:
- return [self[0]]
- else:
- return [self[1], self[0]]
- def transpose(self: List[int], dim0: int, dim1: int):
- ndims = len(self)
- dim0 = maybe_wrap_dim(dim0, ndims)
- dim1 = maybe_wrap_dim(dim1, ndims)
- if dim0 == dim1:
- return _copy(self)
- out: List[int] = []
- for i in range(ndims):
- if i == dim0:
- out.append(self[dim1])
- elif i == dim1:
- out.append(self[dim0])
- else:
- out.append(self[i])
- return out
- def linear(input: List[int], weight: List[int], bias: Optional[List[int]]):
- out = matmul(input, t(weight))
- if bias is not None:
- assert broadcast(bias, out) == out
- return out
- def addmm(self: List[int], mat1: List[int], mat2: List[int], beta: Any, alpha: Any):
- return broadcast(self, mm(mat1, mat2))
- def check_non_negative(array: List[int]) -> bool:
- # TODO: look into rewriting with early return and getting loop unrolling to fire
- non_negative = False
- for val in array:
- if val < 0:
- non_negative = True
- return non_negative
- def check_shape_forward(
- input: List[int],
- weight_sizes: List[int],
- bias: Optional[List[int]],
- stride: List[int],
- padding: List[int],
- dilation: List[int],
- groups: int,
- ):
- k = len(input)
- weight_dim = len(weight_sizes)
- # TODO: assertions could be expanded with the error messages
- assert not check_non_negative(padding)
- assert not check_non_negative(stride)
- assert weight_dim == k
- assert weight_sizes[0] >= groups
- assert (weight_sizes[0] % groups) == 0
- # only handling not transposed
- assert input[1] == weight_sizes[1] * groups
- assert bias is None or (len(bias) == 1 and bias[0] == weight_sizes[0])
- for i in range(2, k):
- assert (input[i] + 2 * padding[i - 2]) >= (
- dilation[i - 2] * (weight_sizes[i] - 1) + 1
- )
- # this is not handling transposed convolution yet
- def conv_output_size(
- input_size: List[int],
- weight_size: List[int],
- bias: Optional[List[int]],
- stride: List[int],
- padding: List[int],
- dilation: List[int],
- groups: int,
- ):
- check_shape_forward(
- input_size, weight_size, bias, stride, padding, dilation, groups
- )
- has_dilation = len(dilation) > 0
- dim = len(input_size)
- output_size: List[int] = []
- input_batch_size_dim = 0
- weight_output_channels_dim = 0
- output_size.append(input_size[input_batch_size_dim])
- output_size.append(weight_size[weight_output_channels_dim])
- for d in range(2, dim):
- dilation_ = dilation[d - 2] if has_dilation else 1
- kernel = dilation_ * (weight_size[d] - 1) + 1
- output_size.append(
- (input_size[d] + (2 * padding[d - 2]) - kernel) // stride[d - 2] + 1
- )
- return output_size
- def conv1d(
- input: List[int],
- weight: List[int],
- bias: Optional[List[int]],
- stride: List[int],
- padding: List[int],
- dilation: List[int],
- groups: int,
- ):
- assert len(weight) == 3
- assert len(input) == 3
- return conv_output_size(input, weight, bias, stride, padding, dilation, groups)
- def conv2d(
- input: List[int],
- weight: List[int],
- bias: Optional[List[int]],
- stride: List[int],
- padding: List[int],
- dilation: List[int],
- groups: int,
- ):
- assert len(weight) == 4
- assert len(input) == 4
- return conv_output_size(input, weight, bias, stride, padding, dilation, groups)
- def conv_backwards(
- grad_output: List[int],
- input: List[int],
- weight: List[int],
- biases: Optional[List[int]],
- ):
- # Bias gradient is always generated regardess of if biases is supplied
- return _copy(input), _copy(weight), [grad_output[1]]
- def conv_transpose2d_input(
- input: List[int],
- weight: List[int],
- bias: Optional[List[int]] = None,
- stride: Optional[List[int]] = None,
- padding: Optional[List[int]] = None,
- output_padding: Optional[List[int]] = None,
- groups: int = 1,
- dilation: Optional[List[int]] = None,
- ) -> List[int]:
- if stride is None:
- stride = [1, 1]
- if padding is None:
- padding = [0, 0]
- if output_padding is None:
- output_padding = [0, 0]
- if dilation is None:
- dilation = [1, 1]
- has_dilation = len(dilation) > 0
- dim = len(input)
- output_size: List[int] = []
- input_batch_size_dim = 0
- weight_output_channels_dim = 1
- output_size.append(input[input_batch_size_dim])
- output_size.append(weight[weight_output_channels_dim] * groups)
- for d in range(2, dim):
- dilation_ = dilation[d - 2] if has_dilation else 1
- kernel = dilation_ * (weight[d] - 1)
- output_size.append(
- (input[d] - 1) * stride[d - 2]
- - 2 * padding[d - 2]
- + kernel
- + output_padding[d - 2]
- + 1
- )
- return output_size
- def conv_forwards(
- input: List[int],
- weight: List[int],
- bias: Optional[List[int]],
- stride: List[int],
- padding: List[int],
- dilation: List[int],
- transposed: bool,
- output_padding: List[int],
- groups: int,
- ) -> List[int]:
- has_dilation = len(dilation) > 0
- has_output_padding = len(output_padding) > 0
- dim = len(input)
- output_size: List[int] = []
- input_batch_size_dim = 0
- weight_output_channels_dim = 1 if transposed else 0
- output_size.append(input[input_batch_size_dim])
- if transposed:
- output_size.append(weight[weight_output_channels_dim] * groups)
- else:
- output_size.append(weight[weight_output_channels_dim])
- for d in range(2, dim):
- dilation_ = dilation[d - 2] if has_dilation else 1
- output_padding_ = output_padding[d - 2] if has_output_padding else 0
- if transposed:
- kernel = dilation_ * (weight[d] - 1)
- output_size.append(
- (input[d] - 1) * stride[d - 2]
- - 2 * padding[d - 2]
- + kernel
- + output_padding_
- + 1
- )
- else:
- kernel = dilation_ * (weight[d] - 1) + 1
- output_size.append(
- (input[d] + (2 * padding[d - 2]) - kernel) // stride[d - 2] + 1
- )
- return output_size
- def _conv_forwards(
- input: List[int],
- weight: List[int],
- bias: Optional[List[int]],
- stride: List[int],
- padding: List[int],
- dilation: List[int],
- transposed: bool,
- output_padding: List[int],
- groups: int,
- benchmark: bool,
- deterministic: bool,
- cudnn_enabled: bool,
- allow_tf32: bool,
- ) -> List[int]:
- return conv_forwards(
- input,
- weight,
- bias,
- stride,
- padding,
- dilation,
- transposed,
- output_padding,
- groups,
- )
- def batch_norm(
- input: List[int],
- weight: Optional[List[int]],
- bias: Optional[List[int]],
- running_mean: Optional[List[int]],
- running_var: Optional[List[int]],
- training: bool,
- momentum: float,
- eps: float,
- cudnn_enabled: bool,
- ):
- out: List[int] = []
- for elem in input:
- out.append(elem)
- return out
- def conv3d(
- input: List[int],
- weight: List[int],
- bias: Optional[List[int]],
- stride: List[int],
- padding: List[int],
- dilation: List[int],
- groups: int,
- ):
- assert len(weight) == 5
- assert len(input) == 5
- return conv_output_size(input, weight, bias, stride, padding, dilation, groups)
- def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True):
- if dim_post_expr <= 0:
- assert wrap_scalar
- dim_post_expr = 1
- min = -dim_post_expr
- max = dim_post_expr - 1
- assert not (dim < min or dim > max)
- if dim < 0:
- dim += dim_post_expr
- return dim
- def zero_dim_tensor(input: Any):
- out: List[int] = []
- return out
- def multiply_integers(li: List[int]):
- out = 1
- for elem in li:
- out = out * elem
- return out
- def arange_end(end: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any):
- assert end >= 0
- return [int(math.ceil(end))]
- def arange_start(
- start: number, end: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any
- ):
- assert end >= 0
- assert end >= start
- return [int(math.ceil(end - start))]
- def arange_start_step(
- start: number, end: number, step: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any
- ):
- assert step != 0
- if step < 0:
- assert start >= end
- else:
- assert end >= start
- return [int(math.ceil((end - start) / step))]
- def permute(input: List[int], dims: List[int]):
- assert len(input) == len(dims)
- ndim = len(dims)
- seen_dims: List[int] = []
- newSizes: List[int] = []
- for i in range(ndim):
- dim = maybe_wrap_dim(dims[i], ndim)
- seen_dims.append(dim)
- newSizes.append(input[dim])
- for i in range(1, ndim):
- for j in range(i):
- assert seen_dims[i] != seen_dims[j]
- return newSizes
- def movedim(self: List[int], source: List[int], destination: List[int]) -> List[int]:
- self_dim = len(self)
- if self_dim <= 1:
- return self
- normalized_src: List[int] = []
- normalized_dst: List[int] = []
- for i in range(len(source)):
- normalized_src.append(maybe_wrap_dim(source[i], self_dim))
- normalized_dst.append(maybe_wrap_dim(destination[i], self_dim))
- order = [-1 for i in range(self_dim)]
- src_dims = [i for i in range(self_dim)]
- dst_dims = [i for i in range(self_dim)]
- for i in range(len(source)):
- order[normalized_dst[i]] = normalized_src[i]
- src_dims[normalized_src[i]] = -1
- dst_dims[normalized_dst[i]] = -1
- source_dims: List[int] = []
- destination_dims: List[int] = []
- for ele in src_dims:
- if ele != -1:
- source_dims.append(ele)
- for ele in dst_dims:
- if ele != -1:
- destination_dims.append(ele)
- rest_dim = self_dim - len(source)
- for i in range(rest_dim):
- order[destination_dims[i]] = source_dims[i]
- return permute(self, order)
- def flatten(input: List[int], start_dim: int, end_dim: int):
- start_dim = maybe_wrap_dim(start_dim, len(input))
- end_dim = maybe_wrap_dim(end_dim, len(input))
- assert start_dim <= end_dim
- if len(input) == 0:
- return [1]
- if start_dim == end_dim:
- # TODO: return self
- out: List[int] = []
- for elem in input:
- out.append(elem)
- return out
- slice_numel = 1
- for i in range(start_dim, end_dim + 1):
- slice_numel *= input[i]
- # TODO: use slicing when slice optimization has landed
- # slice_numel = multiply_integers(input[start_dim:end_dim - start_dim + 1])
- shape: List[int] = []
- for i in range(start_dim):
- shape.append(input[i])
- shape.append(slice_numel)
- for i in range(end_dim + 1, len(input)):
- shape.append(input[i])
- return shape
- def nonzero_lower_bound(input: List[int]):
- return [0, len(input)]
- def nonzero_upper_bound(input: List[int]):
- return [numel(input), len(input)]
- def _reduce_along_dim(self: List[int], dim: int, keepdim: bool):
- dim = maybe_wrap_dim(dim, len(self))
- out: List[int] = []
- for i, self_dim in enumerate(self):
- if i == dim:
- if keepdim:
- out.append(1)
- else:
- out.append(self_dim)
- return out
- def argmax(
- self: List[int], dim: Optional[int] = None, keepdim: bool = False
- ) -> List[int]:
- if dim is None:
- return []
- return _reduce_along_dim(self, dim, keepdim)
- def bmm(self: List[int], mat2: List[int]) -> List[int]:
- assert len(self) == 3, "bmm only supports 3D tensors"
- assert len(mat2) == 3, "bmm only supports 3D tensors"
- assert self[0] == mat2[0], "mismatching batch dimension"
- assert self[2] == mat2[1], "mismatching contracting dimension"
- return [self[0], self[1], mat2[2]]
- def _shape_as_tensor(self: List[int]) -> List[int]:
- return [len(self)]
- def topk(self: List[int], k: int, dim: int = -1) -> Tuple[List[int], List[int]]:
- if len(self) == 0:
- result: List[int] = []
- else:
- assert (
- k <= self[dim]
- ), f"k ({k}) is too big for dimension {dim} of size {self[dim]}"
- result = _copy(self)
- result[dim] = k
- return result, result
- def nll_loss_forward(
- self: List[int], target: List[int], weight: Optional[List[int]], reduction: int
- ) -> Tuple[List[int], List[int]]:
- # This is taken shamelessly from the meta function in LossNLL.cpp
- self_dim = len(self)
- target_dim = len(target)
- assert 0 < self_dim <= 2
- assert target_dim <= 1
- no_batch_dim = self_dim == 1 and target_dim == 0
- assert no_batch_dim or (self[0] == target[0])
- n_classes = self[-1]
- scalar_shape: List[int] = []
- assert weight is None or (len(weight) == 1 and weight[0] == n_classes)
- if reduction == 0 and self_dim == 2:
- reduction_shape = [self[0]]
- else:
- reduction_shape = scalar_shape
- return reduction_shape, scalar_shape
- def native_layer_norm(
- input: List[int], normalized_shape: List[int]
- ) -> Tuple[List[int], List[int], List[int]]:
- reduction_shape: List[int] = []
- num_unreduced_dimensions = len(input) - len(normalized_shape)
- assert num_unreduced_dimensions >= 0
- for i in range(num_unreduced_dimensions):
- reduction_shape.append(input[i])
- for i in range(num_unreduced_dimensions, len(input)):
- reduction_shape.append(1)
- return _copy(input), reduction_shape, reduction_shape
- def native_batch_norm(
- input: List[int],
- weight: Optional[List[int]],
- bias: Optional[List[int]],
- running_mean: Optional[List[int]],
- running_var: Optional[List[int]],
- training: bool,
- ) -> Tuple[List[int], List[int], List[int]]:
- if training:
- _size = [input[1]]
- else:
- _size = [0]
- return _copy(input), _size, _size
- def _batch_norm_with_update(
- input: List[int],
- weight: Optional[List[int]],
- bias: Optional[List[int]],
- running_mean: Optional[List[int]],
- running_var: Optional[List[int]],
- ) -> Tuple[List[int], List[int], List[int], List[int]]:
- _size = [input[1]]
- return _copy(input), _size, _size, [0]
- def cross_entropy_loss(
- self: List[int],
- target: List[int],
- weight: Optional[List[int]] = None,
- reduction: int = 1,
- ignore_index: int = -100,
- label_smoothing: float = 0.0,
- ) -> List[int]:
- result_shape = nll_loss_forward(self, target, weight, reduction)[0]
- return result_shape
- """
- Currently deferring the enabling of this, as part of the propoasal to suspend
- adding ops.
- There are currently cases in the test case where this is being called
- in the SSA opinfo tests with with unexpected values (eg list of two ints, see the first
- opinfo test). The behavoir of index is significantly dependent on the inputs.
- This could be an error with how we are matching up shape functions, or that this
- function needs to just implement everything.
- def index_Tensor(self: List[int], indices: List[Optional[List[int]]]) -> List[int]:
- assert len(indices) <= len(self), "More indices than dimensions to index"
- broadcasted_shape: List[int] = []
- for index_tensor_shape in indices:
- if index_tensor_shape is not None:
- broadcasted_shape = broadcast(broadcasted_shape, index_tensor_shape)
- return broadcasted_shape
- """
- ScriptFn = torch._C.ScriptFunction
- shape_compute_graph_mapping: Dict[str, ScriptFn] = {}
- bounded_compute_graph_mapping: Dict[str, Tuple[ScriptFn, ScriptFn]] = {}
- script_func_map: Dict[Callable, ScriptFn] = {}
- def process_func(func: Callable):
- if func not in script_func_map:
- scripted_func = torch.jit.script(func)
- torch._C._jit_pass_inline(scripted_func.graph)
- for _ in range(2):
- torch._C._jit_pass_peephole(scripted_func.graph)
- torch._C._jit_pass_constant_propagation(scripted_func.graph)
- script_func_map[func] = scripted_func
- return script_func_map[func]
- def add_shape_compute_mapping(operator_schema: str, func: Callable):
- global shape_compute_graph_mapping
- shape_compute_graph_mapping[operator_schema] = process_func(func)
- def add_bounded_compute_mapping(
- operator_schema: str, lower_bound_func: Callable, upper_bound_func: Callable
- ):
- # Adds a shape compute function for both upper and lower bounds
- fns = (process_func(lower_bound_func), process_func(upper_bound_func))
- bounded_compute_graph_mapping[operator_schema] = fns
- add_shape_compute_mapping(
- "aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)",
- unary,
- )
- add_shape_compute_mapping(
- "aten::rsub.Tensor(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", unary
- )
- add_shape_compute_mapping(
- "aten::dropout(Tensor input, float p, bool train) -> Tensor", unary
- )
- add_shape_compute_mapping(
- "aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor",
- adaptive_avg_pool2d,
- )
- add_shape_compute_mapping(
- "prim::NumToTensor.Scalar(Scalar a) -> Tensor", zero_dim_tensor
- )
- add_shape_compute_mapping("prim::NumToTensor.bool(bool a) -> Tensor", zero_dim_tensor)
- add_shape_compute_mapping(
- "aten::zeros(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)",
- unary,
- )
- add_shape_compute_mapping(
- "aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))",
- unary,
- )
- add_shape_compute_mapping(
- "aten::arange(Scalar end, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)",
- arange_end,
- )
- add_shape_compute_mapping(
- "aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor",
- arange_start,
- )
- add_shape_compute_mapping(
- "aten::arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor",
- arange_start_step,
- )
- add_shape_compute_mapping("aten::squeeze(Tensor(a) self) -> Tensor(a)", squeeze_nodim)
- add_shape_compute_mapping(
- "aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)", squeeze
- )
- add_shape_compute_mapping(
- "aten::squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a)", squeeze_dims
- )
- add_shape_compute_mapping(
- "aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)", unsqueeze
- )
- add_shape_compute_mapping(
- "aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)",
- slice,
- )
- add_shape_compute_mapping(
- "aten::select.int(Tensor(a) self, int dim, int index) -> Tensor(a)", select
- )
- add_shape_compute_mapping(
- "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor", index_select
- )
- add_shape_compute_mapping(
- "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, "
- "float eps=1e-05, bool cudnn_enable=True) -> Tensor",
- unary,
- )
- add_shape_compute_mapping(
- "aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", unary
- )
- add_shape_compute_mapping(
- "aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor",
- unary,
- )
- add_shape_compute_mapping(
- "aten::embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!)",
- unary,
- )
- add_shape_compute_mapping(
- "aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor",
- embedding,
- )
- add_shape_compute_mapping("aten::mm(Tensor self, Tensor mat2) -> Tensor", mm)
- add_shape_compute_mapping("aten::dot(Tensor self, Tensor tensor) -> Tensor", dot)
- add_shape_compute_mapping("aten::mv(Tensor self, Tensor vec) -> Tensor", mv)
- add_shape_compute_mapping("aten::matmul(Tensor self, Tensor other) -> Tensor", matmul)
- add_shape_compute_mapping(
- "aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor", linear
- )
- add_shape_compute_mapping(
- "aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
- max_pool2d,
- )
- add_shape_compute_mapping(
- "aten::max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)",
- max_pool2d_with_indices,
- )
- add_shape_compute_mapping("aten::t(Tensor(a) self) -> Tensor(a)", t)
- add_shape_compute_mapping(
- "aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)", transpose
- )
- add_shape_compute_mapping(
- "aten::conv1d(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, int[1] padding=0, int[1] dilation=1, int groups=1) -> Tensor",
- conv1d,
- )
- add_shape_compute_mapping(
- "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor",
- conv2d,
- )
- add_shape_compute_mapping(
- "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
- batch_norm,
- )
- add_shape_compute_mapping(
- "aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor",
- conv3d,
- )
- add_shape_compute_mapping(
- "aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, int[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)",
- conv_backwards,
- )
- add_shape_compute_mapping(
- "aten::convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor",
- conv_forwards,
- )
- add_shape_compute_mapping(
- "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor",
- _conv_forwards,
- )
- add_shape_compute_mapping(
- "aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int groups=1, int[2] dilation=1) -> Tensor",
- conv_transpose2d_input,
- )
- add_shape_compute_mapping(
- "aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)",
- flatten,
- )
- add_shape_compute_mapping("aten::cat(Tensor[] tensors, int dim=0) -> Tensor", cat)
- add_shape_compute_mapping("aten::stack(Tensor[] tensors, int dim=0) -> Tensor", stack)
- add_shape_compute_mapping(
- "aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", permute
- )
- add_shape_compute_mapping(
- "aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)",
- movedim,
- )
- add_shape_compute_mapping("aten::view(Tensor(a) self, int[] size) -> Tensor(a)", view)
- add_shape_compute_mapping(
- "aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", expand
- )
- add_shape_compute_mapping(
- "aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)",
- expand_one_unused,
- )
- add_shape_compute_mapping(
- "aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor",
- sum_mean_dim,
- )
- add_shape_compute_mapping(
- "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor",
- sum_mean_dim,
- )
- add_shape_compute_mapping(
- "aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)",
- max_dim,
- )
- add_shape_compute_mapping(
- "aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor
- )
- add_shape_compute_mapping(
- "aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor
- )
- add_shape_compute_mapping(
- "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor",
- addmm,
- )
- add_shape_compute_mapping(
- "aten::upsample_nearest2d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)",
- upsample_nearest2d,
- )
- add_shape_compute_mapping(
- "aten::quantize_per_tensor(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor",
- unary,
- )
- add_shape_compute_mapping(
- "aten::quantize_per_tensor.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype) -> Tensor",
- unary,
- )
- add_shape_compute_mapping("aten::dequantize(Tensor self) -> Tensor", unary)
- add_shape_compute_mapping(
- "quantized::add(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc",
- broadcast,
- )
- add_shape_compute_mapping(
- "aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor", argmax
- )
- add_shape_compute_mapping("aten::bmm(Tensor self, Tensor mat2) -> Tensor", bmm)
- add_shape_compute_mapping(
- "aten::_shape_as_tensor(Tensor self) -> Tensor", _shape_as_tensor
- )
- add_shape_compute_mapping(
- "aten::topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)",
- topk,
- )
- add_shape_compute_mapping(
- "aten::nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> (Tensor output, Tensor total_weight)",
- nll_loss_forward,
- )
- add_shape_compute_mapping(
- "aten::native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)",
- native_layer_norm,
- )
- add_shape_compute_mapping(
- "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
- native_batch_norm,
- )
- add_shape_compute_mapping(
- "aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
- native_batch_norm,
- )
- add_shape_compute_mapping(
- "aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
- native_batch_norm,
- )
- add_shape_compute_mapping(
- "_batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)",
- _batch_norm_with_update,
- )
- add_shape_compute_mapping(
- "aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor",
- cross_entropy_loss,
- )
- # add_shape_compute_mapping("aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor", index_Tensor)
- # TODO: migrate over all of symbolic_shape_registry_util.cpp
- # These are duplicated here so that the functions will be serialiazed
- add_shape_compute_mapping(
- "aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor",
- broadcast_three,
- )
- add_shape_compute_mapping(
- "aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor",
- broadcast_one_three,
- )
- add_shape_compute_mapping(
- "aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)",
- broadcast_inplace,
- )
- # quantized_conv_prepack TODO
- # Shape Compute Fn with upper and lower bounds
- add_bounded_compute_mapping(
- "aten::nonzero(Tensor self) -> (Tensor)", nonzero_lower_bound, nonzero_upper_bound
- )
|