| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591 |
- # mypy: ignore-errors
- from __future__ import annotations
- import builtins
- import math
- import operator
- from typing import Sequence
- import torch
- from . import _dtypes, _dtypes_impl, _funcs, _ufuncs, _util
- from ._normalizations import (
- ArrayLike,
- normalize_array_like,
- normalizer,
- NotImplementedType,
- )
- newaxis = None
- FLAGS = [
- "C_CONTIGUOUS",
- "F_CONTIGUOUS",
- "OWNDATA",
- "WRITEABLE",
- "ALIGNED",
- "WRITEBACKIFCOPY",
- "FNC",
- "FORC",
- "BEHAVED",
- "CARRAY",
- "FARRAY",
- ]
- SHORTHAND_TO_FLAGS = {
- "C": "C_CONTIGUOUS",
- "F": "F_CONTIGUOUS",
- "O": "OWNDATA",
- "W": "WRITEABLE",
- "A": "ALIGNED",
- "X": "WRITEBACKIFCOPY",
- "B": "BEHAVED",
- "CA": "CARRAY",
- "FA": "FARRAY",
- }
- class Flags:
- def __init__(self, flag_to_value: dict):
- assert all(k in FLAGS for k in flag_to_value.keys()) # sanity check
- self._flag_to_value = flag_to_value
- def __getattr__(self, attr: str):
- if attr.islower() and attr.upper() in FLAGS:
- return self[attr.upper()]
- else:
- raise AttributeError(f"No flag attribute '{attr}'")
- def __getitem__(self, key):
- if key in SHORTHAND_TO_FLAGS.keys():
- key = SHORTHAND_TO_FLAGS[key]
- if key in FLAGS:
- try:
- return self._flag_to_value[key]
- except KeyError as e:
- raise NotImplementedError(f"{key=}") from e
- else:
- raise KeyError(f"No flag key '{key}'")
- def __setattr__(self, attr, value):
- if attr.islower() and attr.upper() in FLAGS:
- self[attr.upper()] = value
- else:
- super().__setattr__(attr, value)
- def __setitem__(self, key, value):
- if key in FLAGS or key in SHORTHAND_TO_FLAGS.keys():
- raise NotImplementedError("Modifying flags is not implemented")
- else:
- raise KeyError(f"No flag key '{key}'")
- def create_method(fn, name=None):
- name = name or fn.__name__
- def f(*args, **kwargs):
- return fn(*args, **kwargs)
- f.__name__ = name
- f.__qualname__ = f"ndarray.{name}"
- return f
- # Map ndarray.name_method -> np.name_func
- # If name_func == None, it means that name_method == name_func
- methods = {
- "clip": None,
- "nonzero": None,
- "repeat": None,
- "round": None,
- "squeeze": None,
- "swapaxes": None,
- "ravel": None,
- # linalg
- "diagonal": None,
- "dot": None,
- "trace": None,
- # sorting
- "argsort": None,
- "searchsorted": None,
- # reductions
- "argmax": None,
- "argmin": None,
- "any": None,
- "all": None,
- "max": None,
- "min": None,
- "ptp": None,
- "sum": None,
- "prod": None,
- "mean": None,
- "var": None,
- "std": None,
- # scans
- "cumsum": None,
- "cumprod": None,
- # advanced indexing
- "take": None,
- "choose": None,
- }
- dunder = {
- "abs": "absolute",
- "invert": None,
- "pos": "positive",
- "neg": "negative",
- "gt": "greater",
- "lt": "less",
- "ge": "greater_equal",
- "le": "less_equal",
- }
- # dunder methods with right-looking and in-place variants
- ri_dunder = {
- "add": None,
- "sub": "subtract",
- "mul": "multiply",
- "truediv": "divide",
- "floordiv": "floor_divide",
- "pow": "power",
- "mod": "remainder",
- "and": "bitwise_and",
- "or": "bitwise_or",
- "xor": "bitwise_xor",
- "lshift": "left_shift",
- "rshift": "right_shift",
- "matmul": None,
- }
- def _upcast_int_indices(index):
- if isinstance(index, torch.Tensor):
- if index.dtype in (torch.int8, torch.int16, torch.int32, torch.uint8):
- return index.to(torch.int64)
- elif isinstance(index, tuple):
- return tuple(_upcast_int_indices(i) for i in index)
- return index
- # Used to indicate that a parameter is unspecified (as opposed to explicitly
- # `None`)
- class _Unspecified:
- pass
- _Unspecified.unspecified = _Unspecified()
- ###############################################################
- # ndarray class #
- ###############################################################
- class ndarray:
- def __init__(self, t=None):
- if t is None:
- self.tensor = torch.Tensor()
- elif isinstance(t, torch.Tensor):
- self.tensor = t
- else:
- raise ValueError(
- "ndarray constructor is not recommended; prefer"
- "either array(...) or zeros/empty(...)"
- )
- # Register NumPy functions as methods
- for method, name in methods.items():
- fn = getattr(_funcs, name or method)
- vars()[method] = create_method(fn, method)
- # Regular methods but coming from ufuncs
- conj = create_method(_ufuncs.conjugate, "conj")
- conjugate = create_method(_ufuncs.conjugate)
- for method, name in dunder.items():
- fn = getattr(_ufuncs, name or method)
- method = f"__{method}__"
- vars()[method] = create_method(fn, method)
- for method, name in ri_dunder.items():
- fn = getattr(_ufuncs, name or method)
- plain = f"__{method}__"
- vars()[plain] = create_method(fn, plain)
- rvar = f"__r{method}__"
- vars()[rvar] = create_method(lambda self, other, fn=fn: fn(other, self), rvar)
- ivar = f"__i{method}__"
- vars()[ivar] = create_method(
- lambda self, other, fn=fn: fn(self, other, out=self), ivar
- )
- # There's no __idivmod__
- __divmod__ = create_method(_ufuncs.divmod, "__divmod__")
- __rdivmod__ = create_method(
- lambda self, other: _ufuncs.divmod(other, self), "__rdivmod__"
- )
- # prevent loop variables leaking into the ndarray class namespace
- del ivar, rvar, name, plain, fn, method
- @property
- def shape(self):
- return tuple(self.tensor.shape)
- @property
- def size(self):
- return self.tensor.numel()
- @property
- def ndim(self):
- return self.tensor.ndim
- @property
- def dtype(self):
- return _dtypes.dtype(self.tensor.dtype)
- @property
- def strides(self):
- elsize = self.tensor.element_size()
- return tuple(stride * elsize for stride in self.tensor.stride())
- @property
- def itemsize(self):
- return self.tensor.element_size()
- @property
- def flags(self):
- # Note contiguous in torch is assumed C-style
- return Flags(
- {
- "C_CONTIGUOUS": self.tensor.is_contiguous(),
- "F_CONTIGUOUS": self.T.tensor.is_contiguous(),
- "OWNDATA": self.tensor._base is None,
- "WRITEABLE": True, # pytorch does not have readonly tensors
- }
- )
- @property
- def data(self):
- return self.tensor.data_ptr()
- @property
- def nbytes(self):
- return self.tensor.storage().nbytes()
- @property
- def T(self):
- return self.transpose()
- @property
- def real(self):
- return _funcs.real(self)
- @real.setter
- def real(self, value):
- self.tensor.real = asarray(value).tensor
- @property
- def imag(self):
- return _funcs.imag(self)
- @imag.setter
- def imag(self, value):
- self.tensor.imag = asarray(value).tensor
- # ctors
- def astype(self, dtype, order="K", casting="unsafe", subok=True, copy=True):
- if order != "K":
- raise NotImplementedError(f"astype(..., order={order} is not implemented.")
- if casting != "unsafe":
- raise NotImplementedError(
- f"astype(..., casting={casting} is not implemented."
- )
- if not subok:
- raise NotImplementedError(f"astype(..., subok={subok} is not implemented.")
- if not copy:
- raise NotImplementedError(f"astype(..., copy={copy} is not implemented.")
- torch_dtype = _dtypes.dtype(dtype).torch_dtype
- t = self.tensor.to(torch_dtype)
- return ndarray(t)
- @normalizer
- def copy(self: ArrayLike, order: NotImplementedType = "C"):
- return self.clone()
- @normalizer
- def flatten(self: ArrayLike, order: NotImplementedType = "C"):
- return torch.flatten(self)
- def resize(self, *new_shape, refcheck=False):
- # NB: differs from np.resize: fills with zeros instead of making repeated copies of input.
- if refcheck:
- raise NotImplementedError(
- f"resize(..., refcheck={refcheck} is not implemented."
- )
- if new_shape in [(), (None,)]:
- return
- # support both x.resize((2, 2)) and x.resize(2, 2)
- if len(new_shape) == 1:
- new_shape = new_shape[0]
- if isinstance(new_shape, int):
- new_shape = (new_shape,)
- if builtins.any(x < 0 for x in new_shape):
- raise ValueError("all elements of `new_shape` must be non-negative")
- new_numel, old_numel = math.prod(new_shape), self.tensor.numel()
- self.tensor.resize_(new_shape)
- if new_numel >= old_numel:
- # zero-fill new elements
- assert self.tensor.is_contiguous()
- b = self.tensor.flatten() # does not copy
- b[old_numel:].zero_()
- def view(self, dtype=_Unspecified.unspecified, type=_Unspecified.unspecified):
- if dtype is _Unspecified.unspecified:
- dtype = self.dtype
- if type is not _Unspecified.unspecified:
- raise NotImplementedError(f"view(..., type={type} is not implemented.")
- torch_dtype = _dtypes.dtype(dtype).torch_dtype
- tview = self.tensor.view(torch_dtype)
- return ndarray(tview)
- @normalizer
- def fill(self, value: ArrayLike):
- # Both Pytorch and NumPy accept 0D arrays/tensors and scalars, and
- # error out on D > 0 arrays
- self.tensor.fill_(value)
- def tolist(self):
- return self.tensor.tolist()
- def __iter__(self):
- return (ndarray(x) for x in self.tensor.__iter__())
- def __str__(self):
- return (
- str(self.tensor)
- .replace("tensor", "torch.ndarray")
- .replace("dtype=torch.", "dtype=")
- )
- __repr__ = create_method(__str__)
- def __eq__(self, other):
- try:
- return _ufuncs.equal(self, other)
- except (RuntimeError, TypeError):
- # Failed to convert other to array: definitely not equal.
- falsy = torch.full(self.shape, fill_value=False, dtype=bool)
- return asarray(falsy)
- def __ne__(self, other):
- return ~(self == other)
- def __index__(self):
- try:
- return operator.index(self.tensor.item())
- except Exception as exc:
- raise TypeError(
- "only integer scalar arrays can be converted to a scalar index"
- ) from exc
- def __bool__(self):
- return bool(self.tensor)
- def __int__(self):
- return int(self.tensor)
- def __float__(self):
- return float(self.tensor)
- def __complex__(self):
- return complex(self.tensor)
- def is_integer(self):
- try:
- v = self.tensor.item()
- result = int(v) == v
- except Exception:
- result = False
- return result
- def __len__(self):
- return self.tensor.shape[0]
- def __contains__(self, x):
- return self.tensor.__contains__(x)
- def transpose(self, *axes):
- # np.transpose(arr, axis=None) but arr.transpose(*axes)
- return _funcs.transpose(self, axes)
- def reshape(self, *shape, order="C"):
- # arr.reshape(shape) and arr.reshape(*shape)
- return _funcs.reshape(self, shape, order=order)
- def sort(self, axis=-1, kind=None, order=None):
- # ndarray.sort works in-place
- _funcs.copyto(self, _funcs.sort(self, axis, kind, order))
- def item(self, *args):
- # Mimic NumPy's implementation with three special cases (no arguments,
- # a flat index and a multi-index):
- # https://github.com/numpy/numpy/blob/main/numpy/core/src/multiarray/methods.c#L702
- if args == ():
- return self.tensor.item()
- elif len(args) == 1:
- # int argument
- return self.ravel()[args[0]]
- else:
- return self.__getitem__(args)
- def __getitem__(self, index):
- tensor = self.tensor
- def neg_step(i, s):
- if not (isinstance(s, slice) and s.step is not None and s.step < 0):
- return s
- nonlocal tensor
- tensor = torch.flip(tensor, (i,))
- # Account for the fact that a slice includes the start but not the end
- assert isinstance(s.start, int) or s.start is None
- assert isinstance(s.stop, int) or s.stop is None
- start = s.stop + 1 if s.stop else None
- stop = s.start + 1 if s.start else None
- return slice(start, stop, -s.step)
- if isinstance(index, Sequence):
- index = type(index)(neg_step(i, s) for i, s in enumerate(index))
- else:
- index = neg_step(0, index)
- index = _util.ndarrays_to_tensors(index)
- index = _upcast_int_indices(index)
- return ndarray(tensor.__getitem__(index))
- def __setitem__(self, index, value):
- index = _util.ndarrays_to_tensors(index)
- index = _upcast_int_indices(index)
- if not _dtypes_impl.is_scalar(value):
- value = normalize_array_like(value)
- value = _util.cast_if_needed(value, self.tensor.dtype)
- return self.tensor.__setitem__(index, value)
- take = _funcs.take
- put = _funcs.put
- def __dlpack__(self, *, stream=None):
- return self.tensor.__dlpack__(stream=stream)
- def __dlpack_device__(self):
- return self.tensor.__dlpack_device__()
- def _tolist(obj):
- """Recursively convert tensors into lists."""
- a1 = []
- for elem in obj:
- if isinstance(elem, (list, tuple)):
- elem = _tolist(elem)
- if isinstance(elem, ndarray):
- a1.append(elem.tensor.tolist())
- else:
- a1.append(elem)
- return a1
- # This is the ideally the only place which talks to ndarray directly.
- # The rest goes through asarray (preferred) or array.
- def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=None):
- if subok is not False:
- raise NotImplementedError("'subok' parameter is not supported.")
- if like is not None:
- raise NotImplementedError("'like' parameter is not supported.")
- if order != "K":
- raise NotImplementedError
- # a happy path
- if (
- isinstance(obj, ndarray)
- and copy is False
- and dtype is None
- and ndmin <= obj.ndim
- ):
- return obj
- if isinstance(obj, (list, tuple)):
- # FIXME and they have the same dtype, device, etc
- if obj and all(isinstance(x, torch.Tensor) for x in obj):
- # list of arrays: *under torch.Dynamo* these are FakeTensors
- obj = torch.stack(obj)
- else:
- # XXX: remove tolist
- # lists of ndarrays: [1, [2, 3], ndarray(4)] convert to lists of lists
- obj = _tolist(obj)
- # is obj an ndarray already?
- if isinstance(obj, ndarray):
- obj = obj.tensor
- # is a specific dtype requested?
- torch_dtype = None
- if dtype is not None:
- torch_dtype = _dtypes.dtype(dtype).torch_dtype
- tensor = _util._coerce_to_tensor(obj, torch_dtype, copy, ndmin)
- return ndarray(tensor)
- def asarray(a, dtype=None, order="K", *, like=None):
- return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0)
- def ascontiguousarray(a, dtype=None, *, like=None):
- arr = asarray(a, dtype=dtype, like=like)
- if not arr.tensor.is_contiguous():
- arr.tensor = arr.tensor.contiguous()
- return arr
- def from_dlpack(x, /):
- t = torch.from_dlpack(x)
- return ndarray(t)
- def _extract_dtype(entry):
- try:
- dty = _dtypes.dtype(entry)
- except Exception:
- dty = asarray(entry).dtype
- return dty
- def can_cast(from_, to, casting="safe"):
- from_ = _extract_dtype(from_)
- to_ = _extract_dtype(to)
- return _dtypes_impl.can_cast_impl(from_.torch_dtype, to_.torch_dtype, casting)
- def result_type(*arrays_and_dtypes):
- tensors = []
- for entry in arrays_and_dtypes:
- try:
- t = asarray(entry).tensor
- except (RuntimeError, ValueError, TypeError):
- dty = _dtypes.dtype(entry)
- t = torch.empty(1, dtype=dty.torch_dtype)
- tensors.append(t)
- torch_dtype = _dtypes_impl.result_type_impl(*tensors)
- return _dtypes.dtype(torch_dtype)
|