_ndarray.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591
  1. # mypy: ignore-errors
  2. from __future__ import annotations
  3. import builtins
  4. import math
  5. import operator
  6. from typing import Sequence
  7. import torch
  8. from . import _dtypes, _dtypes_impl, _funcs, _ufuncs, _util
  9. from ._normalizations import (
  10. ArrayLike,
  11. normalize_array_like,
  12. normalizer,
  13. NotImplementedType,
  14. )
  15. newaxis = None
  16. FLAGS = [
  17. "C_CONTIGUOUS",
  18. "F_CONTIGUOUS",
  19. "OWNDATA",
  20. "WRITEABLE",
  21. "ALIGNED",
  22. "WRITEBACKIFCOPY",
  23. "FNC",
  24. "FORC",
  25. "BEHAVED",
  26. "CARRAY",
  27. "FARRAY",
  28. ]
  29. SHORTHAND_TO_FLAGS = {
  30. "C": "C_CONTIGUOUS",
  31. "F": "F_CONTIGUOUS",
  32. "O": "OWNDATA",
  33. "W": "WRITEABLE",
  34. "A": "ALIGNED",
  35. "X": "WRITEBACKIFCOPY",
  36. "B": "BEHAVED",
  37. "CA": "CARRAY",
  38. "FA": "FARRAY",
  39. }
  40. class Flags:
  41. def __init__(self, flag_to_value: dict):
  42. assert all(k in FLAGS for k in flag_to_value.keys()) # sanity check
  43. self._flag_to_value = flag_to_value
  44. def __getattr__(self, attr: str):
  45. if attr.islower() and attr.upper() in FLAGS:
  46. return self[attr.upper()]
  47. else:
  48. raise AttributeError(f"No flag attribute '{attr}'")
  49. def __getitem__(self, key):
  50. if key in SHORTHAND_TO_FLAGS.keys():
  51. key = SHORTHAND_TO_FLAGS[key]
  52. if key in FLAGS:
  53. try:
  54. return self._flag_to_value[key]
  55. except KeyError as e:
  56. raise NotImplementedError(f"{key=}") from e
  57. else:
  58. raise KeyError(f"No flag key '{key}'")
  59. def __setattr__(self, attr, value):
  60. if attr.islower() and attr.upper() in FLAGS:
  61. self[attr.upper()] = value
  62. else:
  63. super().__setattr__(attr, value)
  64. def __setitem__(self, key, value):
  65. if key in FLAGS or key in SHORTHAND_TO_FLAGS.keys():
  66. raise NotImplementedError("Modifying flags is not implemented")
  67. else:
  68. raise KeyError(f"No flag key '{key}'")
  69. def create_method(fn, name=None):
  70. name = name or fn.__name__
  71. def f(*args, **kwargs):
  72. return fn(*args, **kwargs)
  73. f.__name__ = name
  74. f.__qualname__ = f"ndarray.{name}"
  75. return f
  76. # Map ndarray.name_method -> np.name_func
  77. # If name_func == None, it means that name_method == name_func
  78. methods = {
  79. "clip": None,
  80. "nonzero": None,
  81. "repeat": None,
  82. "round": None,
  83. "squeeze": None,
  84. "swapaxes": None,
  85. "ravel": None,
  86. # linalg
  87. "diagonal": None,
  88. "dot": None,
  89. "trace": None,
  90. # sorting
  91. "argsort": None,
  92. "searchsorted": None,
  93. # reductions
  94. "argmax": None,
  95. "argmin": None,
  96. "any": None,
  97. "all": None,
  98. "max": None,
  99. "min": None,
  100. "ptp": None,
  101. "sum": None,
  102. "prod": None,
  103. "mean": None,
  104. "var": None,
  105. "std": None,
  106. # scans
  107. "cumsum": None,
  108. "cumprod": None,
  109. # advanced indexing
  110. "take": None,
  111. "choose": None,
  112. }
  113. dunder = {
  114. "abs": "absolute",
  115. "invert": None,
  116. "pos": "positive",
  117. "neg": "negative",
  118. "gt": "greater",
  119. "lt": "less",
  120. "ge": "greater_equal",
  121. "le": "less_equal",
  122. }
  123. # dunder methods with right-looking and in-place variants
  124. ri_dunder = {
  125. "add": None,
  126. "sub": "subtract",
  127. "mul": "multiply",
  128. "truediv": "divide",
  129. "floordiv": "floor_divide",
  130. "pow": "power",
  131. "mod": "remainder",
  132. "and": "bitwise_and",
  133. "or": "bitwise_or",
  134. "xor": "bitwise_xor",
  135. "lshift": "left_shift",
  136. "rshift": "right_shift",
  137. "matmul": None,
  138. }
  139. def _upcast_int_indices(index):
  140. if isinstance(index, torch.Tensor):
  141. if index.dtype in (torch.int8, torch.int16, torch.int32, torch.uint8):
  142. return index.to(torch.int64)
  143. elif isinstance(index, tuple):
  144. return tuple(_upcast_int_indices(i) for i in index)
  145. return index
  146. # Used to indicate that a parameter is unspecified (as opposed to explicitly
  147. # `None`)
  148. class _Unspecified:
  149. pass
  150. _Unspecified.unspecified = _Unspecified()
  151. ###############################################################
  152. # ndarray class #
  153. ###############################################################
  154. class ndarray:
  155. def __init__(self, t=None):
  156. if t is None:
  157. self.tensor = torch.Tensor()
  158. elif isinstance(t, torch.Tensor):
  159. self.tensor = t
  160. else:
  161. raise ValueError(
  162. "ndarray constructor is not recommended; prefer"
  163. "either array(...) or zeros/empty(...)"
  164. )
  165. # Register NumPy functions as methods
  166. for method, name in methods.items():
  167. fn = getattr(_funcs, name or method)
  168. vars()[method] = create_method(fn, method)
  169. # Regular methods but coming from ufuncs
  170. conj = create_method(_ufuncs.conjugate, "conj")
  171. conjugate = create_method(_ufuncs.conjugate)
  172. for method, name in dunder.items():
  173. fn = getattr(_ufuncs, name or method)
  174. method = f"__{method}__"
  175. vars()[method] = create_method(fn, method)
  176. for method, name in ri_dunder.items():
  177. fn = getattr(_ufuncs, name or method)
  178. plain = f"__{method}__"
  179. vars()[plain] = create_method(fn, plain)
  180. rvar = f"__r{method}__"
  181. vars()[rvar] = create_method(lambda self, other, fn=fn: fn(other, self), rvar)
  182. ivar = f"__i{method}__"
  183. vars()[ivar] = create_method(
  184. lambda self, other, fn=fn: fn(self, other, out=self), ivar
  185. )
  186. # There's no __idivmod__
  187. __divmod__ = create_method(_ufuncs.divmod, "__divmod__")
  188. __rdivmod__ = create_method(
  189. lambda self, other: _ufuncs.divmod(other, self), "__rdivmod__"
  190. )
  191. # prevent loop variables leaking into the ndarray class namespace
  192. del ivar, rvar, name, plain, fn, method
  193. @property
  194. def shape(self):
  195. return tuple(self.tensor.shape)
  196. @property
  197. def size(self):
  198. return self.tensor.numel()
  199. @property
  200. def ndim(self):
  201. return self.tensor.ndim
  202. @property
  203. def dtype(self):
  204. return _dtypes.dtype(self.tensor.dtype)
  205. @property
  206. def strides(self):
  207. elsize = self.tensor.element_size()
  208. return tuple(stride * elsize for stride in self.tensor.stride())
  209. @property
  210. def itemsize(self):
  211. return self.tensor.element_size()
  212. @property
  213. def flags(self):
  214. # Note contiguous in torch is assumed C-style
  215. return Flags(
  216. {
  217. "C_CONTIGUOUS": self.tensor.is_contiguous(),
  218. "F_CONTIGUOUS": self.T.tensor.is_contiguous(),
  219. "OWNDATA": self.tensor._base is None,
  220. "WRITEABLE": True, # pytorch does not have readonly tensors
  221. }
  222. )
  223. @property
  224. def data(self):
  225. return self.tensor.data_ptr()
  226. @property
  227. def nbytes(self):
  228. return self.tensor.storage().nbytes()
  229. @property
  230. def T(self):
  231. return self.transpose()
  232. @property
  233. def real(self):
  234. return _funcs.real(self)
  235. @real.setter
  236. def real(self, value):
  237. self.tensor.real = asarray(value).tensor
  238. @property
  239. def imag(self):
  240. return _funcs.imag(self)
  241. @imag.setter
  242. def imag(self, value):
  243. self.tensor.imag = asarray(value).tensor
  244. # ctors
  245. def astype(self, dtype, order="K", casting="unsafe", subok=True, copy=True):
  246. if order != "K":
  247. raise NotImplementedError(f"astype(..., order={order} is not implemented.")
  248. if casting != "unsafe":
  249. raise NotImplementedError(
  250. f"astype(..., casting={casting} is not implemented."
  251. )
  252. if not subok:
  253. raise NotImplementedError(f"astype(..., subok={subok} is not implemented.")
  254. if not copy:
  255. raise NotImplementedError(f"astype(..., copy={copy} is not implemented.")
  256. torch_dtype = _dtypes.dtype(dtype).torch_dtype
  257. t = self.tensor.to(torch_dtype)
  258. return ndarray(t)
  259. @normalizer
  260. def copy(self: ArrayLike, order: NotImplementedType = "C"):
  261. return self.clone()
  262. @normalizer
  263. def flatten(self: ArrayLike, order: NotImplementedType = "C"):
  264. return torch.flatten(self)
  265. def resize(self, *new_shape, refcheck=False):
  266. # NB: differs from np.resize: fills with zeros instead of making repeated copies of input.
  267. if refcheck:
  268. raise NotImplementedError(
  269. f"resize(..., refcheck={refcheck} is not implemented."
  270. )
  271. if new_shape in [(), (None,)]:
  272. return
  273. # support both x.resize((2, 2)) and x.resize(2, 2)
  274. if len(new_shape) == 1:
  275. new_shape = new_shape[0]
  276. if isinstance(new_shape, int):
  277. new_shape = (new_shape,)
  278. if builtins.any(x < 0 for x in new_shape):
  279. raise ValueError("all elements of `new_shape` must be non-negative")
  280. new_numel, old_numel = math.prod(new_shape), self.tensor.numel()
  281. self.tensor.resize_(new_shape)
  282. if new_numel >= old_numel:
  283. # zero-fill new elements
  284. assert self.tensor.is_contiguous()
  285. b = self.tensor.flatten() # does not copy
  286. b[old_numel:].zero_()
  287. def view(self, dtype=_Unspecified.unspecified, type=_Unspecified.unspecified):
  288. if dtype is _Unspecified.unspecified:
  289. dtype = self.dtype
  290. if type is not _Unspecified.unspecified:
  291. raise NotImplementedError(f"view(..., type={type} is not implemented.")
  292. torch_dtype = _dtypes.dtype(dtype).torch_dtype
  293. tview = self.tensor.view(torch_dtype)
  294. return ndarray(tview)
  295. @normalizer
  296. def fill(self, value: ArrayLike):
  297. # Both Pytorch and NumPy accept 0D arrays/tensors and scalars, and
  298. # error out on D > 0 arrays
  299. self.tensor.fill_(value)
  300. def tolist(self):
  301. return self.tensor.tolist()
  302. def __iter__(self):
  303. return (ndarray(x) for x in self.tensor.__iter__())
  304. def __str__(self):
  305. return (
  306. str(self.tensor)
  307. .replace("tensor", "torch.ndarray")
  308. .replace("dtype=torch.", "dtype=")
  309. )
  310. __repr__ = create_method(__str__)
  311. def __eq__(self, other):
  312. try:
  313. return _ufuncs.equal(self, other)
  314. except (RuntimeError, TypeError):
  315. # Failed to convert other to array: definitely not equal.
  316. falsy = torch.full(self.shape, fill_value=False, dtype=bool)
  317. return asarray(falsy)
  318. def __ne__(self, other):
  319. return ~(self == other)
  320. def __index__(self):
  321. try:
  322. return operator.index(self.tensor.item())
  323. except Exception as exc:
  324. raise TypeError(
  325. "only integer scalar arrays can be converted to a scalar index"
  326. ) from exc
  327. def __bool__(self):
  328. return bool(self.tensor)
  329. def __int__(self):
  330. return int(self.tensor)
  331. def __float__(self):
  332. return float(self.tensor)
  333. def __complex__(self):
  334. return complex(self.tensor)
  335. def is_integer(self):
  336. try:
  337. v = self.tensor.item()
  338. result = int(v) == v
  339. except Exception:
  340. result = False
  341. return result
  342. def __len__(self):
  343. return self.tensor.shape[0]
  344. def __contains__(self, x):
  345. return self.tensor.__contains__(x)
  346. def transpose(self, *axes):
  347. # np.transpose(arr, axis=None) but arr.transpose(*axes)
  348. return _funcs.transpose(self, axes)
  349. def reshape(self, *shape, order="C"):
  350. # arr.reshape(shape) and arr.reshape(*shape)
  351. return _funcs.reshape(self, shape, order=order)
  352. def sort(self, axis=-1, kind=None, order=None):
  353. # ndarray.sort works in-place
  354. _funcs.copyto(self, _funcs.sort(self, axis, kind, order))
  355. def item(self, *args):
  356. # Mimic NumPy's implementation with three special cases (no arguments,
  357. # a flat index and a multi-index):
  358. # https://github.com/numpy/numpy/blob/main/numpy/core/src/multiarray/methods.c#L702
  359. if args == ():
  360. return self.tensor.item()
  361. elif len(args) == 1:
  362. # int argument
  363. return self.ravel()[args[0]]
  364. else:
  365. return self.__getitem__(args)
  366. def __getitem__(self, index):
  367. tensor = self.tensor
  368. def neg_step(i, s):
  369. if not (isinstance(s, slice) and s.step is not None and s.step < 0):
  370. return s
  371. nonlocal tensor
  372. tensor = torch.flip(tensor, (i,))
  373. # Account for the fact that a slice includes the start but not the end
  374. assert isinstance(s.start, int) or s.start is None
  375. assert isinstance(s.stop, int) or s.stop is None
  376. start = s.stop + 1 if s.stop else None
  377. stop = s.start + 1 if s.start else None
  378. return slice(start, stop, -s.step)
  379. if isinstance(index, Sequence):
  380. index = type(index)(neg_step(i, s) for i, s in enumerate(index))
  381. else:
  382. index = neg_step(0, index)
  383. index = _util.ndarrays_to_tensors(index)
  384. index = _upcast_int_indices(index)
  385. return ndarray(tensor.__getitem__(index))
  386. def __setitem__(self, index, value):
  387. index = _util.ndarrays_to_tensors(index)
  388. index = _upcast_int_indices(index)
  389. if not _dtypes_impl.is_scalar(value):
  390. value = normalize_array_like(value)
  391. value = _util.cast_if_needed(value, self.tensor.dtype)
  392. return self.tensor.__setitem__(index, value)
  393. take = _funcs.take
  394. put = _funcs.put
  395. def __dlpack__(self, *, stream=None):
  396. return self.tensor.__dlpack__(stream=stream)
  397. def __dlpack_device__(self):
  398. return self.tensor.__dlpack_device__()
  399. def _tolist(obj):
  400. """Recursively convert tensors into lists."""
  401. a1 = []
  402. for elem in obj:
  403. if isinstance(elem, (list, tuple)):
  404. elem = _tolist(elem)
  405. if isinstance(elem, ndarray):
  406. a1.append(elem.tensor.tolist())
  407. else:
  408. a1.append(elem)
  409. return a1
  410. # This is the ideally the only place which talks to ndarray directly.
  411. # The rest goes through asarray (preferred) or array.
  412. def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=None):
  413. if subok is not False:
  414. raise NotImplementedError("'subok' parameter is not supported.")
  415. if like is not None:
  416. raise NotImplementedError("'like' parameter is not supported.")
  417. if order != "K":
  418. raise NotImplementedError
  419. # a happy path
  420. if (
  421. isinstance(obj, ndarray)
  422. and copy is False
  423. and dtype is None
  424. and ndmin <= obj.ndim
  425. ):
  426. return obj
  427. if isinstance(obj, (list, tuple)):
  428. # FIXME and they have the same dtype, device, etc
  429. if obj and all(isinstance(x, torch.Tensor) for x in obj):
  430. # list of arrays: *under torch.Dynamo* these are FakeTensors
  431. obj = torch.stack(obj)
  432. else:
  433. # XXX: remove tolist
  434. # lists of ndarrays: [1, [2, 3], ndarray(4)] convert to lists of lists
  435. obj = _tolist(obj)
  436. # is obj an ndarray already?
  437. if isinstance(obj, ndarray):
  438. obj = obj.tensor
  439. # is a specific dtype requested?
  440. torch_dtype = None
  441. if dtype is not None:
  442. torch_dtype = _dtypes.dtype(dtype).torch_dtype
  443. tensor = _util._coerce_to_tensor(obj, torch_dtype, copy, ndmin)
  444. return ndarray(tensor)
  445. def asarray(a, dtype=None, order="K", *, like=None):
  446. return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0)
  447. def ascontiguousarray(a, dtype=None, *, like=None):
  448. arr = asarray(a, dtype=dtype, like=like)
  449. if not arr.tensor.is_contiguous():
  450. arr.tensor = arr.tensor.contiguous()
  451. return arr
  452. def from_dlpack(x, /):
  453. t = torch.from_dlpack(x)
  454. return ndarray(t)
  455. def _extract_dtype(entry):
  456. try:
  457. dty = _dtypes.dtype(entry)
  458. except Exception:
  459. dty = asarray(entry).dtype
  460. return dty
  461. def can_cast(from_, to, casting="safe"):
  462. from_ = _extract_dtype(from_)
  463. to_ = _extract_dtype(to)
  464. return _dtypes_impl.can_cast_impl(from_.torch_dtype, to_.torch_dtype, casting)
  465. def result_type(*arrays_and_dtypes):
  466. tensors = []
  467. for entry in arrays_and_dtypes:
  468. try:
  469. t = asarray(entry).tensor
  470. except (RuntimeError, ValueError, TypeError):
  471. dty = _dtypes.dtype(entry)
  472. t = torch.empty(1, dtype=dty.torch_dtype)
  473. tensors.append(t)
  474. torch_dtype = _dtypes_impl.result_type_impl(*tensors)
  475. return _dtypes.dtype(torch_dtype)