| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055 |
- # mypy: ignore-errors
- """A thin pytorch / numpy compat layer.
- Things imported from here have numpy-compatible signatures but operate on
- pytorch tensors.
- """
- # Contents of this module ends up in the main namespace via _funcs.py
- # where type annotations are used in conjunction with the @normalizer decorator.
- from __future__ import annotations
- import builtins
- import itertools
- import operator
- from typing import Optional, Sequence, TYPE_CHECKING
- import torch
- from . import _dtypes_impl, _util
- if TYPE_CHECKING:
- from ._normalizations import (
- ArrayLike,
- ArrayLikeOrScalar,
- CastingModes,
- DTypeLike,
- NDArray,
- NotImplementedType,
- OutArray,
- )
- def copy(
- a: ArrayLike, order: NotImplementedType = "K", subok: NotImplementedType = False
- ):
- return a.clone()
- def copyto(
- dst: NDArray,
- src: ArrayLike,
- casting: Optional[CastingModes] = "same_kind",
- where: NotImplementedType = None,
- ):
- (src,) = _util.typecast_tensors((src,), dst.dtype, casting=casting)
- dst.copy_(src)
- def atleast_1d(*arys: ArrayLike):
- res = torch.atleast_1d(*arys)
- if isinstance(res, tuple):
- return list(res)
- else:
- return res
- def atleast_2d(*arys: ArrayLike):
- res = torch.atleast_2d(*arys)
- if isinstance(res, tuple):
- return list(res)
- else:
- return res
- def atleast_3d(*arys: ArrayLike):
- res = torch.atleast_3d(*arys)
- if isinstance(res, tuple):
- return list(res)
- else:
- return res
- def _concat_check(tup, dtype, out):
- if tup == ():
- raise ValueError("need at least one array to concatenate")
- """Check inputs in concatenate et al."""
- if out is not None and dtype is not None:
- # mimic numpy
- raise TypeError(
- "concatenate() only takes `out` or `dtype` as an "
- "argument, but both were provided."
- )
- def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"):
- """Figure out dtypes, cast if necessary."""
- if out is not None or dtype is not None:
- # figure out the type of the inputs and outputs
- out_dtype = out.dtype.torch_dtype if dtype is None else dtype
- else:
- out_dtype = _dtypes_impl.result_type_impl(*tensors)
- # cast input arrays if necessary; do not broadcast them agains `out`
- tensors = _util.typecast_tensors(tensors, out_dtype, casting)
- return tensors
- def _concatenate(
- tensors, axis=0, out=None, dtype=None, casting: Optional[CastingModes] = "same_kind"
- ):
- # pure torch implementation, used below and in cov/corrcoef below
- tensors, axis = _util.axis_none_flatten(*tensors, axis=axis)
- tensors = _concat_cast_helper(tensors, out, dtype, casting)
- return torch.cat(tensors, axis)
- def concatenate(
- ar_tuple: Sequence[ArrayLike],
- axis=0,
- out: Optional[OutArray] = None,
- dtype: Optional[DTypeLike] = None,
- casting: Optional[CastingModes] = "same_kind",
- ):
- _concat_check(ar_tuple, dtype, out=out)
- result = _concatenate(ar_tuple, axis=axis, out=out, dtype=dtype, casting=casting)
- return result
- def vstack(
- tup: Sequence[ArrayLike],
- *,
- dtype: Optional[DTypeLike] = None,
- casting: Optional[CastingModes] = "same_kind",
- ):
- _concat_check(tup, dtype, out=None)
- tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting)
- return torch.vstack(tensors)
- row_stack = vstack
- def hstack(
- tup: Sequence[ArrayLike],
- *,
- dtype: Optional[DTypeLike] = None,
- casting: Optional[CastingModes] = "same_kind",
- ):
- _concat_check(tup, dtype, out=None)
- tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting)
- return torch.hstack(tensors)
- def dstack(
- tup: Sequence[ArrayLike],
- *,
- dtype: Optional[DTypeLike] = None,
- casting: Optional[CastingModes] = "same_kind",
- ):
- # XXX: in numpy 1.24 dstack does not have dtype and casting keywords
- # but {h,v}stack do. Hence add them here for consistency.
- _concat_check(tup, dtype, out=None)
- tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting)
- return torch.dstack(tensors)
- def column_stack(
- tup: Sequence[ArrayLike],
- *,
- dtype: Optional[DTypeLike] = None,
- casting: Optional[CastingModes] = "same_kind",
- ):
- # XXX: in numpy 1.24 column_stack does not have dtype and casting keywords
- # but row_stack does. (because row_stack is an alias for vstack, really).
- # Hence add these keywords here for consistency.
- _concat_check(tup, dtype, out=None)
- tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting)
- return torch.column_stack(tensors)
- def stack(
- arrays: Sequence[ArrayLike],
- axis=0,
- out: Optional[OutArray] = None,
- *,
- dtype: Optional[DTypeLike] = None,
- casting: Optional[CastingModes] = "same_kind",
- ):
- _concat_check(arrays, dtype, out=out)
- tensors = _concat_cast_helper(arrays, dtype=dtype, casting=casting)
- result_ndim = tensors[0].ndim + 1
- axis = _util.normalize_axis_index(axis, result_ndim)
- return torch.stack(tensors, axis=axis)
- def append(arr: ArrayLike, values: ArrayLike, axis=None):
- if axis is None:
- if arr.ndim != 1:
- arr = arr.flatten()
- values = values.flatten()
- axis = arr.ndim - 1
- return _concatenate((arr, values), axis=axis)
- # ### split ###
- def _split_helper(tensor, indices_or_sections, axis, strict=False):
- if isinstance(indices_or_sections, int):
- return _split_helper_int(tensor, indices_or_sections, axis, strict)
- elif isinstance(indices_or_sections, (list, tuple)):
- # NB: drop split=..., it only applies to split_helper_int
- return _split_helper_list(tensor, list(indices_or_sections), axis)
- else:
- raise TypeError("split_helper: ", type(indices_or_sections))
- def _split_helper_int(tensor, indices_or_sections, axis, strict=False):
- if not isinstance(indices_or_sections, int):
- raise NotImplementedError("split: indices_or_sections")
- axis = _util.normalize_axis_index(axis, tensor.ndim)
- # numpy: l%n chunks of size (l//n + 1), the rest are sized l//n
- l, n = tensor.shape[axis], indices_or_sections
- if n <= 0:
- raise ValueError
- if l % n == 0:
- num, sz = n, l // n
- lst = [sz] * num
- else:
- if strict:
- raise ValueError("array split does not result in an equal division")
- num, sz = l % n, l // n + 1
- lst = [sz] * num
- lst += [sz - 1] * (n - num)
- return torch.split(tensor, lst, axis)
- def _split_helper_list(tensor, indices_or_sections, axis):
- if not isinstance(indices_or_sections, list):
- raise NotImplementedError("split: indices_or_sections: list")
- # numpy expects indices, while torch expects lengths of sections
- # also, numpy appends zero-size arrays for indices above the shape[axis]
- lst = [x for x in indices_or_sections if x <= tensor.shape[axis]]
- num_extra = len(indices_or_sections) - len(lst)
- lst.append(tensor.shape[axis])
- lst = [
- lst[0],
- ] + [a - b for a, b in zip(lst[1:], lst[:-1])]
- lst += [0] * num_extra
- return torch.split(tensor, lst, axis)
- def array_split(ary: ArrayLike, indices_or_sections, axis=0):
- return _split_helper(ary, indices_or_sections, axis)
- def split(ary: ArrayLike, indices_or_sections, axis=0):
- return _split_helper(ary, indices_or_sections, axis, strict=True)
- def hsplit(ary: ArrayLike, indices_or_sections):
- if ary.ndim == 0:
- raise ValueError("hsplit only works on arrays of 1 or more dimensions")
- axis = 1 if ary.ndim > 1 else 0
- return _split_helper(ary, indices_or_sections, axis, strict=True)
- def vsplit(ary: ArrayLike, indices_or_sections):
- if ary.ndim < 2:
- raise ValueError("vsplit only works on arrays of 2 or more dimensions")
- return _split_helper(ary, indices_or_sections, 0, strict=True)
- def dsplit(ary: ArrayLike, indices_or_sections):
- if ary.ndim < 3:
- raise ValueError("dsplit only works on arrays of 3 or more dimensions")
- return _split_helper(ary, indices_or_sections, 2, strict=True)
- def kron(a: ArrayLike, b: ArrayLike):
- return torch.kron(a, b)
- def vander(x: ArrayLike, N=None, increasing=False):
- return torch.vander(x, N, increasing)
- # ### linspace, geomspace, logspace and arange ###
- def linspace(
- start: ArrayLike,
- stop: ArrayLike,
- num=50,
- endpoint=True,
- retstep=False,
- dtype: Optional[DTypeLike] = None,
- axis=0,
- ):
- if axis != 0 or retstep or not endpoint:
- raise NotImplementedError
- if dtype is None:
- dtype = _dtypes_impl.default_dtypes().float_dtype
- # XXX: raises TypeError if start or stop are not scalars
- return torch.linspace(start, stop, num, dtype=dtype)
- def geomspace(
- start: ArrayLike,
- stop: ArrayLike,
- num=50,
- endpoint=True,
- dtype: Optional[DTypeLike] = None,
- axis=0,
- ):
- if axis != 0 or not endpoint:
- raise NotImplementedError
- base = torch.pow(stop / start, 1.0 / (num - 1))
- logbase = torch.log(base)
- return torch.logspace(
- torch.log(start) / logbase,
- torch.log(stop) / logbase,
- num,
- base=base,
- )
- def logspace(
- start,
- stop,
- num=50,
- endpoint=True,
- base=10.0,
- dtype: Optional[DTypeLike] = None,
- axis=0,
- ):
- if axis != 0 or not endpoint:
- raise NotImplementedError
- return torch.logspace(start, stop, num, base=base, dtype=dtype)
- def arange(
- start: Optional[ArrayLikeOrScalar] = None,
- stop: Optional[ArrayLikeOrScalar] = None,
- step: Optional[ArrayLikeOrScalar] = 1,
- dtype: Optional[DTypeLike] = None,
- *,
- like: NotImplementedType = None,
- ):
- if step == 0:
- raise ZeroDivisionError
- if stop is None and start is None:
- raise TypeError
- if stop is None:
- # XXX: this breaks if start is passed as a kwarg:
- # arange(start=4) should raise (no stop) but doesn't
- start, stop = 0, start
- if start is None:
- start = 0
- # the dtype of the result
- if dtype is None:
- dtype = (
- _dtypes_impl.default_dtypes().float_dtype
- if any(_dtypes_impl.is_float_or_fp_tensor(x) for x in (start, stop, step))
- else _dtypes_impl.default_dtypes().int_dtype
- )
- work_dtype = torch.float64 if dtype.is_complex else dtype
- # RuntimeError: "lt_cpu" not implemented for 'ComplexFloat'. Fall back to eager.
- if any(_dtypes_impl.is_complex_or_complex_tensor(x) for x in (start, stop, step)):
- raise NotImplementedError
- if (step > 0 and start > stop) or (step < 0 and start < stop):
- # empty range
- return torch.empty(0, dtype=dtype)
- result = torch.arange(start, stop, step, dtype=work_dtype)
- result = _util.cast_if_needed(result, dtype)
- return result
- # ### zeros/ones/empty/full ###
- def empty(
- shape,
- dtype: Optional[DTypeLike] = None,
- order: NotImplementedType = "C",
- *,
- like: NotImplementedType = None,
- ):
- if dtype is None:
- dtype = _dtypes_impl.default_dtypes().float_dtype
- return torch.empty(shape, dtype=dtype)
- # NB: *_like functions deliberately deviate from numpy: it has subok=True
- # as the default; we set subok=False and raise on anything else.
- def empty_like(
- prototype: ArrayLike,
- dtype: Optional[DTypeLike] = None,
- order: NotImplementedType = "K",
- subok: NotImplementedType = False,
- shape=None,
- ):
- result = torch.empty_like(prototype, dtype=dtype)
- if shape is not None:
- result = result.reshape(shape)
- return result
- def full(
- shape,
- fill_value: ArrayLike,
- dtype: Optional[DTypeLike] = None,
- order: NotImplementedType = "C",
- *,
- like: NotImplementedType = None,
- ):
- if isinstance(shape, int):
- shape = (shape,)
- if dtype is None:
- dtype = fill_value.dtype
- if not isinstance(shape, (tuple, list)):
- shape = (shape,)
- return torch.full(shape, fill_value, dtype=dtype)
- def full_like(
- a: ArrayLike,
- fill_value,
- dtype: Optional[DTypeLike] = None,
- order: NotImplementedType = "K",
- subok: NotImplementedType = False,
- shape=None,
- ):
- # XXX: fill_value broadcasts
- result = torch.full_like(a, fill_value, dtype=dtype)
- if shape is not None:
- result = result.reshape(shape)
- return result
- def ones(
- shape,
- dtype: Optional[DTypeLike] = None,
- order: NotImplementedType = "C",
- *,
- like: NotImplementedType = None,
- ):
- if dtype is None:
- dtype = _dtypes_impl.default_dtypes().float_dtype
- return torch.ones(shape, dtype=dtype)
- def ones_like(
- a: ArrayLike,
- dtype: Optional[DTypeLike] = None,
- order: NotImplementedType = "K",
- subok: NotImplementedType = False,
- shape=None,
- ):
- result = torch.ones_like(a, dtype=dtype)
- if shape is not None:
- result = result.reshape(shape)
- return result
- def zeros(
- shape,
- dtype: Optional[DTypeLike] = None,
- order: NotImplementedType = "C",
- *,
- like: NotImplementedType = None,
- ):
- if dtype is None:
- dtype = _dtypes_impl.default_dtypes().float_dtype
- return torch.zeros(shape, dtype=dtype)
- def zeros_like(
- a: ArrayLike,
- dtype: Optional[DTypeLike] = None,
- order: NotImplementedType = "K",
- subok: NotImplementedType = False,
- shape=None,
- ):
- result = torch.zeros_like(a, dtype=dtype)
- if shape is not None:
- result = result.reshape(shape)
- return result
- # ### cov & corrcoef ###
- def _xy_helper_corrcoef(x_tensor, y_tensor=None, rowvar=True):
- """Prepare inputs for cov and corrcoef."""
- # https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/function_base.py#L2636
- if y_tensor is not None:
- # make sure x and y are at least 2D
- ndim_extra = 2 - x_tensor.ndim
- if ndim_extra > 0:
- x_tensor = x_tensor.view((1,) * ndim_extra + x_tensor.shape)
- if not rowvar and x_tensor.shape[0] != 1:
- x_tensor = x_tensor.mT
- x_tensor = x_tensor.clone()
- ndim_extra = 2 - y_tensor.ndim
- if ndim_extra > 0:
- y_tensor = y_tensor.view((1,) * ndim_extra + y_tensor.shape)
- if not rowvar and y_tensor.shape[0] != 1:
- y_tensor = y_tensor.mT
- y_tensor = y_tensor.clone()
- x_tensor = _concatenate((x_tensor, y_tensor), axis=0)
- return x_tensor
- def corrcoef(
- x: ArrayLike,
- y: Optional[ArrayLike] = None,
- rowvar=True,
- bias=None,
- ddof=None,
- *,
- dtype: Optional[DTypeLike] = None,
- ):
- if bias is not None or ddof is not None:
- # deprecated in NumPy
- raise NotImplementedError
- xy_tensor = _xy_helper_corrcoef(x, y, rowvar)
- is_half = (xy_tensor.dtype == torch.float16) and xy_tensor.is_cpu
- if is_half:
- # work around torch's "addmm_impl_cpu_" not implemented for 'Half'"
- dtype = torch.float32
- xy_tensor = _util.cast_if_needed(xy_tensor, dtype)
- result = torch.corrcoef(xy_tensor)
- if is_half:
- result = result.to(torch.float16)
- return result
- def cov(
- m: ArrayLike,
- y: Optional[ArrayLike] = None,
- rowvar=True,
- bias=False,
- ddof=None,
- fweights: Optional[ArrayLike] = None,
- aweights: Optional[ArrayLike] = None,
- *,
- dtype: Optional[DTypeLike] = None,
- ):
- m = _xy_helper_corrcoef(m, y, rowvar)
- if ddof is None:
- ddof = 1 if bias == 0 else 0
- is_half = (m.dtype == torch.float16) and m.is_cpu
- if is_half:
- # work around torch's "addmm_impl_cpu_" not implemented for 'Half'"
- dtype = torch.float32
- m = _util.cast_if_needed(m, dtype)
- result = torch.cov(m, correction=ddof, aweights=aweights, fweights=fweights)
- if is_half:
- result = result.to(torch.float16)
- return result
- def _conv_corr_impl(a, v, mode):
- dt = _dtypes_impl.result_type_impl(a, v)
- a = _util.cast_if_needed(a, dt)
- v = _util.cast_if_needed(v, dt)
- padding = v.shape[0] - 1 if mode == "full" else mode
- if padding == "same" and v.shape[0] % 2 == 0:
- # UserWarning: Using padding='same' with even kernel lengths and odd
- # dilation may require a zero-padded copy of the input be created
- # (Triggered internally at pytorch/aten/src/ATen/native/Convolution.cpp:1010.)
- raise NotImplementedError("mode='same' and even-length weights")
- # NumPy only accepts 1D arrays; PyTorch requires 2D inputs and 3D weights
- aa = a[None, :]
- vv = v[None, None, :]
- result = torch.nn.functional.conv1d(aa, vv, padding=padding)
- # torch returns a 2D result, numpy returns a 1D array
- return result[0, :]
- def convolve(a: ArrayLike, v: ArrayLike, mode="full"):
- # NumPy: if v is longer than a, the arrays are swapped before computation
- if a.shape[0] < v.shape[0]:
- a, v = v, a
- # flip the weights since numpy does and torch does not
- v = torch.flip(v, (0,))
- return _conv_corr_impl(a, v, mode)
- def correlate(a: ArrayLike, v: ArrayLike, mode="valid"):
- v = torch.conj_physical(v)
- return _conv_corr_impl(a, v, mode)
- # ### logic & element selection ###
- def bincount(x: ArrayLike, /, weights: Optional[ArrayLike] = None, minlength=0):
- if x.numel() == 0:
- # edge case allowed by numpy
- x = x.new_empty(0, dtype=int)
- int_dtype = _dtypes_impl.default_dtypes().int_dtype
- (x,) = _util.typecast_tensors((x,), int_dtype, casting="safe")
- return torch.bincount(x, weights, minlength)
- def where(
- condition: ArrayLike,
- x: Optional[ArrayLikeOrScalar] = None,
- y: Optional[ArrayLikeOrScalar] = None,
- /,
- ):
- if (x is None) != (y is None):
- raise ValueError("either both or neither of x and y should be given")
- if condition.dtype != torch.bool:
- condition = condition.to(torch.bool)
- if x is None and y is None:
- result = torch.where(condition)
- else:
- result = torch.where(condition, x, y)
- return result
- # ###### module-level queries of object properties
- def ndim(a: ArrayLike):
- return a.ndim
- def shape(a: ArrayLike):
- return tuple(a.shape)
- def size(a: ArrayLike, axis=None):
- if axis is None:
- return a.numel()
- else:
- return a.shape[axis]
- # ###### shape manipulations and indexing
- def expand_dims(a: ArrayLike, axis):
- shape = _util.expand_shape(a.shape, axis)
- return a.view(shape) # never copies
- def flip(m: ArrayLike, axis=None):
- # XXX: semantic difference: np.flip returns a view, torch.flip copies
- if axis is None:
- axis = tuple(range(m.ndim))
- else:
- axis = _util.normalize_axis_tuple(axis, m.ndim)
- return torch.flip(m, axis)
- def flipud(m: ArrayLike):
- return torch.flipud(m)
- def fliplr(m: ArrayLike):
- return torch.fliplr(m)
- def rot90(m: ArrayLike, k=1, axes=(0, 1)):
- axes = _util.normalize_axis_tuple(axes, m.ndim)
- return torch.rot90(m, k, axes)
- # ### broadcasting and indices ###
- def broadcast_to(array: ArrayLike, shape, subok: NotImplementedType = False):
- return torch.broadcast_to(array, size=shape)
- # This is a function from tuples to tuples, so we just reuse it
- from torch import broadcast_shapes
- def broadcast_arrays(*args: ArrayLike, subok: NotImplementedType = False):
- return torch.broadcast_tensors(*args)
- def meshgrid(*xi: ArrayLike, copy=True, sparse=False, indexing="xy"):
- ndim = len(xi)
- if indexing not in ["xy", "ij"]:
- raise ValueError("Valid values for `indexing` are 'xy' and 'ij'.")
- s0 = (1,) * ndim
- output = [x.reshape(s0[:i] + (-1,) + s0[i + 1 :]) for i, x in enumerate(xi)]
- if indexing == "xy" and ndim > 1:
- # switch first and second axis
- output[0] = output[0].reshape((1, -1) + s0[2:])
- output[1] = output[1].reshape((-1, 1) + s0[2:])
- if not sparse:
- # Return the full N-D matrix (not only the 1-D vector)
- output = torch.broadcast_tensors(*output)
- if copy:
- output = [x.clone() for x in output]
- return list(output) # match numpy, return a list
- def indices(dimensions, dtype: Optional[DTypeLike] = int, sparse=False):
- # https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numeric.py#L1691-L1791
- dimensions = tuple(dimensions)
- N = len(dimensions)
- shape = (1,) * N
- if sparse:
- res = tuple()
- else:
- res = torch.empty((N,) + dimensions, dtype=dtype)
- for i, dim in enumerate(dimensions):
- idx = torch.arange(dim, dtype=dtype).reshape(
- shape[:i] + (dim,) + shape[i + 1 :]
- )
- if sparse:
- res = res + (idx,)
- else:
- res[i] = idx
- return res
- # ### tri*-something ###
- def tril(m: ArrayLike, k=0):
- return torch.tril(m, k)
- def triu(m: ArrayLike, k=0):
- return torch.triu(m, k)
- def tril_indices(n, k=0, m=None):
- if m is None:
- m = n
- return torch.tril_indices(n, m, offset=k)
- def triu_indices(n, k=0, m=None):
- if m is None:
- m = n
- return torch.triu_indices(n, m, offset=k)
- def tril_indices_from(arr: ArrayLike, k=0):
- if arr.ndim != 2:
- raise ValueError("input array must be 2-d")
- # Return a tensor rather than a tuple to avoid a graphbreak
- return torch.tril_indices(arr.shape[0], arr.shape[1], offset=k)
- def triu_indices_from(arr: ArrayLike, k=0):
- if arr.ndim != 2:
- raise ValueError("input array must be 2-d")
- # Return a tensor rather than a tuple to avoid a graphbreak
- return torch.triu_indices(arr.shape[0], arr.shape[1], offset=k)
- def tri(
- N,
- M=None,
- k=0,
- dtype: Optional[DTypeLike] = None,
- *,
- like: NotImplementedType = None,
- ):
- if M is None:
- M = N
- tensor = torch.ones((N, M), dtype=dtype)
- return torch.tril(tensor, diagonal=k)
- # ### equality, equivalence, allclose ###
- def isclose(a: ArrayLike, b: ArrayLike, rtol=1.0e-5, atol=1.0e-8, equal_nan=False):
- dtype = _dtypes_impl.result_type_impl(a, b)
- a = _util.cast_if_needed(a, dtype)
- b = _util.cast_if_needed(b, dtype)
- return torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
- def allclose(a: ArrayLike, b: ArrayLike, rtol=1e-05, atol=1e-08, equal_nan=False):
- dtype = _dtypes_impl.result_type_impl(a, b)
- a = _util.cast_if_needed(a, dtype)
- b = _util.cast_if_needed(b, dtype)
- return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
- def _tensor_equal(a1, a2, equal_nan=False):
- # Implementation of array_equal/array_equiv.
- if a1.shape != a2.shape:
- return False
- cond = a1 == a2
- if equal_nan:
- cond = cond | (torch.isnan(a1) & torch.isnan(a2))
- return cond.all().item()
- def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan=False):
- return _tensor_equal(a1, a2, equal_nan=equal_nan)
- def array_equiv(a1: ArrayLike, a2: ArrayLike):
- # *almost* the same as array_equal: _equiv tries to broadcast, _equal does not
- try:
- a1_t, a2_t = torch.broadcast_tensors(a1, a2)
- except RuntimeError:
- # failed to broadcast => not equivalent
- return False
- return _tensor_equal(a1_t, a2_t)
- def nan_to_num(
- x: ArrayLike, copy: NotImplementedType = True, nan=0.0, posinf=None, neginf=None
- ):
- # work around RuntimeError: "nan_to_num" not implemented for 'ComplexDouble'
- if x.is_complex():
- re = torch.nan_to_num(x.real, nan=nan, posinf=posinf, neginf=neginf)
- im = torch.nan_to_num(x.imag, nan=nan, posinf=posinf, neginf=neginf)
- return re + 1j * im
- else:
- return torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf)
- # ### put/take_along_axis ###
- def take(
- a: ArrayLike,
- indices: ArrayLike,
- axis=None,
- out: Optional[OutArray] = None,
- mode: NotImplementedType = "raise",
- ):
- (a,), axis = _util.axis_none_flatten(a, axis=axis)
- axis = _util.normalize_axis_index(axis, a.ndim)
- idx = (slice(None),) * axis + (indices, ...)
- result = a[idx]
- return result
- def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis):
- (arr,), axis = _util.axis_none_flatten(arr, axis=axis)
- axis = _util.normalize_axis_index(axis, arr.ndim)
- return torch.take_along_dim(arr, indices, axis)
- def put(
- a: NDArray,
- indices: ArrayLike,
- values: ArrayLike,
- mode: NotImplementedType = "raise",
- ):
- v = values.type(a.dtype)
- # If indices is larger than v, expand v to at least the size of indices. Any
- # unnecessary trailing elements are then trimmed.
- if indices.numel() > v.numel():
- ratio = (indices.numel() + v.numel() - 1) // v.numel()
- v = v.unsqueeze(0).expand((ratio,) + v.shape)
- # Trim unnecessary elements, regardless if v was expanded or not. Note
- # np.put() trims v to match indices by default too.
- if indices.numel() < v.numel():
- v = v.flatten()
- v = v[: indices.numel()]
- a.put_(indices, v)
- return None
- def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike, axis):
- (arr,), axis = _util.axis_none_flatten(arr, axis=axis)
- axis = _util.normalize_axis_index(axis, arr.ndim)
- indices, values = torch.broadcast_tensors(indices, values)
- values = _util.cast_if_needed(values, arr.dtype)
- result = torch.scatter(arr, axis, indices, values)
- arr.copy_(result.reshape(arr.shape))
- return None
- def choose(
- a: ArrayLike,
- choices: Sequence[ArrayLike],
- out: Optional[OutArray] = None,
- mode: NotImplementedType = "raise",
- ):
- # First, broadcast elements of `choices`
- choices = torch.stack(torch.broadcast_tensors(*choices))
- # Use an analog of `gather(choices, 0, a)` which broadcasts `choices` vs `a`:
- # (taken from https://github.com/pytorch/pytorch/issues/9407#issuecomment-1427907939)
- idx_list = [
- torch.arange(dim).view((1,) * i + (dim,) + (1,) * (choices.ndim - i - 1))
- for i, dim in enumerate(choices.shape)
- ]
- idx_list[0] = a
- return choices[idx_list].squeeze(0)
- # ### unique et al. ###
- def unique(
- ar: ArrayLike,
- return_index: NotImplementedType = False,
- return_inverse=False,
- return_counts=False,
- axis=None,
- *,
- equal_nan: NotImplementedType = True,
- ):
- (ar,), axis = _util.axis_none_flatten(ar, axis=axis)
- axis = _util.normalize_axis_index(axis, ar.ndim)
- result = torch.unique(
- ar, return_inverse=return_inverse, return_counts=return_counts, dim=axis
- )
- return result
- def nonzero(a: ArrayLike):
- return torch.nonzero(a, as_tuple=True)
- def argwhere(a: ArrayLike):
- return torch.argwhere(a)
- def flatnonzero(a: ArrayLike):
- return torch.flatten(a).nonzero(as_tuple=True)[0]
- def clip(
- a: ArrayLike,
- min: Optional[ArrayLike] = None,
- max: Optional[ArrayLike] = None,
- out: Optional[OutArray] = None,
- ):
- return torch.clamp(a, min, max)
- def repeat(a: ArrayLike, repeats: ArrayLikeOrScalar, axis=None):
- return torch.repeat_interleave(a, repeats, axis)
- def tile(A: ArrayLike, reps):
- if isinstance(reps, int):
- reps = (reps,)
- return torch.tile(A, reps)
- def resize(a: ArrayLike, new_shape=None):
- # implementation vendored from
- # https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/fromnumeric.py#L1420-L1497
- if new_shape is None:
- return a
- if isinstance(new_shape, int):
- new_shape = (new_shape,)
- a = a.flatten()
- new_size = 1
- for dim_length in new_shape:
- new_size *= dim_length
- if dim_length < 0:
- raise ValueError("all elements of `new_shape` must be non-negative")
- if a.numel() == 0 or new_size == 0:
- # First case must zero fill. The second would have repeats == 0.
- return torch.zeros(new_shape, dtype=a.dtype)
- repeats = -(-new_size // a.numel()) # ceil division
- a = concatenate((a,) * repeats)[:new_size]
- return reshape(a, new_shape)
- # ### diag et al. ###
- def diagonal(a: ArrayLike, offset=0, axis1=0, axis2=1):
- axis1 = _util.normalize_axis_index(axis1, a.ndim)
- axis2 = _util.normalize_axis_index(axis2, a.ndim)
- return torch.diagonal(a, offset, axis1, axis2)
- def trace(
- a: ArrayLike,
- offset=0,
- axis1=0,
- axis2=1,
- dtype: Optional[DTypeLike] = None,
- out: Optional[OutArray] = None,
- ):
- result = torch.diagonal(a, offset, dim1=axis1, dim2=axis2).sum(-1, dtype=dtype)
- return result
- def eye(
- N,
- M=None,
- k=0,
- dtype: Optional[DTypeLike] = None,
- order: NotImplementedType = "C",
- *,
- like: NotImplementedType = None,
- ):
- if dtype is None:
- dtype = _dtypes_impl.default_dtypes().float_dtype
- if M is None:
- M = N
- z = torch.zeros(N, M, dtype=dtype)
- z.diagonal(k).fill_(1)
- return z
- def identity(n, dtype: Optional[DTypeLike] = None, *, like: NotImplementedType = None):
- return torch.eye(n, dtype=dtype)
- def diag(v: ArrayLike, k=0):
- return torch.diag(v, k)
- def diagflat(v: ArrayLike, k=0):
- return torch.diagflat(v, k)
- def diag_indices(n, ndim=2):
- idx = torch.arange(n)
- return (idx,) * ndim
- def diag_indices_from(arr: ArrayLike):
- if not arr.ndim >= 2:
- raise ValueError("input array must be at least 2-d")
- # For more than d=2, the strided formula is only valid for arrays with
- # all dimensions equal, so we check first.
- s = arr.shape
- if s[1:] != s[:-1]:
- raise ValueError("All dimensions of input must be of equal length")
- return diag_indices(s[0], arr.ndim)
- def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap=False):
- if a.ndim < 2:
- raise ValueError("array must be at least 2-d")
- if val.numel() == 0 and not wrap:
- a.fill_diagonal_(val)
- return a
- if val.ndim == 0:
- val = val.unsqueeze(0)
- # torch.Tensor.fill_diagonal_ only accepts scalars
- # If the size of val is too large, then val is trimmed
- if a.ndim == 2:
- tall = a.shape[0] > a.shape[1]
- # wrap does nothing for wide matrices...
- if not wrap or not tall:
- # Never wraps
- diag = a.diagonal()
- diag.copy_(val[: diag.numel()])
- else:
- # wraps and tall... leaving one empty line between diagonals?!
- max_, min_ = a.shape
- idx = torch.arange(max_ - max_ // (min_ + 1))
- mod = idx % min_
- div = idx // min_
- a[(div * (min_ + 1) + mod, mod)] = val[: idx.numel()]
- else:
- idx = diag_indices_from(a)
- # a.shape = (n, n, ..., n)
- a[idx] = val[: a.shape[0]]
- return a
- def vdot(a: ArrayLike, b: ArrayLike, /):
- # 1. torch only accepts 1D arrays, numpy flattens
- # 2. torch requires matching dtype, while numpy casts (?)
- t_a, t_b = torch.atleast_1d(a, b)
- if t_a.ndim > 1:
- t_a = t_a.flatten()
- if t_b.ndim > 1:
- t_b = t_b.flatten()
- dtype = _dtypes_impl.result_type_impl(t_a, t_b)
- is_half = dtype == torch.float16 and (t_a.is_cpu or t_b.is_cpu)
- is_bool = dtype == torch.bool
- # work around torch's "dot" not implemented for 'Half', 'Bool'
- if is_half:
- dtype = torch.float32
- elif is_bool:
- dtype = torch.uint8
- t_a = _util.cast_if_needed(t_a, dtype)
- t_b = _util.cast_if_needed(t_b, dtype)
- result = torch.vdot(t_a, t_b)
- if is_half:
- result = result.to(torch.float16)
- elif is_bool:
- result = result.to(torch.bool)
- return result
- def tensordot(a: ArrayLike, b: ArrayLike, axes=2):
- if isinstance(axes, (list, tuple)):
- axes = [[ax] if isinstance(ax, int) else ax for ax in axes]
- target_dtype = _dtypes_impl.result_type_impl(a, b)
- a = _util.cast_if_needed(a, target_dtype)
- b = _util.cast_if_needed(b, target_dtype)
- return torch.tensordot(a, b, dims=axes)
- def dot(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
- dtype = _dtypes_impl.result_type_impl(a, b)
- is_bool = dtype == torch.bool
- if is_bool:
- dtype = torch.uint8
- a = _util.cast_if_needed(a, dtype)
- b = _util.cast_if_needed(b, dtype)
- if a.ndim == 0 or b.ndim == 0:
- result = a * b
- else:
- result = torch.matmul(a, b)
- if is_bool:
- result = result.to(torch.bool)
- return result
- def inner(a: ArrayLike, b: ArrayLike, /):
- dtype = _dtypes_impl.result_type_impl(a, b)
- is_half = dtype == torch.float16 and (a.is_cpu or b.is_cpu)
- is_bool = dtype == torch.bool
- if is_half:
- # work around torch's "addmm_impl_cpu_" not implemented for 'Half'"
- dtype = torch.float32
- elif is_bool:
- dtype = torch.uint8
- a = _util.cast_if_needed(a, dtype)
- b = _util.cast_if_needed(b, dtype)
- result = torch.inner(a, b)
- if is_half:
- result = result.to(torch.float16)
- elif is_bool:
- result = result.to(torch.bool)
- return result
- def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
- return torch.outer(a, b)
- def cross(a: ArrayLike, b: ArrayLike, axisa=-1, axisb=-1, axisc=-1, axis=None):
- # implementation vendored from
- # https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numeric.py#L1486-L1685
- if axis is not None:
- axisa, axisb, axisc = (axis,) * 3
- # Check axisa and axisb are within bounds
- axisa = _util.normalize_axis_index(axisa, a.ndim)
- axisb = _util.normalize_axis_index(axisb, b.ndim)
- # Move working axis to the end of the shape
- a = torch.moveaxis(a, axisa, -1)
- b = torch.moveaxis(b, axisb, -1)
- msg = "incompatible dimensions for cross product\n(dimension must be 2 or 3)"
- if a.shape[-1] not in (2, 3) or b.shape[-1] not in (2, 3):
- raise ValueError(msg)
- # Create the output array
- shape = broadcast_shapes(a[..., 0].shape, b[..., 0].shape)
- if a.shape[-1] == 3 or b.shape[-1] == 3:
- shape += (3,)
- # Check axisc is within bounds
- axisc = _util.normalize_axis_index(axisc, len(shape))
- dtype = _dtypes_impl.result_type_impl(a, b)
- cp = torch.empty(shape, dtype=dtype)
- # recast arrays as dtype
- a = _util.cast_if_needed(a, dtype)
- b = _util.cast_if_needed(b, dtype)
- # create local aliases for readability
- a0 = a[..., 0]
- a1 = a[..., 1]
- if a.shape[-1] == 3:
- a2 = a[..., 2]
- b0 = b[..., 0]
- b1 = b[..., 1]
- if b.shape[-1] == 3:
- b2 = b[..., 2]
- if cp.ndim != 0 and cp.shape[-1] == 3:
- cp0 = cp[..., 0]
- cp1 = cp[..., 1]
- cp2 = cp[..., 2]
- if a.shape[-1] == 2:
- if b.shape[-1] == 2:
- # a0 * b1 - a1 * b0
- cp[...] = a0 * b1 - a1 * b0
- return cp
- else:
- assert b.shape[-1] == 3
- # cp0 = a1 * b2 - 0 (a2 = 0)
- # cp1 = 0 - a0 * b2 (a2 = 0)
- # cp2 = a0 * b1 - a1 * b0
- cp0[...] = a1 * b2
- cp1[...] = -a0 * b2
- cp2[...] = a0 * b1 - a1 * b0
- else:
- assert a.shape[-1] == 3
- if b.shape[-1] == 3:
- cp0[...] = a1 * b2 - a2 * b1
- cp1[...] = a2 * b0 - a0 * b2
- cp2[...] = a0 * b1 - a1 * b0
- else:
- assert b.shape[-1] == 2
- cp0[...] = -a2 * b1
- cp1[...] = a2 * b0
- cp2[...] = a0 * b1 - a1 * b0
- return torch.moveaxis(cp, -1, axisc)
- def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize=False):
- # Have to manually normalize *operands and **kwargs, following the NumPy signature
- # We have a local import to avoid poluting the global space, as it will be then
- # exported in funcs.py
- from ._ndarray import ndarray
- from ._normalizations import (
- maybe_copy_to,
- normalize_array_like,
- normalize_casting,
- normalize_dtype,
- wrap_tensors,
- )
- dtype = normalize_dtype(dtype)
- casting = normalize_casting(casting)
- if out is not None and not isinstance(out, ndarray):
- raise TypeError("'out' must be an array")
- if order != "K":
- raise NotImplementedError("'order' parameter is not supported.")
- # parse arrays and normalize them
- sublist_format = not isinstance(operands[0], str)
- if sublist_format:
- # op, str, op, str ... [sublistout] format: normalize every other argument
- # - if sublistout is not given, the length of operands is even, and we pick
- # odd-numbered elements, which are arrays.
- # - if sublistout is given, the length of operands is odd, we peel off
- # the last one, and pick odd-numbered elements, which are arrays.
- # Without [:-1], we would have picked sublistout, too.
- array_operands = operands[:-1][::2]
- else:
- # ("ij->", arrays) format
- subscripts, array_operands = operands[0], operands[1:]
- tensors = [normalize_array_like(op) for op in array_operands]
- target_dtype = _dtypes_impl.result_type_impl(*tensors) if dtype is None else dtype
- # work around 'bmm' not implemented for 'Half' etc
- is_half = target_dtype == torch.float16 and all(t.is_cpu for t in tensors)
- if is_half:
- target_dtype = torch.float32
- is_short_int = target_dtype in [torch.uint8, torch.int8, torch.int16, torch.int32]
- if is_short_int:
- target_dtype = torch.int64
- tensors = _util.typecast_tensors(tensors, target_dtype, casting)
- from torch.backends import opt_einsum
- try:
- # set the global state to handle the optimize=... argument, restore on exit
- if opt_einsum.is_available():
- old_strategy = torch.backends.opt_einsum.strategy
- old_enabled = torch.backends.opt_einsum.enabled
- # torch.einsum calls opt_einsum.contract_path, which runs into
- # https://github.com/dgasmith/opt_einsum/issues/219
- # for strategy={True, False}
- if optimize is True:
- optimize = "auto"
- elif optimize is False:
- torch.backends.opt_einsum.enabled = False
- torch.backends.opt_einsum.strategy = optimize
- if sublist_format:
- # recombine operands
- sublists = operands[1::2]
- has_sublistout = len(operands) % 2 == 1
- if has_sublistout:
- sublistout = operands[-1]
- operands = list(itertools.chain.from_iterable(zip(tensors, sublists)))
- if has_sublistout:
- operands.append(sublistout)
- result = torch.einsum(*operands)
- else:
- result = torch.einsum(subscripts, *tensors)
- finally:
- if opt_einsum.is_available():
- torch.backends.opt_einsum.strategy = old_strategy
- torch.backends.opt_einsum.enabled = old_enabled
- result = maybe_copy_to(out, result)
- return wrap_tensors(result)
- # ### sort and partition ###
- def _sort_helper(tensor, axis, kind, order):
- if tensor.dtype.is_complex:
- raise NotImplementedError(f"sorting {tensor.dtype} is not supported")
- (tensor,), axis = _util.axis_none_flatten(tensor, axis=axis)
- axis = _util.normalize_axis_index(axis, tensor.ndim)
- stable = kind == "stable"
- return tensor, axis, stable
- def sort(a: ArrayLike, axis=-1, kind=None, order: NotImplementedType = None):
- # `order` keyword arg is only relevant for structured dtypes; so not supported here.
- a, axis, stable = _sort_helper(a, axis, kind, order)
- result = torch.sort(a, dim=axis, stable=stable)
- return result.values
- def argsort(a: ArrayLike, axis=-1, kind=None, order: NotImplementedType = None):
- a, axis, stable = _sort_helper(a, axis, kind, order)
- return torch.argsort(a, dim=axis, stable=stable)
- def searchsorted(
- a: ArrayLike, v: ArrayLike, side="left", sorter: Optional[ArrayLike] = None
- ):
- if a.dtype.is_complex:
- raise NotImplementedError(f"searchsorted with dtype={a.dtype}")
- return torch.searchsorted(a, v, side=side, sorter=sorter)
- # ### swap/move/roll axis ###
- def moveaxis(a: ArrayLike, source, destination):
- source = _util.normalize_axis_tuple(source, a.ndim, "source")
- destination = _util.normalize_axis_tuple(destination, a.ndim, "destination")
- return torch.moveaxis(a, source, destination)
- def swapaxes(a: ArrayLike, axis1, axis2):
- axis1 = _util.normalize_axis_index(axis1, a.ndim)
- axis2 = _util.normalize_axis_index(axis2, a.ndim)
- return torch.swapaxes(a, axis1, axis2)
- def rollaxis(a: ArrayLike, axis, start=0):
- # Straight vendor from:
- # https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numeric.py#L1259
- #
- # Also note this function in NumPy is mostly retained for backwards compat
- # (https://stackoverflow.com/questions/29891583/reason-why-numpy-rollaxis-is-so-confusing)
- # so let's not touch it unless hard pressed.
- n = a.ndim
- axis = _util.normalize_axis_index(axis, n)
- if start < 0:
- start += n
- msg = "'%s' arg requires %d <= %s < %d, but %d was passed in"
- if not (0 <= start < n + 1):
- raise _util.AxisError(msg % ("start", -n, "start", n + 1, start))
- if axis < start:
- # it's been removed
- start -= 1
- if axis == start:
- # numpy returns a view, here we try returning the tensor itself
- # return tensor[...]
- return a
- axes = list(range(0, n))
- axes.remove(axis)
- axes.insert(start, axis)
- return a.view(axes)
- def roll(a: ArrayLike, shift, axis=None):
- if axis is not None:
- axis = _util.normalize_axis_tuple(axis, a.ndim, allow_duplicate=True)
- if not isinstance(shift, tuple):
- shift = (shift,) * len(axis)
- return torch.roll(a, shift, axis)
- # ### shape manipulations ###
- def squeeze(a: ArrayLike, axis=None):
- if axis == ():
- result = a
- elif axis is None:
- result = a.squeeze()
- else:
- if isinstance(axis, tuple):
- result = a
- for ax in axis:
- result = a.squeeze(ax)
- else:
- result = a.squeeze(axis)
- return result
- def reshape(a: ArrayLike, newshape, order: NotImplementedType = "C"):
- # if sh = (1, 2, 3), numpy allows both .reshape(sh) and .reshape(*sh)
- newshape = newshape[0] if len(newshape) == 1 else newshape
- return a.reshape(newshape)
- # NB: cannot use torch.reshape(a, newshape) above, because of
- # (Pdb) torch.reshape(torch.as_tensor([1]), 1)
- # *** TypeError: reshape(): argument 'shape' (position 2) must be tuple of SymInts, not int
- def transpose(a: ArrayLike, axes=None):
- # numpy allows both .transpose(sh) and .transpose(*sh)
- # also older code uses axes being a list
- if axes in [(), None, (None,)]:
- axes = tuple(reversed(range(a.ndim)))
- elif len(axes) == 1:
- axes = axes[0]
- return a.permute(axes)
- def ravel(a: ArrayLike, order: NotImplementedType = "C"):
- return torch.flatten(a)
- def diff(
- a: ArrayLike,
- n=1,
- axis=-1,
- prepend: Optional[ArrayLike] = None,
- append: Optional[ArrayLike] = None,
- ):
- axis = _util.normalize_axis_index(axis, a.ndim)
- if n < 0:
- raise ValueError(f"order must be non-negative but got {n}")
- if n == 0:
- # match numpy and return the input immediately
- return a
- if prepend is not None:
- shape = list(a.shape)
- shape[axis] = prepend.shape[axis] if prepend.ndim > 0 else 1
- prepend = torch.broadcast_to(prepend, shape)
- if append is not None:
- shape = list(a.shape)
- shape[axis] = append.shape[axis] if append.ndim > 0 else 1
- append = torch.broadcast_to(append, shape)
- return torch.diff(a, n, axis=axis, prepend=prepend, append=append)
- # ### math functions ###
- def angle(z: ArrayLike, deg=False):
- result = torch.angle(z)
- if deg:
- result = result * (180 / torch.pi)
- return result
- def sinc(x: ArrayLike):
- return torch.sinc(x)
- # NB: have to normalize *varargs manually
- def gradient(f: ArrayLike, *varargs, axis=None, edge_order=1):
- N = f.ndim # number of dimensions
- varargs = _util.ndarrays_to_tensors(varargs)
- if axis is None:
- axes = tuple(range(N))
- else:
- axes = _util.normalize_axis_tuple(axis, N)
- len_axes = len(axes)
- n = len(varargs)
- if n == 0:
- # no spacing argument - use 1 in all axes
- dx = [1.0] * len_axes
- elif n == 1 and (_dtypes_impl.is_scalar(varargs[0]) or varargs[0].ndim == 0):
- # single scalar or 0D tensor for all axes (np.ndim(varargs[0]) == 0)
- dx = varargs * len_axes
- elif n == len_axes:
- # scalar or 1d array for each axis
- dx = list(varargs)
- for i, distances in enumerate(dx):
- distances = torch.as_tensor(distances)
- if distances.ndim == 0:
- continue
- elif distances.ndim != 1:
- raise ValueError("distances must be either scalars or 1d")
- if len(distances) != f.shape[axes[i]]:
- raise ValueError(
- "when 1d, distances must match "
- "the length of the corresponding dimension"
- )
- if not (distances.dtype.is_floating_point or distances.dtype.is_complex):
- distances = distances.double()
- diffx = torch.diff(distances)
- # if distances are constant reduce to the scalar case
- # since it brings a consistent speedup
- if (diffx == diffx[0]).all():
- diffx = diffx[0]
- dx[i] = diffx
- else:
- raise TypeError("invalid number of arguments")
- if edge_order > 2:
- raise ValueError("'edge_order' greater than 2 not supported")
- # use central differences on interior and one-sided differences on the
- # endpoints. This preserves second order-accuracy over the full domain.
- outvals = []
- # create slice objects --- initially all are [:, :, ..., :]
- slice1 = [slice(None)] * N
- slice2 = [slice(None)] * N
- slice3 = [slice(None)] * N
- slice4 = [slice(None)] * N
- otype = f.dtype
- if _dtypes_impl.python_type_for_torch(otype) in (int, bool):
- # Convert to floating point.
- # First check if f is a numpy integer type; if so, convert f to float64
- # to avoid modular arithmetic when computing the changes in f.
- f = f.double()
- otype = torch.float64
- for axis, ax_dx in zip(axes, dx):
- if f.shape[axis] < edge_order + 1:
- raise ValueError(
- "Shape of array too small to calculate a numerical gradient, "
- "at least (edge_order + 1) elements are required."
- )
- # result allocation
- out = torch.empty_like(f, dtype=otype)
- # spacing for the current axis (NB: np.ndim(ax_dx) == 0)
- uniform_spacing = _dtypes_impl.is_scalar(ax_dx) or ax_dx.ndim == 0
- # Numerical differentiation: 2nd order interior
- slice1[axis] = slice(1, -1)
- slice2[axis] = slice(None, -2)
- slice3[axis] = slice(1, -1)
- slice4[axis] = slice(2, None)
- if uniform_spacing:
- out[tuple(slice1)] = (f[tuple(slice4)] - f[tuple(slice2)]) / (2.0 * ax_dx)
- else:
- dx1 = ax_dx[0:-1]
- dx2 = ax_dx[1:]
- a = -(dx2) / (dx1 * (dx1 + dx2))
- b = (dx2 - dx1) / (dx1 * dx2)
- c = dx1 / (dx2 * (dx1 + dx2))
- # fix the shape for broadcasting
- shape = [1] * N
- shape[axis] = -1
- a = a.reshape(shape)
- b = b.reshape(shape)
- c = c.reshape(shape)
- # 1D equivalent -- out[1:-1] = a * f[:-2] + b * f[1:-1] + c * f[2:]
- out[tuple(slice1)] = (
- a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)]
- )
- # Numerical differentiation: 1st order edges
- if edge_order == 1:
- slice1[axis] = 0
- slice2[axis] = 1
- slice3[axis] = 0
- dx_0 = ax_dx if uniform_spacing else ax_dx[0]
- # 1D equivalent -- out[0] = (f[1] - f[0]) / (x[1] - x[0])
- out[tuple(slice1)] = (f[tuple(slice2)] - f[tuple(slice3)]) / dx_0
- slice1[axis] = -1
- slice2[axis] = -1
- slice3[axis] = -2
- dx_n = ax_dx if uniform_spacing else ax_dx[-1]
- # 1D equivalent -- out[-1] = (f[-1] - f[-2]) / (x[-1] - x[-2])
- out[tuple(slice1)] = (f[tuple(slice2)] - f[tuple(slice3)]) / dx_n
- # Numerical differentiation: 2nd order edges
- else:
- slice1[axis] = 0
- slice2[axis] = 0
- slice3[axis] = 1
- slice4[axis] = 2
- if uniform_spacing:
- a = -1.5 / ax_dx
- b = 2.0 / ax_dx
- c = -0.5 / ax_dx
- else:
- dx1 = ax_dx[0]
- dx2 = ax_dx[1]
- a = -(2.0 * dx1 + dx2) / (dx1 * (dx1 + dx2))
- b = (dx1 + dx2) / (dx1 * dx2)
- c = -dx1 / (dx2 * (dx1 + dx2))
- # 1D equivalent -- out[0] = a * f[0] + b * f[1] + c * f[2]
- out[tuple(slice1)] = (
- a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)]
- )
- slice1[axis] = -1
- slice2[axis] = -3
- slice3[axis] = -2
- slice4[axis] = -1
- if uniform_spacing:
- a = 0.5 / ax_dx
- b = -2.0 / ax_dx
- c = 1.5 / ax_dx
- else:
- dx1 = ax_dx[-2]
- dx2 = ax_dx[-1]
- a = (dx2) / (dx1 * (dx1 + dx2))
- b = -(dx2 + dx1) / (dx1 * dx2)
- c = (2.0 * dx2 + dx1) / (dx2 * (dx1 + dx2))
- # 1D equivalent -- out[-1] = a * f[-3] + b * f[-2] + c * f[-1]
- out[tuple(slice1)] = (
- a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)]
- )
- outvals.append(out)
- # reset the slice object in this dimension to ":"
- slice1[axis] = slice(None)
- slice2[axis] = slice(None)
- slice3[axis] = slice(None)
- slice4[axis] = slice(None)
- if len_axes == 1:
- return outvals[0]
- else:
- return outvals
- # ### Type/shape etc queries ###
- def round(a: ArrayLike, decimals=0, out: Optional[OutArray] = None):
- if a.is_floating_point():
- result = torch.round(a, decimals=decimals)
- elif a.is_complex():
- # RuntimeError: "round_cpu" not implemented for 'ComplexFloat'
- result = torch.complex(
- torch.round(a.real, decimals=decimals),
- torch.round(a.imag, decimals=decimals),
- )
- else:
- # RuntimeError: "round_cpu" not implemented for 'int'
- result = a
- return result
- around = round
- round_ = round
- def real_if_close(a: ArrayLike, tol=100):
- if not torch.is_complex(a):
- return a
- if tol > 1:
- # Undocumented in numpy: if tol < 1, it's an absolute tolerance!
- # Otherwise, tol > 1 is relative tolerance, in units of the dtype epsilon
- # https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/type_check.py#L577
- tol = tol * torch.finfo(a.dtype).eps
- mask = torch.abs(a.imag) < tol
- return a.real if mask.all() else a
- def real(a: ArrayLike):
- return torch.real(a)
- def imag(a: ArrayLike):
- if a.is_complex():
- return a.imag
- return torch.zeros_like(a)
- def iscomplex(x: ArrayLike):
- if torch.is_complex(x):
- return x.imag != 0
- return torch.zeros_like(x, dtype=torch.bool)
- def isreal(x: ArrayLike):
- if torch.is_complex(x):
- return x.imag == 0
- return torch.ones_like(x, dtype=torch.bool)
- def iscomplexobj(x: ArrayLike):
- return torch.is_complex(x)
- def isrealobj(x: ArrayLike):
- return not torch.is_complex(x)
- def isneginf(x: ArrayLike, out: Optional[OutArray] = None):
- return torch.isneginf(x)
- def isposinf(x: ArrayLike, out: Optional[OutArray] = None):
- return torch.isposinf(x)
- def i0(x: ArrayLike):
- return torch.special.i0(x)
- def isscalar(a):
- # We need to use normalize_array_like, but we don't want to export it in funcs.py
- from ._normalizations import normalize_array_like
- try:
- t = normalize_array_like(a)
- return t.numel() == 1
- except Exception:
- return False
- # ### Filter windows ###
- def hamming(M):
- dtype = _dtypes_impl.default_dtypes().float_dtype
- return torch.hamming_window(M, periodic=False, dtype=dtype)
- def hanning(M):
- dtype = _dtypes_impl.default_dtypes().float_dtype
- return torch.hann_window(M, periodic=False, dtype=dtype)
- def kaiser(M, beta):
- dtype = _dtypes_impl.default_dtypes().float_dtype
- return torch.kaiser_window(M, beta=beta, periodic=False, dtype=dtype)
- def blackman(M):
- dtype = _dtypes_impl.default_dtypes().float_dtype
- return torch.blackman_window(M, periodic=False, dtype=dtype)
- def bartlett(M):
- dtype = _dtypes_impl.default_dtypes().float_dtype
- return torch.bartlett_window(M, periodic=False, dtype=dtype)
- # ### Dtype routines ###
- # vendored from https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/type_check.py#L666
- array_type = [
- [torch.float16, torch.float32, torch.float64],
- [None, torch.complex64, torch.complex128],
- ]
- array_precision = {
- torch.float16: 0,
- torch.float32: 1,
- torch.float64: 2,
- torch.complex64: 1,
- torch.complex128: 2,
- }
- def common_type(*tensors: ArrayLike):
- is_complex = False
- precision = 0
- for a in tensors:
- t = a.dtype
- if iscomplexobj(a):
- is_complex = True
- if not (t.is_floating_point or t.is_complex):
- p = 2 # array_precision[_nx.double]
- else:
- p = array_precision.get(t, None)
- if p is None:
- raise TypeError("can't get common type for non-numeric array")
- precision = builtins.max(precision, p)
- if is_complex:
- return array_type[1][precision]
- else:
- return array_type[0][precision]
- # ### histograms ###
- def histogram(
- a: ArrayLike,
- bins: ArrayLike = 10,
- range=None,
- normed=None,
- weights: Optional[ArrayLike] = None,
- density=None,
- ):
- if normed is not None:
- raise ValueError("normed argument is deprecated, use density= instead")
- if weights is not None and weights.dtype.is_complex:
- raise NotImplementedError("complex weights histogram.")
- is_a_int = not (a.dtype.is_floating_point or a.dtype.is_complex)
- is_w_int = weights is None or not weights.dtype.is_floating_point
- if is_a_int:
- a = a.double()
- if weights is not None:
- weights = _util.cast_if_needed(weights, a.dtype)
- if isinstance(bins, torch.Tensor):
- if bins.ndim == 0:
- # bins was a single int
- bins = operator.index(bins)
- else:
- bins = _util.cast_if_needed(bins, a.dtype)
- if range is None:
- h, b = torch.histogram(a, bins, weight=weights, density=bool(density))
- else:
- h, b = torch.histogram(
- a, bins, range=range, weight=weights, density=bool(density)
- )
- if not density and is_w_int:
- h = h.long()
- if is_a_int:
- b = b.long()
- return h, b
- def histogram2d(
- x,
- y,
- bins=10,
- range: Optional[ArrayLike] = None,
- normed=None,
- weights: Optional[ArrayLike] = None,
- density=None,
- ):
- # vendored from https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/twodim_base.py#L655-L821
- if len(x) != len(y):
- raise ValueError("x and y must have the same length.")
- try:
- N = len(bins)
- except TypeError:
- N = 1
- if N != 1 and N != 2:
- bins = [bins, bins]
- h, e = histogramdd((x, y), bins, range, normed, weights, density)
- return h, e[0], e[1]
- def histogramdd(
- sample,
- bins=10,
- range: Optional[ArrayLike] = None,
- normed=None,
- weights: Optional[ArrayLike] = None,
- density=None,
- ):
- # have to normalize manually because `sample` interpretation differs
- # for a list of lists and a 2D array
- if normed is not None:
- raise ValueError("normed argument is deprecated, use density= instead")
- from ._normalizations import normalize_array_like, normalize_seq_array_like
- if isinstance(sample, (list, tuple)):
- sample = normalize_array_like(sample).T
- else:
- sample = normalize_array_like(sample)
- sample = torch.atleast_2d(sample)
- if not (sample.dtype.is_floating_point or sample.dtype.is_complex):
- sample = sample.double()
- # bins is either an int, or a sequence of ints or a sequence of arrays
- bins_is_array = not (
- isinstance(bins, int) or builtins.all(isinstance(b, int) for b in bins)
- )
- if bins_is_array:
- bins = normalize_seq_array_like(bins)
- bins_dtypes = [b.dtype for b in bins]
- bins = [_util.cast_if_needed(b, sample.dtype) for b in bins]
- if range is not None:
- range = range.flatten().tolist()
- if weights is not None:
- # range=... is required : interleave min and max values per dimension
- mm = sample.aminmax(dim=0)
- range = torch.cat(mm).reshape(2, -1).T.flatten()
- range = tuple(range.tolist())
- weights = _util.cast_if_needed(weights, sample.dtype)
- w_kwd = {"weight": weights}
- else:
- w_kwd = {}
- h, b = torch.histogramdd(sample, bins, range, density=bool(density), **w_kwd)
- if bins_is_array:
- b = [_util.cast_if_needed(bb, dtyp) for bb, dtyp in zip(b, bins_dtypes)]
- return h, b
- # ### odds and ends
- def min_scalar_type(a: ArrayLike, /):
- # https://github.com/numpy/numpy/blob/maintenance/1.24.x/numpy/core/src/multiarray/convert_datatype.c#L1288
- from ._dtypes import DType
- if a.numel() > 1:
- # numpy docs: "For non-scalar array a, returns the vector's dtype unmodified."
- return DType(a.dtype)
- if a.dtype == torch.bool:
- dtype = torch.bool
- elif a.dtype.is_complex:
- fi = torch.finfo(torch.float32)
- fits_in_single = a.dtype == torch.complex64 or (
- fi.min <= a.real <= fi.max and fi.min <= a.imag <= fi.max
- )
- dtype = torch.complex64 if fits_in_single else torch.complex128
- elif a.dtype.is_floating_point:
- for dt in [torch.float16, torch.float32, torch.float64]:
- fi = torch.finfo(dt)
- if fi.min <= a <= fi.max:
- dtype = dt
- break
- else:
- # must be integer
- for dt in [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]:
- # Prefer unsigned int where possible, as numpy does.
- ii = torch.iinfo(dt)
- if ii.min <= a <= ii.max:
- dtype = dt
- break
- return DType(dtype)
- def pad(array: ArrayLike, pad_width: ArrayLike, mode="constant", **kwargs):
- if mode != "constant":
- raise NotImplementedError
- value = kwargs.get("constant_values", 0)
- # `value` must be a python scalar for torch.nn.functional.pad
- typ = _dtypes_impl.python_type_for_torch(array.dtype)
- value = typ(value)
- pad_width = torch.broadcast_to(pad_width, (array.ndim, 2))
- pad_width = torch.flip(pad_width, (0,)).flatten()
- return torch.nn.functional.pad(array, tuple(pad_width), value=value)
|