_funcs_impl.py 58 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055
  1. # mypy: ignore-errors
  2. """A thin pytorch / numpy compat layer.
  3. Things imported from here have numpy-compatible signatures but operate on
  4. pytorch tensors.
  5. """
  6. # Contents of this module ends up in the main namespace via _funcs.py
  7. # where type annotations are used in conjunction with the @normalizer decorator.
  8. from __future__ import annotations
  9. import builtins
  10. import itertools
  11. import operator
  12. from typing import Optional, Sequence, TYPE_CHECKING
  13. import torch
  14. from . import _dtypes_impl, _util
  15. if TYPE_CHECKING:
  16. from ._normalizations import (
  17. ArrayLike,
  18. ArrayLikeOrScalar,
  19. CastingModes,
  20. DTypeLike,
  21. NDArray,
  22. NotImplementedType,
  23. OutArray,
  24. )
  25. def copy(
  26. a: ArrayLike, order: NotImplementedType = "K", subok: NotImplementedType = False
  27. ):
  28. return a.clone()
  29. def copyto(
  30. dst: NDArray,
  31. src: ArrayLike,
  32. casting: Optional[CastingModes] = "same_kind",
  33. where: NotImplementedType = None,
  34. ):
  35. (src,) = _util.typecast_tensors((src,), dst.dtype, casting=casting)
  36. dst.copy_(src)
  37. def atleast_1d(*arys: ArrayLike):
  38. res = torch.atleast_1d(*arys)
  39. if isinstance(res, tuple):
  40. return list(res)
  41. else:
  42. return res
  43. def atleast_2d(*arys: ArrayLike):
  44. res = torch.atleast_2d(*arys)
  45. if isinstance(res, tuple):
  46. return list(res)
  47. else:
  48. return res
  49. def atleast_3d(*arys: ArrayLike):
  50. res = torch.atleast_3d(*arys)
  51. if isinstance(res, tuple):
  52. return list(res)
  53. else:
  54. return res
  55. def _concat_check(tup, dtype, out):
  56. if tup == ():
  57. raise ValueError("need at least one array to concatenate")
  58. """Check inputs in concatenate et al."""
  59. if out is not None and dtype is not None:
  60. # mimic numpy
  61. raise TypeError(
  62. "concatenate() only takes `out` or `dtype` as an "
  63. "argument, but both were provided."
  64. )
  65. def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"):
  66. """Figure out dtypes, cast if necessary."""
  67. if out is not None or dtype is not None:
  68. # figure out the type of the inputs and outputs
  69. out_dtype = out.dtype.torch_dtype if dtype is None else dtype
  70. else:
  71. out_dtype = _dtypes_impl.result_type_impl(*tensors)
  72. # cast input arrays if necessary; do not broadcast them agains `out`
  73. tensors = _util.typecast_tensors(tensors, out_dtype, casting)
  74. return tensors
  75. def _concatenate(
  76. tensors, axis=0, out=None, dtype=None, casting: Optional[CastingModes] = "same_kind"
  77. ):
  78. # pure torch implementation, used below and in cov/corrcoef below
  79. tensors, axis = _util.axis_none_flatten(*tensors, axis=axis)
  80. tensors = _concat_cast_helper(tensors, out, dtype, casting)
  81. return torch.cat(tensors, axis)
  82. def concatenate(
  83. ar_tuple: Sequence[ArrayLike],
  84. axis=0,
  85. out: Optional[OutArray] = None,
  86. dtype: Optional[DTypeLike] = None,
  87. casting: Optional[CastingModes] = "same_kind",
  88. ):
  89. _concat_check(ar_tuple, dtype, out=out)
  90. result = _concatenate(ar_tuple, axis=axis, out=out, dtype=dtype, casting=casting)
  91. return result
  92. def vstack(
  93. tup: Sequence[ArrayLike],
  94. *,
  95. dtype: Optional[DTypeLike] = None,
  96. casting: Optional[CastingModes] = "same_kind",
  97. ):
  98. _concat_check(tup, dtype, out=None)
  99. tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting)
  100. return torch.vstack(tensors)
  101. row_stack = vstack
  102. def hstack(
  103. tup: Sequence[ArrayLike],
  104. *,
  105. dtype: Optional[DTypeLike] = None,
  106. casting: Optional[CastingModes] = "same_kind",
  107. ):
  108. _concat_check(tup, dtype, out=None)
  109. tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting)
  110. return torch.hstack(tensors)
  111. def dstack(
  112. tup: Sequence[ArrayLike],
  113. *,
  114. dtype: Optional[DTypeLike] = None,
  115. casting: Optional[CastingModes] = "same_kind",
  116. ):
  117. # XXX: in numpy 1.24 dstack does not have dtype and casting keywords
  118. # but {h,v}stack do. Hence add them here for consistency.
  119. _concat_check(tup, dtype, out=None)
  120. tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting)
  121. return torch.dstack(tensors)
  122. def column_stack(
  123. tup: Sequence[ArrayLike],
  124. *,
  125. dtype: Optional[DTypeLike] = None,
  126. casting: Optional[CastingModes] = "same_kind",
  127. ):
  128. # XXX: in numpy 1.24 column_stack does not have dtype and casting keywords
  129. # but row_stack does. (because row_stack is an alias for vstack, really).
  130. # Hence add these keywords here for consistency.
  131. _concat_check(tup, dtype, out=None)
  132. tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting)
  133. return torch.column_stack(tensors)
  134. def stack(
  135. arrays: Sequence[ArrayLike],
  136. axis=0,
  137. out: Optional[OutArray] = None,
  138. *,
  139. dtype: Optional[DTypeLike] = None,
  140. casting: Optional[CastingModes] = "same_kind",
  141. ):
  142. _concat_check(arrays, dtype, out=out)
  143. tensors = _concat_cast_helper(arrays, dtype=dtype, casting=casting)
  144. result_ndim = tensors[0].ndim + 1
  145. axis = _util.normalize_axis_index(axis, result_ndim)
  146. return torch.stack(tensors, axis=axis)
  147. def append(arr: ArrayLike, values: ArrayLike, axis=None):
  148. if axis is None:
  149. if arr.ndim != 1:
  150. arr = arr.flatten()
  151. values = values.flatten()
  152. axis = arr.ndim - 1
  153. return _concatenate((arr, values), axis=axis)
  154. # ### split ###
  155. def _split_helper(tensor, indices_or_sections, axis, strict=False):
  156. if isinstance(indices_or_sections, int):
  157. return _split_helper_int(tensor, indices_or_sections, axis, strict)
  158. elif isinstance(indices_or_sections, (list, tuple)):
  159. # NB: drop split=..., it only applies to split_helper_int
  160. return _split_helper_list(tensor, list(indices_or_sections), axis)
  161. else:
  162. raise TypeError("split_helper: ", type(indices_or_sections))
  163. def _split_helper_int(tensor, indices_or_sections, axis, strict=False):
  164. if not isinstance(indices_or_sections, int):
  165. raise NotImplementedError("split: indices_or_sections")
  166. axis = _util.normalize_axis_index(axis, tensor.ndim)
  167. # numpy: l%n chunks of size (l//n + 1), the rest are sized l//n
  168. l, n = tensor.shape[axis], indices_or_sections
  169. if n <= 0:
  170. raise ValueError
  171. if l % n == 0:
  172. num, sz = n, l // n
  173. lst = [sz] * num
  174. else:
  175. if strict:
  176. raise ValueError("array split does not result in an equal division")
  177. num, sz = l % n, l // n + 1
  178. lst = [sz] * num
  179. lst += [sz - 1] * (n - num)
  180. return torch.split(tensor, lst, axis)
  181. def _split_helper_list(tensor, indices_or_sections, axis):
  182. if not isinstance(indices_or_sections, list):
  183. raise NotImplementedError("split: indices_or_sections: list")
  184. # numpy expects indices, while torch expects lengths of sections
  185. # also, numpy appends zero-size arrays for indices above the shape[axis]
  186. lst = [x for x in indices_or_sections if x <= tensor.shape[axis]]
  187. num_extra = len(indices_or_sections) - len(lst)
  188. lst.append(tensor.shape[axis])
  189. lst = [
  190. lst[0],
  191. ] + [a - b for a, b in zip(lst[1:], lst[:-1])]
  192. lst += [0] * num_extra
  193. return torch.split(tensor, lst, axis)
  194. def array_split(ary: ArrayLike, indices_or_sections, axis=0):
  195. return _split_helper(ary, indices_or_sections, axis)
  196. def split(ary: ArrayLike, indices_or_sections, axis=0):
  197. return _split_helper(ary, indices_or_sections, axis, strict=True)
  198. def hsplit(ary: ArrayLike, indices_or_sections):
  199. if ary.ndim == 0:
  200. raise ValueError("hsplit only works on arrays of 1 or more dimensions")
  201. axis = 1 if ary.ndim > 1 else 0
  202. return _split_helper(ary, indices_or_sections, axis, strict=True)
  203. def vsplit(ary: ArrayLike, indices_or_sections):
  204. if ary.ndim < 2:
  205. raise ValueError("vsplit only works on arrays of 2 or more dimensions")
  206. return _split_helper(ary, indices_or_sections, 0, strict=True)
  207. def dsplit(ary: ArrayLike, indices_or_sections):
  208. if ary.ndim < 3:
  209. raise ValueError("dsplit only works on arrays of 3 or more dimensions")
  210. return _split_helper(ary, indices_or_sections, 2, strict=True)
  211. def kron(a: ArrayLike, b: ArrayLike):
  212. return torch.kron(a, b)
  213. def vander(x: ArrayLike, N=None, increasing=False):
  214. return torch.vander(x, N, increasing)
  215. # ### linspace, geomspace, logspace and arange ###
  216. def linspace(
  217. start: ArrayLike,
  218. stop: ArrayLike,
  219. num=50,
  220. endpoint=True,
  221. retstep=False,
  222. dtype: Optional[DTypeLike] = None,
  223. axis=0,
  224. ):
  225. if axis != 0 or retstep or not endpoint:
  226. raise NotImplementedError
  227. if dtype is None:
  228. dtype = _dtypes_impl.default_dtypes().float_dtype
  229. # XXX: raises TypeError if start or stop are not scalars
  230. return torch.linspace(start, stop, num, dtype=dtype)
  231. def geomspace(
  232. start: ArrayLike,
  233. stop: ArrayLike,
  234. num=50,
  235. endpoint=True,
  236. dtype: Optional[DTypeLike] = None,
  237. axis=0,
  238. ):
  239. if axis != 0 or not endpoint:
  240. raise NotImplementedError
  241. base = torch.pow(stop / start, 1.0 / (num - 1))
  242. logbase = torch.log(base)
  243. return torch.logspace(
  244. torch.log(start) / logbase,
  245. torch.log(stop) / logbase,
  246. num,
  247. base=base,
  248. )
  249. def logspace(
  250. start,
  251. stop,
  252. num=50,
  253. endpoint=True,
  254. base=10.0,
  255. dtype: Optional[DTypeLike] = None,
  256. axis=0,
  257. ):
  258. if axis != 0 or not endpoint:
  259. raise NotImplementedError
  260. return torch.logspace(start, stop, num, base=base, dtype=dtype)
  261. def arange(
  262. start: Optional[ArrayLikeOrScalar] = None,
  263. stop: Optional[ArrayLikeOrScalar] = None,
  264. step: Optional[ArrayLikeOrScalar] = 1,
  265. dtype: Optional[DTypeLike] = None,
  266. *,
  267. like: NotImplementedType = None,
  268. ):
  269. if step == 0:
  270. raise ZeroDivisionError
  271. if stop is None and start is None:
  272. raise TypeError
  273. if stop is None:
  274. # XXX: this breaks if start is passed as a kwarg:
  275. # arange(start=4) should raise (no stop) but doesn't
  276. start, stop = 0, start
  277. if start is None:
  278. start = 0
  279. # the dtype of the result
  280. if dtype is None:
  281. dtype = (
  282. _dtypes_impl.default_dtypes().float_dtype
  283. if any(_dtypes_impl.is_float_or_fp_tensor(x) for x in (start, stop, step))
  284. else _dtypes_impl.default_dtypes().int_dtype
  285. )
  286. work_dtype = torch.float64 if dtype.is_complex else dtype
  287. # RuntimeError: "lt_cpu" not implemented for 'ComplexFloat'. Fall back to eager.
  288. if any(_dtypes_impl.is_complex_or_complex_tensor(x) for x in (start, stop, step)):
  289. raise NotImplementedError
  290. if (step > 0 and start > stop) or (step < 0 and start < stop):
  291. # empty range
  292. return torch.empty(0, dtype=dtype)
  293. result = torch.arange(start, stop, step, dtype=work_dtype)
  294. result = _util.cast_if_needed(result, dtype)
  295. return result
  296. # ### zeros/ones/empty/full ###
  297. def empty(
  298. shape,
  299. dtype: Optional[DTypeLike] = None,
  300. order: NotImplementedType = "C",
  301. *,
  302. like: NotImplementedType = None,
  303. ):
  304. if dtype is None:
  305. dtype = _dtypes_impl.default_dtypes().float_dtype
  306. return torch.empty(shape, dtype=dtype)
  307. # NB: *_like functions deliberately deviate from numpy: it has subok=True
  308. # as the default; we set subok=False and raise on anything else.
  309. def empty_like(
  310. prototype: ArrayLike,
  311. dtype: Optional[DTypeLike] = None,
  312. order: NotImplementedType = "K",
  313. subok: NotImplementedType = False,
  314. shape=None,
  315. ):
  316. result = torch.empty_like(prototype, dtype=dtype)
  317. if shape is not None:
  318. result = result.reshape(shape)
  319. return result
  320. def full(
  321. shape,
  322. fill_value: ArrayLike,
  323. dtype: Optional[DTypeLike] = None,
  324. order: NotImplementedType = "C",
  325. *,
  326. like: NotImplementedType = None,
  327. ):
  328. if isinstance(shape, int):
  329. shape = (shape,)
  330. if dtype is None:
  331. dtype = fill_value.dtype
  332. if not isinstance(shape, (tuple, list)):
  333. shape = (shape,)
  334. return torch.full(shape, fill_value, dtype=dtype)
  335. def full_like(
  336. a: ArrayLike,
  337. fill_value,
  338. dtype: Optional[DTypeLike] = None,
  339. order: NotImplementedType = "K",
  340. subok: NotImplementedType = False,
  341. shape=None,
  342. ):
  343. # XXX: fill_value broadcasts
  344. result = torch.full_like(a, fill_value, dtype=dtype)
  345. if shape is not None:
  346. result = result.reshape(shape)
  347. return result
  348. def ones(
  349. shape,
  350. dtype: Optional[DTypeLike] = None,
  351. order: NotImplementedType = "C",
  352. *,
  353. like: NotImplementedType = None,
  354. ):
  355. if dtype is None:
  356. dtype = _dtypes_impl.default_dtypes().float_dtype
  357. return torch.ones(shape, dtype=dtype)
  358. def ones_like(
  359. a: ArrayLike,
  360. dtype: Optional[DTypeLike] = None,
  361. order: NotImplementedType = "K",
  362. subok: NotImplementedType = False,
  363. shape=None,
  364. ):
  365. result = torch.ones_like(a, dtype=dtype)
  366. if shape is not None:
  367. result = result.reshape(shape)
  368. return result
  369. def zeros(
  370. shape,
  371. dtype: Optional[DTypeLike] = None,
  372. order: NotImplementedType = "C",
  373. *,
  374. like: NotImplementedType = None,
  375. ):
  376. if dtype is None:
  377. dtype = _dtypes_impl.default_dtypes().float_dtype
  378. return torch.zeros(shape, dtype=dtype)
  379. def zeros_like(
  380. a: ArrayLike,
  381. dtype: Optional[DTypeLike] = None,
  382. order: NotImplementedType = "K",
  383. subok: NotImplementedType = False,
  384. shape=None,
  385. ):
  386. result = torch.zeros_like(a, dtype=dtype)
  387. if shape is not None:
  388. result = result.reshape(shape)
  389. return result
  390. # ### cov & corrcoef ###
  391. def _xy_helper_corrcoef(x_tensor, y_tensor=None, rowvar=True):
  392. """Prepare inputs for cov and corrcoef."""
  393. # https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/function_base.py#L2636
  394. if y_tensor is not None:
  395. # make sure x and y are at least 2D
  396. ndim_extra = 2 - x_tensor.ndim
  397. if ndim_extra > 0:
  398. x_tensor = x_tensor.view((1,) * ndim_extra + x_tensor.shape)
  399. if not rowvar and x_tensor.shape[0] != 1:
  400. x_tensor = x_tensor.mT
  401. x_tensor = x_tensor.clone()
  402. ndim_extra = 2 - y_tensor.ndim
  403. if ndim_extra > 0:
  404. y_tensor = y_tensor.view((1,) * ndim_extra + y_tensor.shape)
  405. if not rowvar and y_tensor.shape[0] != 1:
  406. y_tensor = y_tensor.mT
  407. y_tensor = y_tensor.clone()
  408. x_tensor = _concatenate((x_tensor, y_tensor), axis=0)
  409. return x_tensor
  410. def corrcoef(
  411. x: ArrayLike,
  412. y: Optional[ArrayLike] = None,
  413. rowvar=True,
  414. bias=None,
  415. ddof=None,
  416. *,
  417. dtype: Optional[DTypeLike] = None,
  418. ):
  419. if bias is not None or ddof is not None:
  420. # deprecated in NumPy
  421. raise NotImplementedError
  422. xy_tensor = _xy_helper_corrcoef(x, y, rowvar)
  423. is_half = (xy_tensor.dtype == torch.float16) and xy_tensor.is_cpu
  424. if is_half:
  425. # work around torch's "addmm_impl_cpu_" not implemented for 'Half'"
  426. dtype = torch.float32
  427. xy_tensor = _util.cast_if_needed(xy_tensor, dtype)
  428. result = torch.corrcoef(xy_tensor)
  429. if is_half:
  430. result = result.to(torch.float16)
  431. return result
  432. def cov(
  433. m: ArrayLike,
  434. y: Optional[ArrayLike] = None,
  435. rowvar=True,
  436. bias=False,
  437. ddof=None,
  438. fweights: Optional[ArrayLike] = None,
  439. aweights: Optional[ArrayLike] = None,
  440. *,
  441. dtype: Optional[DTypeLike] = None,
  442. ):
  443. m = _xy_helper_corrcoef(m, y, rowvar)
  444. if ddof is None:
  445. ddof = 1 if bias == 0 else 0
  446. is_half = (m.dtype == torch.float16) and m.is_cpu
  447. if is_half:
  448. # work around torch's "addmm_impl_cpu_" not implemented for 'Half'"
  449. dtype = torch.float32
  450. m = _util.cast_if_needed(m, dtype)
  451. result = torch.cov(m, correction=ddof, aweights=aweights, fweights=fweights)
  452. if is_half:
  453. result = result.to(torch.float16)
  454. return result
  455. def _conv_corr_impl(a, v, mode):
  456. dt = _dtypes_impl.result_type_impl(a, v)
  457. a = _util.cast_if_needed(a, dt)
  458. v = _util.cast_if_needed(v, dt)
  459. padding = v.shape[0] - 1 if mode == "full" else mode
  460. if padding == "same" and v.shape[0] % 2 == 0:
  461. # UserWarning: Using padding='same' with even kernel lengths and odd
  462. # dilation may require a zero-padded copy of the input be created
  463. # (Triggered internally at pytorch/aten/src/ATen/native/Convolution.cpp:1010.)
  464. raise NotImplementedError("mode='same' and even-length weights")
  465. # NumPy only accepts 1D arrays; PyTorch requires 2D inputs and 3D weights
  466. aa = a[None, :]
  467. vv = v[None, None, :]
  468. result = torch.nn.functional.conv1d(aa, vv, padding=padding)
  469. # torch returns a 2D result, numpy returns a 1D array
  470. return result[0, :]
  471. def convolve(a: ArrayLike, v: ArrayLike, mode="full"):
  472. # NumPy: if v is longer than a, the arrays are swapped before computation
  473. if a.shape[0] < v.shape[0]:
  474. a, v = v, a
  475. # flip the weights since numpy does and torch does not
  476. v = torch.flip(v, (0,))
  477. return _conv_corr_impl(a, v, mode)
  478. def correlate(a: ArrayLike, v: ArrayLike, mode="valid"):
  479. v = torch.conj_physical(v)
  480. return _conv_corr_impl(a, v, mode)
  481. # ### logic & element selection ###
  482. def bincount(x: ArrayLike, /, weights: Optional[ArrayLike] = None, minlength=0):
  483. if x.numel() == 0:
  484. # edge case allowed by numpy
  485. x = x.new_empty(0, dtype=int)
  486. int_dtype = _dtypes_impl.default_dtypes().int_dtype
  487. (x,) = _util.typecast_tensors((x,), int_dtype, casting="safe")
  488. return torch.bincount(x, weights, minlength)
  489. def where(
  490. condition: ArrayLike,
  491. x: Optional[ArrayLikeOrScalar] = None,
  492. y: Optional[ArrayLikeOrScalar] = None,
  493. /,
  494. ):
  495. if (x is None) != (y is None):
  496. raise ValueError("either both or neither of x and y should be given")
  497. if condition.dtype != torch.bool:
  498. condition = condition.to(torch.bool)
  499. if x is None and y is None:
  500. result = torch.where(condition)
  501. else:
  502. result = torch.where(condition, x, y)
  503. return result
  504. # ###### module-level queries of object properties
  505. def ndim(a: ArrayLike):
  506. return a.ndim
  507. def shape(a: ArrayLike):
  508. return tuple(a.shape)
  509. def size(a: ArrayLike, axis=None):
  510. if axis is None:
  511. return a.numel()
  512. else:
  513. return a.shape[axis]
  514. # ###### shape manipulations and indexing
  515. def expand_dims(a: ArrayLike, axis):
  516. shape = _util.expand_shape(a.shape, axis)
  517. return a.view(shape) # never copies
  518. def flip(m: ArrayLike, axis=None):
  519. # XXX: semantic difference: np.flip returns a view, torch.flip copies
  520. if axis is None:
  521. axis = tuple(range(m.ndim))
  522. else:
  523. axis = _util.normalize_axis_tuple(axis, m.ndim)
  524. return torch.flip(m, axis)
  525. def flipud(m: ArrayLike):
  526. return torch.flipud(m)
  527. def fliplr(m: ArrayLike):
  528. return torch.fliplr(m)
  529. def rot90(m: ArrayLike, k=1, axes=(0, 1)):
  530. axes = _util.normalize_axis_tuple(axes, m.ndim)
  531. return torch.rot90(m, k, axes)
  532. # ### broadcasting and indices ###
  533. def broadcast_to(array: ArrayLike, shape, subok: NotImplementedType = False):
  534. return torch.broadcast_to(array, size=shape)
  535. # This is a function from tuples to tuples, so we just reuse it
  536. from torch import broadcast_shapes
  537. def broadcast_arrays(*args: ArrayLike, subok: NotImplementedType = False):
  538. return torch.broadcast_tensors(*args)
  539. def meshgrid(*xi: ArrayLike, copy=True, sparse=False, indexing="xy"):
  540. ndim = len(xi)
  541. if indexing not in ["xy", "ij"]:
  542. raise ValueError("Valid values for `indexing` are 'xy' and 'ij'.")
  543. s0 = (1,) * ndim
  544. output = [x.reshape(s0[:i] + (-1,) + s0[i + 1 :]) for i, x in enumerate(xi)]
  545. if indexing == "xy" and ndim > 1:
  546. # switch first and second axis
  547. output[0] = output[0].reshape((1, -1) + s0[2:])
  548. output[1] = output[1].reshape((-1, 1) + s0[2:])
  549. if not sparse:
  550. # Return the full N-D matrix (not only the 1-D vector)
  551. output = torch.broadcast_tensors(*output)
  552. if copy:
  553. output = [x.clone() for x in output]
  554. return list(output) # match numpy, return a list
  555. def indices(dimensions, dtype: Optional[DTypeLike] = int, sparse=False):
  556. # https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numeric.py#L1691-L1791
  557. dimensions = tuple(dimensions)
  558. N = len(dimensions)
  559. shape = (1,) * N
  560. if sparse:
  561. res = tuple()
  562. else:
  563. res = torch.empty((N,) + dimensions, dtype=dtype)
  564. for i, dim in enumerate(dimensions):
  565. idx = torch.arange(dim, dtype=dtype).reshape(
  566. shape[:i] + (dim,) + shape[i + 1 :]
  567. )
  568. if sparse:
  569. res = res + (idx,)
  570. else:
  571. res[i] = idx
  572. return res
  573. # ### tri*-something ###
  574. def tril(m: ArrayLike, k=0):
  575. return torch.tril(m, k)
  576. def triu(m: ArrayLike, k=0):
  577. return torch.triu(m, k)
  578. def tril_indices(n, k=0, m=None):
  579. if m is None:
  580. m = n
  581. return torch.tril_indices(n, m, offset=k)
  582. def triu_indices(n, k=0, m=None):
  583. if m is None:
  584. m = n
  585. return torch.triu_indices(n, m, offset=k)
  586. def tril_indices_from(arr: ArrayLike, k=0):
  587. if arr.ndim != 2:
  588. raise ValueError("input array must be 2-d")
  589. # Return a tensor rather than a tuple to avoid a graphbreak
  590. return torch.tril_indices(arr.shape[0], arr.shape[1], offset=k)
  591. def triu_indices_from(arr: ArrayLike, k=0):
  592. if arr.ndim != 2:
  593. raise ValueError("input array must be 2-d")
  594. # Return a tensor rather than a tuple to avoid a graphbreak
  595. return torch.triu_indices(arr.shape[0], arr.shape[1], offset=k)
  596. def tri(
  597. N,
  598. M=None,
  599. k=0,
  600. dtype: Optional[DTypeLike] = None,
  601. *,
  602. like: NotImplementedType = None,
  603. ):
  604. if M is None:
  605. M = N
  606. tensor = torch.ones((N, M), dtype=dtype)
  607. return torch.tril(tensor, diagonal=k)
  608. # ### equality, equivalence, allclose ###
  609. def isclose(a: ArrayLike, b: ArrayLike, rtol=1.0e-5, atol=1.0e-8, equal_nan=False):
  610. dtype = _dtypes_impl.result_type_impl(a, b)
  611. a = _util.cast_if_needed(a, dtype)
  612. b = _util.cast_if_needed(b, dtype)
  613. return torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
  614. def allclose(a: ArrayLike, b: ArrayLike, rtol=1e-05, atol=1e-08, equal_nan=False):
  615. dtype = _dtypes_impl.result_type_impl(a, b)
  616. a = _util.cast_if_needed(a, dtype)
  617. b = _util.cast_if_needed(b, dtype)
  618. return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
  619. def _tensor_equal(a1, a2, equal_nan=False):
  620. # Implementation of array_equal/array_equiv.
  621. if a1.shape != a2.shape:
  622. return False
  623. cond = a1 == a2
  624. if equal_nan:
  625. cond = cond | (torch.isnan(a1) & torch.isnan(a2))
  626. return cond.all().item()
  627. def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan=False):
  628. return _tensor_equal(a1, a2, equal_nan=equal_nan)
  629. def array_equiv(a1: ArrayLike, a2: ArrayLike):
  630. # *almost* the same as array_equal: _equiv tries to broadcast, _equal does not
  631. try:
  632. a1_t, a2_t = torch.broadcast_tensors(a1, a2)
  633. except RuntimeError:
  634. # failed to broadcast => not equivalent
  635. return False
  636. return _tensor_equal(a1_t, a2_t)
  637. def nan_to_num(
  638. x: ArrayLike, copy: NotImplementedType = True, nan=0.0, posinf=None, neginf=None
  639. ):
  640. # work around RuntimeError: "nan_to_num" not implemented for 'ComplexDouble'
  641. if x.is_complex():
  642. re = torch.nan_to_num(x.real, nan=nan, posinf=posinf, neginf=neginf)
  643. im = torch.nan_to_num(x.imag, nan=nan, posinf=posinf, neginf=neginf)
  644. return re + 1j * im
  645. else:
  646. return torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf)
  647. # ### put/take_along_axis ###
  648. def take(
  649. a: ArrayLike,
  650. indices: ArrayLike,
  651. axis=None,
  652. out: Optional[OutArray] = None,
  653. mode: NotImplementedType = "raise",
  654. ):
  655. (a,), axis = _util.axis_none_flatten(a, axis=axis)
  656. axis = _util.normalize_axis_index(axis, a.ndim)
  657. idx = (slice(None),) * axis + (indices, ...)
  658. result = a[idx]
  659. return result
  660. def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis):
  661. (arr,), axis = _util.axis_none_flatten(arr, axis=axis)
  662. axis = _util.normalize_axis_index(axis, arr.ndim)
  663. return torch.take_along_dim(arr, indices, axis)
  664. def put(
  665. a: NDArray,
  666. indices: ArrayLike,
  667. values: ArrayLike,
  668. mode: NotImplementedType = "raise",
  669. ):
  670. v = values.type(a.dtype)
  671. # If indices is larger than v, expand v to at least the size of indices. Any
  672. # unnecessary trailing elements are then trimmed.
  673. if indices.numel() > v.numel():
  674. ratio = (indices.numel() + v.numel() - 1) // v.numel()
  675. v = v.unsqueeze(0).expand((ratio,) + v.shape)
  676. # Trim unnecessary elements, regardless if v was expanded or not. Note
  677. # np.put() trims v to match indices by default too.
  678. if indices.numel() < v.numel():
  679. v = v.flatten()
  680. v = v[: indices.numel()]
  681. a.put_(indices, v)
  682. return None
  683. def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike, axis):
  684. (arr,), axis = _util.axis_none_flatten(arr, axis=axis)
  685. axis = _util.normalize_axis_index(axis, arr.ndim)
  686. indices, values = torch.broadcast_tensors(indices, values)
  687. values = _util.cast_if_needed(values, arr.dtype)
  688. result = torch.scatter(arr, axis, indices, values)
  689. arr.copy_(result.reshape(arr.shape))
  690. return None
  691. def choose(
  692. a: ArrayLike,
  693. choices: Sequence[ArrayLike],
  694. out: Optional[OutArray] = None,
  695. mode: NotImplementedType = "raise",
  696. ):
  697. # First, broadcast elements of `choices`
  698. choices = torch.stack(torch.broadcast_tensors(*choices))
  699. # Use an analog of `gather(choices, 0, a)` which broadcasts `choices` vs `a`:
  700. # (taken from https://github.com/pytorch/pytorch/issues/9407#issuecomment-1427907939)
  701. idx_list = [
  702. torch.arange(dim).view((1,) * i + (dim,) + (1,) * (choices.ndim - i - 1))
  703. for i, dim in enumerate(choices.shape)
  704. ]
  705. idx_list[0] = a
  706. return choices[idx_list].squeeze(0)
  707. # ### unique et al. ###
  708. def unique(
  709. ar: ArrayLike,
  710. return_index: NotImplementedType = False,
  711. return_inverse=False,
  712. return_counts=False,
  713. axis=None,
  714. *,
  715. equal_nan: NotImplementedType = True,
  716. ):
  717. (ar,), axis = _util.axis_none_flatten(ar, axis=axis)
  718. axis = _util.normalize_axis_index(axis, ar.ndim)
  719. result = torch.unique(
  720. ar, return_inverse=return_inverse, return_counts=return_counts, dim=axis
  721. )
  722. return result
  723. def nonzero(a: ArrayLike):
  724. return torch.nonzero(a, as_tuple=True)
  725. def argwhere(a: ArrayLike):
  726. return torch.argwhere(a)
  727. def flatnonzero(a: ArrayLike):
  728. return torch.flatten(a).nonzero(as_tuple=True)[0]
  729. def clip(
  730. a: ArrayLike,
  731. min: Optional[ArrayLike] = None,
  732. max: Optional[ArrayLike] = None,
  733. out: Optional[OutArray] = None,
  734. ):
  735. return torch.clamp(a, min, max)
  736. def repeat(a: ArrayLike, repeats: ArrayLikeOrScalar, axis=None):
  737. return torch.repeat_interleave(a, repeats, axis)
  738. def tile(A: ArrayLike, reps):
  739. if isinstance(reps, int):
  740. reps = (reps,)
  741. return torch.tile(A, reps)
  742. def resize(a: ArrayLike, new_shape=None):
  743. # implementation vendored from
  744. # https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/fromnumeric.py#L1420-L1497
  745. if new_shape is None:
  746. return a
  747. if isinstance(new_shape, int):
  748. new_shape = (new_shape,)
  749. a = a.flatten()
  750. new_size = 1
  751. for dim_length in new_shape:
  752. new_size *= dim_length
  753. if dim_length < 0:
  754. raise ValueError("all elements of `new_shape` must be non-negative")
  755. if a.numel() == 0 or new_size == 0:
  756. # First case must zero fill. The second would have repeats == 0.
  757. return torch.zeros(new_shape, dtype=a.dtype)
  758. repeats = -(-new_size // a.numel()) # ceil division
  759. a = concatenate((a,) * repeats)[:new_size]
  760. return reshape(a, new_shape)
  761. # ### diag et al. ###
  762. def diagonal(a: ArrayLike, offset=0, axis1=0, axis2=1):
  763. axis1 = _util.normalize_axis_index(axis1, a.ndim)
  764. axis2 = _util.normalize_axis_index(axis2, a.ndim)
  765. return torch.diagonal(a, offset, axis1, axis2)
  766. def trace(
  767. a: ArrayLike,
  768. offset=0,
  769. axis1=0,
  770. axis2=1,
  771. dtype: Optional[DTypeLike] = None,
  772. out: Optional[OutArray] = None,
  773. ):
  774. result = torch.diagonal(a, offset, dim1=axis1, dim2=axis2).sum(-1, dtype=dtype)
  775. return result
  776. def eye(
  777. N,
  778. M=None,
  779. k=0,
  780. dtype: Optional[DTypeLike] = None,
  781. order: NotImplementedType = "C",
  782. *,
  783. like: NotImplementedType = None,
  784. ):
  785. if dtype is None:
  786. dtype = _dtypes_impl.default_dtypes().float_dtype
  787. if M is None:
  788. M = N
  789. z = torch.zeros(N, M, dtype=dtype)
  790. z.diagonal(k).fill_(1)
  791. return z
  792. def identity(n, dtype: Optional[DTypeLike] = None, *, like: NotImplementedType = None):
  793. return torch.eye(n, dtype=dtype)
  794. def diag(v: ArrayLike, k=0):
  795. return torch.diag(v, k)
  796. def diagflat(v: ArrayLike, k=0):
  797. return torch.diagflat(v, k)
  798. def diag_indices(n, ndim=2):
  799. idx = torch.arange(n)
  800. return (idx,) * ndim
  801. def diag_indices_from(arr: ArrayLike):
  802. if not arr.ndim >= 2:
  803. raise ValueError("input array must be at least 2-d")
  804. # For more than d=2, the strided formula is only valid for arrays with
  805. # all dimensions equal, so we check first.
  806. s = arr.shape
  807. if s[1:] != s[:-1]:
  808. raise ValueError("All dimensions of input must be of equal length")
  809. return diag_indices(s[0], arr.ndim)
  810. def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap=False):
  811. if a.ndim < 2:
  812. raise ValueError("array must be at least 2-d")
  813. if val.numel() == 0 and not wrap:
  814. a.fill_diagonal_(val)
  815. return a
  816. if val.ndim == 0:
  817. val = val.unsqueeze(0)
  818. # torch.Tensor.fill_diagonal_ only accepts scalars
  819. # If the size of val is too large, then val is trimmed
  820. if a.ndim == 2:
  821. tall = a.shape[0] > a.shape[1]
  822. # wrap does nothing for wide matrices...
  823. if not wrap or not tall:
  824. # Never wraps
  825. diag = a.diagonal()
  826. diag.copy_(val[: diag.numel()])
  827. else:
  828. # wraps and tall... leaving one empty line between diagonals?!
  829. max_, min_ = a.shape
  830. idx = torch.arange(max_ - max_ // (min_ + 1))
  831. mod = idx % min_
  832. div = idx // min_
  833. a[(div * (min_ + 1) + mod, mod)] = val[: idx.numel()]
  834. else:
  835. idx = diag_indices_from(a)
  836. # a.shape = (n, n, ..., n)
  837. a[idx] = val[: a.shape[0]]
  838. return a
  839. def vdot(a: ArrayLike, b: ArrayLike, /):
  840. # 1. torch only accepts 1D arrays, numpy flattens
  841. # 2. torch requires matching dtype, while numpy casts (?)
  842. t_a, t_b = torch.atleast_1d(a, b)
  843. if t_a.ndim > 1:
  844. t_a = t_a.flatten()
  845. if t_b.ndim > 1:
  846. t_b = t_b.flatten()
  847. dtype = _dtypes_impl.result_type_impl(t_a, t_b)
  848. is_half = dtype == torch.float16 and (t_a.is_cpu or t_b.is_cpu)
  849. is_bool = dtype == torch.bool
  850. # work around torch's "dot" not implemented for 'Half', 'Bool'
  851. if is_half:
  852. dtype = torch.float32
  853. elif is_bool:
  854. dtype = torch.uint8
  855. t_a = _util.cast_if_needed(t_a, dtype)
  856. t_b = _util.cast_if_needed(t_b, dtype)
  857. result = torch.vdot(t_a, t_b)
  858. if is_half:
  859. result = result.to(torch.float16)
  860. elif is_bool:
  861. result = result.to(torch.bool)
  862. return result
  863. def tensordot(a: ArrayLike, b: ArrayLike, axes=2):
  864. if isinstance(axes, (list, tuple)):
  865. axes = [[ax] if isinstance(ax, int) else ax for ax in axes]
  866. target_dtype = _dtypes_impl.result_type_impl(a, b)
  867. a = _util.cast_if_needed(a, target_dtype)
  868. b = _util.cast_if_needed(b, target_dtype)
  869. return torch.tensordot(a, b, dims=axes)
  870. def dot(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
  871. dtype = _dtypes_impl.result_type_impl(a, b)
  872. is_bool = dtype == torch.bool
  873. if is_bool:
  874. dtype = torch.uint8
  875. a = _util.cast_if_needed(a, dtype)
  876. b = _util.cast_if_needed(b, dtype)
  877. if a.ndim == 0 or b.ndim == 0:
  878. result = a * b
  879. else:
  880. result = torch.matmul(a, b)
  881. if is_bool:
  882. result = result.to(torch.bool)
  883. return result
  884. def inner(a: ArrayLike, b: ArrayLike, /):
  885. dtype = _dtypes_impl.result_type_impl(a, b)
  886. is_half = dtype == torch.float16 and (a.is_cpu or b.is_cpu)
  887. is_bool = dtype == torch.bool
  888. if is_half:
  889. # work around torch's "addmm_impl_cpu_" not implemented for 'Half'"
  890. dtype = torch.float32
  891. elif is_bool:
  892. dtype = torch.uint8
  893. a = _util.cast_if_needed(a, dtype)
  894. b = _util.cast_if_needed(b, dtype)
  895. result = torch.inner(a, b)
  896. if is_half:
  897. result = result.to(torch.float16)
  898. elif is_bool:
  899. result = result.to(torch.bool)
  900. return result
  901. def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
  902. return torch.outer(a, b)
  903. def cross(a: ArrayLike, b: ArrayLike, axisa=-1, axisb=-1, axisc=-1, axis=None):
  904. # implementation vendored from
  905. # https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numeric.py#L1486-L1685
  906. if axis is not None:
  907. axisa, axisb, axisc = (axis,) * 3
  908. # Check axisa and axisb are within bounds
  909. axisa = _util.normalize_axis_index(axisa, a.ndim)
  910. axisb = _util.normalize_axis_index(axisb, b.ndim)
  911. # Move working axis to the end of the shape
  912. a = torch.moveaxis(a, axisa, -1)
  913. b = torch.moveaxis(b, axisb, -1)
  914. msg = "incompatible dimensions for cross product\n(dimension must be 2 or 3)"
  915. if a.shape[-1] not in (2, 3) or b.shape[-1] not in (2, 3):
  916. raise ValueError(msg)
  917. # Create the output array
  918. shape = broadcast_shapes(a[..., 0].shape, b[..., 0].shape)
  919. if a.shape[-1] == 3 or b.shape[-1] == 3:
  920. shape += (3,)
  921. # Check axisc is within bounds
  922. axisc = _util.normalize_axis_index(axisc, len(shape))
  923. dtype = _dtypes_impl.result_type_impl(a, b)
  924. cp = torch.empty(shape, dtype=dtype)
  925. # recast arrays as dtype
  926. a = _util.cast_if_needed(a, dtype)
  927. b = _util.cast_if_needed(b, dtype)
  928. # create local aliases for readability
  929. a0 = a[..., 0]
  930. a1 = a[..., 1]
  931. if a.shape[-1] == 3:
  932. a2 = a[..., 2]
  933. b0 = b[..., 0]
  934. b1 = b[..., 1]
  935. if b.shape[-1] == 3:
  936. b2 = b[..., 2]
  937. if cp.ndim != 0 and cp.shape[-1] == 3:
  938. cp0 = cp[..., 0]
  939. cp1 = cp[..., 1]
  940. cp2 = cp[..., 2]
  941. if a.shape[-1] == 2:
  942. if b.shape[-1] == 2:
  943. # a0 * b1 - a1 * b0
  944. cp[...] = a0 * b1 - a1 * b0
  945. return cp
  946. else:
  947. assert b.shape[-1] == 3
  948. # cp0 = a1 * b2 - 0 (a2 = 0)
  949. # cp1 = 0 - a0 * b2 (a2 = 0)
  950. # cp2 = a0 * b1 - a1 * b0
  951. cp0[...] = a1 * b2
  952. cp1[...] = -a0 * b2
  953. cp2[...] = a0 * b1 - a1 * b0
  954. else:
  955. assert a.shape[-1] == 3
  956. if b.shape[-1] == 3:
  957. cp0[...] = a1 * b2 - a2 * b1
  958. cp1[...] = a2 * b0 - a0 * b2
  959. cp2[...] = a0 * b1 - a1 * b0
  960. else:
  961. assert b.shape[-1] == 2
  962. cp0[...] = -a2 * b1
  963. cp1[...] = a2 * b0
  964. cp2[...] = a0 * b1 - a1 * b0
  965. return torch.moveaxis(cp, -1, axisc)
  966. def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize=False):
  967. # Have to manually normalize *operands and **kwargs, following the NumPy signature
  968. # We have a local import to avoid poluting the global space, as it will be then
  969. # exported in funcs.py
  970. from ._ndarray import ndarray
  971. from ._normalizations import (
  972. maybe_copy_to,
  973. normalize_array_like,
  974. normalize_casting,
  975. normalize_dtype,
  976. wrap_tensors,
  977. )
  978. dtype = normalize_dtype(dtype)
  979. casting = normalize_casting(casting)
  980. if out is not None and not isinstance(out, ndarray):
  981. raise TypeError("'out' must be an array")
  982. if order != "K":
  983. raise NotImplementedError("'order' parameter is not supported.")
  984. # parse arrays and normalize them
  985. sublist_format = not isinstance(operands[0], str)
  986. if sublist_format:
  987. # op, str, op, str ... [sublistout] format: normalize every other argument
  988. # - if sublistout is not given, the length of operands is even, and we pick
  989. # odd-numbered elements, which are arrays.
  990. # - if sublistout is given, the length of operands is odd, we peel off
  991. # the last one, and pick odd-numbered elements, which are arrays.
  992. # Without [:-1], we would have picked sublistout, too.
  993. array_operands = operands[:-1][::2]
  994. else:
  995. # ("ij->", arrays) format
  996. subscripts, array_operands = operands[0], operands[1:]
  997. tensors = [normalize_array_like(op) for op in array_operands]
  998. target_dtype = _dtypes_impl.result_type_impl(*tensors) if dtype is None else dtype
  999. # work around 'bmm' not implemented for 'Half' etc
  1000. is_half = target_dtype == torch.float16 and all(t.is_cpu for t in tensors)
  1001. if is_half:
  1002. target_dtype = torch.float32
  1003. is_short_int = target_dtype in [torch.uint8, torch.int8, torch.int16, torch.int32]
  1004. if is_short_int:
  1005. target_dtype = torch.int64
  1006. tensors = _util.typecast_tensors(tensors, target_dtype, casting)
  1007. from torch.backends import opt_einsum
  1008. try:
  1009. # set the global state to handle the optimize=... argument, restore on exit
  1010. if opt_einsum.is_available():
  1011. old_strategy = torch.backends.opt_einsum.strategy
  1012. old_enabled = torch.backends.opt_einsum.enabled
  1013. # torch.einsum calls opt_einsum.contract_path, which runs into
  1014. # https://github.com/dgasmith/opt_einsum/issues/219
  1015. # for strategy={True, False}
  1016. if optimize is True:
  1017. optimize = "auto"
  1018. elif optimize is False:
  1019. torch.backends.opt_einsum.enabled = False
  1020. torch.backends.opt_einsum.strategy = optimize
  1021. if sublist_format:
  1022. # recombine operands
  1023. sublists = operands[1::2]
  1024. has_sublistout = len(operands) % 2 == 1
  1025. if has_sublistout:
  1026. sublistout = operands[-1]
  1027. operands = list(itertools.chain.from_iterable(zip(tensors, sublists)))
  1028. if has_sublistout:
  1029. operands.append(sublistout)
  1030. result = torch.einsum(*operands)
  1031. else:
  1032. result = torch.einsum(subscripts, *tensors)
  1033. finally:
  1034. if opt_einsum.is_available():
  1035. torch.backends.opt_einsum.strategy = old_strategy
  1036. torch.backends.opt_einsum.enabled = old_enabled
  1037. result = maybe_copy_to(out, result)
  1038. return wrap_tensors(result)
  1039. # ### sort and partition ###
  1040. def _sort_helper(tensor, axis, kind, order):
  1041. if tensor.dtype.is_complex:
  1042. raise NotImplementedError(f"sorting {tensor.dtype} is not supported")
  1043. (tensor,), axis = _util.axis_none_flatten(tensor, axis=axis)
  1044. axis = _util.normalize_axis_index(axis, tensor.ndim)
  1045. stable = kind == "stable"
  1046. return tensor, axis, stable
  1047. def sort(a: ArrayLike, axis=-1, kind=None, order: NotImplementedType = None):
  1048. # `order` keyword arg is only relevant for structured dtypes; so not supported here.
  1049. a, axis, stable = _sort_helper(a, axis, kind, order)
  1050. result = torch.sort(a, dim=axis, stable=stable)
  1051. return result.values
  1052. def argsort(a: ArrayLike, axis=-1, kind=None, order: NotImplementedType = None):
  1053. a, axis, stable = _sort_helper(a, axis, kind, order)
  1054. return torch.argsort(a, dim=axis, stable=stable)
  1055. def searchsorted(
  1056. a: ArrayLike, v: ArrayLike, side="left", sorter: Optional[ArrayLike] = None
  1057. ):
  1058. if a.dtype.is_complex:
  1059. raise NotImplementedError(f"searchsorted with dtype={a.dtype}")
  1060. return torch.searchsorted(a, v, side=side, sorter=sorter)
  1061. # ### swap/move/roll axis ###
  1062. def moveaxis(a: ArrayLike, source, destination):
  1063. source = _util.normalize_axis_tuple(source, a.ndim, "source")
  1064. destination = _util.normalize_axis_tuple(destination, a.ndim, "destination")
  1065. return torch.moveaxis(a, source, destination)
  1066. def swapaxes(a: ArrayLike, axis1, axis2):
  1067. axis1 = _util.normalize_axis_index(axis1, a.ndim)
  1068. axis2 = _util.normalize_axis_index(axis2, a.ndim)
  1069. return torch.swapaxes(a, axis1, axis2)
  1070. def rollaxis(a: ArrayLike, axis, start=0):
  1071. # Straight vendor from:
  1072. # https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numeric.py#L1259
  1073. #
  1074. # Also note this function in NumPy is mostly retained for backwards compat
  1075. # (https://stackoverflow.com/questions/29891583/reason-why-numpy-rollaxis-is-so-confusing)
  1076. # so let's not touch it unless hard pressed.
  1077. n = a.ndim
  1078. axis = _util.normalize_axis_index(axis, n)
  1079. if start < 0:
  1080. start += n
  1081. msg = "'%s' arg requires %d <= %s < %d, but %d was passed in"
  1082. if not (0 <= start < n + 1):
  1083. raise _util.AxisError(msg % ("start", -n, "start", n + 1, start))
  1084. if axis < start:
  1085. # it's been removed
  1086. start -= 1
  1087. if axis == start:
  1088. # numpy returns a view, here we try returning the tensor itself
  1089. # return tensor[...]
  1090. return a
  1091. axes = list(range(0, n))
  1092. axes.remove(axis)
  1093. axes.insert(start, axis)
  1094. return a.view(axes)
  1095. def roll(a: ArrayLike, shift, axis=None):
  1096. if axis is not None:
  1097. axis = _util.normalize_axis_tuple(axis, a.ndim, allow_duplicate=True)
  1098. if not isinstance(shift, tuple):
  1099. shift = (shift,) * len(axis)
  1100. return torch.roll(a, shift, axis)
  1101. # ### shape manipulations ###
  1102. def squeeze(a: ArrayLike, axis=None):
  1103. if axis == ():
  1104. result = a
  1105. elif axis is None:
  1106. result = a.squeeze()
  1107. else:
  1108. if isinstance(axis, tuple):
  1109. result = a
  1110. for ax in axis:
  1111. result = a.squeeze(ax)
  1112. else:
  1113. result = a.squeeze(axis)
  1114. return result
  1115. def reshape(a: ArrayLike, newshape, order: NotImplementedType = "C"):
  1116. # if sh = (1, 2, 3), numpy allows both .reshape(sh) and .reshape(*sh)
  1117. newshape = newshape[0] if len(newshape) == 1 else newshape
  1118. return a.reshape(newshape)
  1119. # NB: cannot use torch.reshape(a, newshape) above, because of
  1120. # (Pdb) torch.reshape(torch.as_tensor([1]), 1)
  1121. # *** TypeError: reshape(): argument 'shape' (position 2) must be tuple of SymInts, not int
  1122. def transpose(a: ArrayLike, axes=None):
  1123. # numpy allows both .transpose(sh) and .transpose(*sh)
  1124. # also older code uses axes being a list
  1125. if axes in [(), None, (None,)]:
  1126. axes = tuple(reversed(range(a.ndim)))
  1127. elif len(axes) == 1:
  1128. axes = axes[0]
  1129. return a.permute(axes)
  1130. def ravel(a: ArrayLike, order: NotImplementedType = "C"):
  1131. return torch.flatten(a)
  1132. def diff(
  1133. a: ArrayLike,
  1134. n=1,
  1135. axis=-1,
  1136. prepend: Optional[ArrayLike] = None,
  1137. append: Optional[ArrayLike] = None,
  1138. ):
  1139. axis = _util.normalize_axis_index(axis, a.ndim)
  1140. if n < 0:
  1141. raise ValueError(f"order must be non-negative but got {n}")
  1142. if n == 0:
  1143. # match numpy and return the input immediately
  1144. return a
  1145. if prepend is not None:
  1146. shape = list(a.shape)
  1147. shape[axis] = prepend.shape[axis] if prepend.ndim > 0 else 1
  1148. prepend = torch.broadcast_to(prepend, shape)
  1149. if append is not None:
  1150. shape = list(a.shape)
  1151. shape[axis] = append.shape[axis] if append.ndim > 0 else 1
  1152. append = torch.broadcast_to(append, shape)
  1153. return torch.diff(a, n, axis=axis, prepend=prepend, append=append)
  1154. # ### math functions ###
  1155. def angle(z: ArrayLike, deg=False):
  1156. result = torch.angle(z)
  1157. if deg:
  1158. result = result * (180 / torch.pi)
  1159. return result
  1160. def sinc(x: ArrayLike):
  1161. return torch.sinc(x)
  1162. # NB: have to normalize *varargs manually
  1163. def gradient(f: ArrayLike, *varargs, axis=None, edge_order=1):
  1164. N = f.ndim # number of dimensions
  1165. varargs = _util.ndarrays_to_tensors(varargs)
  1166. if axis is None:
  1167. axes = tuple(range(N))
  1168. else:
  1169. axes = _util.normalize_axis_tuple(axis, N)
  1170. len_axes = len(axes)
  1171. n = len(varargs)
  1172. if n == 0:
  1173. # no spacing argument - use 1 in all axes
  1174. dx = [1.0] * len_axes
  1175. elif n == 1 and (_dtypes_impl.is_scalar(varargs[0]) or varargs[0].ndim == 0):
  1176. # single scalar or 0D tensor for all axes (np.ndim(varargs[0]) == 0)
  1177. dx = varargs * len_axes
  1178. elif n == len_axes:
  1179. # scalar or 1d array for each axis
  1180. dx = list(varargs)
  1181. for i, distances in enumerate(dx):
  1182. distances = torch.as_tensor(distances)
  1183. if distances.ndim == 0:
  1184. continue
  1185. elif distances.ndim != 1:
  1186. raise ValueError("distances must be either scalars or 1d")
  1187. if len(distances) != f.shape[axes[i]]:
  1188. raise ValueError(
  1189. "when 1d, distances must match "
  1190. "the length of the corresponding dimension"
  1191. )
  1192. if not (distances.dtype.is_floating_point or distances.dtype.is_complex):
  1193. distances = distances.double()
  1194. diffx = torch.diff(distances)
  1195. # if distances are constant reduce to the scalar case
  1196. # since it brings a consistent speedup
  1197. if (diffx == diffx[0]).all():
  1198. diffx = diffx[0]
  1199. dx[i] = diffx
  1200. else:
  1201. raise TypeError("invalid number of arguments")
  1202. if edge_order > 2:
  1203. raise ValueError("'edge_order' greater than 2 not supported")
  1204. # use central differences on interior and one-sided differences on the
  1205. # endpoints. This preserves second order-accuracy over the full domain.
  1206. outvals = []
  1207. # create slice objects --- initially all are [:, :, ..., :]
  1208. slice1 = [slice(None)] * N
  1209. slice2 = [slice(None)] * N
  1210. slice3 = [slice(None)] * N
  1211. slice4 = [slice(None)] * N
  1212. otype = f.dtype
  1213. if _dtypes_impl.python_type_for_torch(otype) in (int, bool):
  1214. # Convert to floating point.
  1215. # First check if f is a numpy integer type; if so, convert f to float64
  1216. # to avoid modular arithmetic when computing the changes in f.
  1217. f = f.double()
  1218. otype = torch.float64
  1219. for axis, ax_dx in zip(axes, dx):
  1220. if f.shape[axis] < edge_order + 1:
  1221. raise ValueError(
  1222. "Shape of array too small to calculate a numerical gradient, "
  1223. "at least (edge_order + 1) elements are required."
  1224. )
  1225. # result allocation
  1226. out = torch.empty_like(f, dtype=otype)
  1227. # spacing for the current axis (NB: np.ndim(ax_dx) == 0)
  1228. uniform_spacing = _dtypes_impl.is_scalar(ax_dx) or ax_dx.ndim == 0
  1229. # Numerical differentiation: 2nd order interior
  1230. slice1[axis] = slice(1, -1)
  1231. slice2[axis] = slice(None, -2)
  1232. slice3[axis] = slice(1, -1)
  1233. slice4[axis] = slice(2, None)
  1234. if uniform_spacing:
  1235. out[tuple(slice1)] = (f[tuple(slice4)] - f[tuple(slice2)]) / (2.0 * ax_dx)
  1236. else:
  1237. dx1 = ax_dx[0:-1]
  1238. dx2 = ax_dx[1:]
  1239. a = -(dx2) / (dx1 * (dx1 + dx2))
  1240. b = (dx2 - dx1) / (dx1 * dx2)
  1241. c = dx1 / (dx2 * (dx1 + dx2))
  1242. # fix the shape for broadcasting
  1243. shape = [1] * N
  1244. shape[axis] = -1
  1245. a = a.reshape(shape)
  1246. b = b.reshape(shape)
  1247. c = c.reshape(shape)
  1248. # 1D equivalent -- out[1:-1] = a * f[:-2] + b * f[1:-1] + c * f[2:]
  1249. out[tuple(slice1)] = (
  1250. a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)]
  1251. )
  1252. # Numerical differentiation: 1st order edges
  1253. if edge_order == 1:
  1254. slice1[axis] = 0
  1255. slice2[axis] = 1
  1256. slice3[axis] = 0
  1257. dx_0 = ax_dx if uniform_spacing else ax_dx[0]
  1258. # 1D equivalent -- out[0] = (f[1] - f[0]) / (x[1] - x[0])
  1259. out[tuple(slice1)] = (f[tuple(slice2)] - f[tuple(slice3)]) / dx_0
  1260. slice1[axis] = -1
  1261. slice2[axis] = -1
  1262. slice3[axis] = -2
  1263. dx_n = ax_dx if uniform_spacing else ax_dx[-1]
  1264. # 1D equivalent -- out[-1] = (f[-1] - f[-2]) / (x[-1] - x[-2])
  1265. out[tuple(slice1)] = (f[tuple(slice2)] - f[tuple(slice3)]) / dx_n
  1266. # Numerical differentiation: 2nd order edges
  1267. else:
  1268. slice1[axis] = 0
  1269. slice2[axis] = 0
  1270. slice3[axis] = 1
  1271. slice4[axis] = 2
  1272. if uniform_spacing:
  1273. a = -1.5 / ax_dx
  1274. b = 2.0 / ax_dx
  1275. c = -0.5 / ax_dx
  1276. else:
  1277. dx1 = ax_dx[0]
  1278. dx2 = ax_dx[1]
  1279. a = -(2.0 * dx1 + dx2) / (dx1 * (dx1 + dx2))
  1280. b = (dx1 + dx2) / (dx1 * dx2)
  1281. c = -dx1 / (dx2 * (dx1 + dx2))
  1282. # 1D equivalent -- out[0] = a * f[0] + b * f[1] + c * f[2]
  1283. out[tuple(slice1)] = (
  1284. a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)]
  1285. )
  1286. slice1[axis] = -1
  1287. slice2[axis] = -3
  1288. slice3[axis] = -2
  1289. slice4[axis] = -1
  1290. if uniform_spacing:
  1291. a = 0.5 / ax_dx
  1292. b = -2.0 / ax_dx
  1293. c = 1.5 / ax_dx
  1294. else:
  1295. dx1 = ax_dx[-2]
  1296. dx2 = ax_dx[-1]
  1297. a = (dx2) / (dx1 * (dx1 + dx2))
  1298. b = -(dx2 + dx1) / (dx1 * dx2)
  1299. c = (2.0 * dx2 + dx1) / (dx2 * (dx1 + dx2))
  1300. # 1D equivalent -- out[-1] = a * f[-3] + b * f[-2] + c * f[-1]
  1301. out[tuple(slice1)] = (
  1302. a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)]
  1303. )
  1304. outvals.append(out)
  1305. # reset the slice object in this dimension to ":"
  1306. slice1[axis] = slice(None)
  1307. slice2[axis] = slice(None)
  1308. slice3[axis] = slice(None)
  1309. slice4[axis] = slice(None)
  1310. if len_axes == 1:
  1311. return outvals[0]
  1312. else:
  1313. return outvals
  1314. # ### Type/shape etc queries ###
  1315. def round(a: ArrayLike, decimals=0, out: Optional[OutArray] = None):
  1316. if a.is_floating_point():
  1317. result = torch.round(a, decimals=decimals)
  1318. elif a.is_complex():
  1319. # RuntimeError: "round_cpu" not implemented for 'ComplexFloat'
  1320. result = torch.complex(
  1321. torch.round(a.real, decimals=decimals),
  1322. torch.round(a.imag, decimals=decimals),
  1323. )
  1324. else:
  1325. # RuntimeError: "round_cpu" not implemented for 'int'
  1326. result = a
  1327. return result
  1328. around = round
  1329. round_ = round
  1330. def real_if_close(a: ArrayLike, tol=100):
  1331. if not torch.is_complex(a):
  1332. return a
  1333. if tol > 1:
  1334. # Undocumented in numpy: if tol < 1, it's an absolute tolerance!
  1335. # Otherwise, tol > 1 is relative tolerance, in units of the dtype epsilon
  1336. # https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/type_check.py#L577
  1337. tol = tol * torch.finfo(a.dtype).eps
  1338. mask = torch.abs(a.imag) < tol
  1339. return a.real if mask.all() else a
  1340. def real(a: ArrayLike):
  1341. return torch.real(a)
  1342. def imag(a: ArrayLike):
  1343. if a.is_complex():
  1344. return a.imag
  1345. return torch.zeros_like(a)
  1346. def iscomplex(x: ArrayLike):
  1347. if torch.is_complex(x):
  1348. return x.imag != 0
  1349. return torch.zeros_like(x, dtype=torch.bool)
  1350. def isreal(x: ArrayLike):
  1351. if torch.is_complex(x):
  1352. return x.imag == 0
  1353. return torch.ones_like(x, dtype=torch.bool)
  1354. def iscomplexobj(x: ArrayLike):
  1355. return torch.is_complex(x)
  1356. def isrealobj(x: ArrayLike):
  1357. return not torch.is_complex(x)
  1358. def isneginf(x: ArrayLike, out: Optional[OutArray] = None):
  1359. return torch.isneginf(x)
  1360. def isposinf(x: ArrayLike, out: Optional[OutArray] = None):
  1361. return torch.isposinf(x)
  1362. def i0(x: ArrayLike):
  1363. return torch.special.i0(x)
  1364. def isscalar(a):
  1365. # We need to use normalize_array_like, but we don't want to export it in funcs.py
  1366. from ._normalizations import normalize_array_like
  1367. try:
  1368. t = normalize_array_like(a)
  1369. return t.numel() == 1
  1370. except Exception:
  1371. return False
  1372. # ### Filter windows ###
  1373. def hamming(M):
  1374. dtype = _dtypes_impl.default_dtypes().float_dtype
  1375. return torch.hamming_window(M, periodic=False, dtype=dtype)
  1376. def hanning(M):
  1377. dtype = _dtypes_impl.default_dtypes().float_dtype
  1378. return torch.hann_window(M, periodic=False, dtype=dtype)
  1379. def kaiser(M, beta):
  1380. dtype = _dtypes_impl.default_dtypes().float_dtype
  1381. return torch.kaiser_window(M, beta=beta, periodic=False, dtype=dtype)
  1382. def blackman(M):
  1383. dtype = _dtypes_impl.default_dtypes().float_dtype
  1384. return torch.blackman_window(M, periodic=False, dtype=dtype)
  1385. def bartlett(M):
  1386. dtype = _dtypes_impl.default_dtypes().float_dtype
  1387. return torch.bartlett_window(M, periodic=False, dtype=dtype)
  1388. # ### Dtype routines ###
  1389. # vendored from https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/type_check.py#L666
  1390. array_type = [
  1391. [torch.float16, torch.float32, torch.float64],
  1392. [None, torch.complex64, torch.complex128],
  1393. ]
  1394. array_precision = {
  1395. torch.float16: 0,
  1396. torch.float32: 1,
  1397. torch.float64: 2,
  1398. torch.complex64: 1,
  1399. torch.complex128: 2,
  1400. }
  1401. def common_type(*tensors: ArrayLike):
  1402. is_complex = False
  1403. precision = 0
  1404. for a in tensors:
  1405. t = a.dtype
  1406. if iscomplexobj(a):
  1407. is_complex = True
  1408. if not (t.is_floating_point or t.is_complex):
  1409. p = 2 # array_precision[_nx.double]
  1410. else:
  1411. p = array_precision.get(t, None)
  1412. if p is None:
  1413. raise TypeError("can't get common type for non-numeric array")
  1414. precision = builtins.max(precision, p)
  1415. if is_complex:
  1416. return array_type[1][precision]
  1417. else:
  1418. return array_type[0][precision]
  1419. # ### histograms ###
  1420. def histogram(
  1421. a: ArrayLike,
  1422. bins: ArrayLike = 10,
  1423. range=None,
  1424. normed=None,
  1425. weights: Optional[ArrayLike] = None,
  1426. density=None,
  1427. ):
  1428. if normed is not None:
  1429. raise ValueError("normed argument is deprecated, use density= instead")
  1430. if weights is not None and weights.dtype.is_complex:
  1431. raise NotImplementedError("complex weights histogram.")
  1432. is_a_int = not (a.dtype.is_floating_point or a.dtype.is_complex)
  1433. is_w_int = weights is None or not weights.dtype.is_floating_point
  1434. if is_a_int:
  1435. a = a.double()
  1436. if weights is not None:
  1437. weights = _util.cast_if_needed(weights, a.dtype)
  1438. if isinstance(bins, torch.Tensor):
  1439. if bins.ndim == 0:
  1440. # bins was a single int
  1441. bins = operator.index(bins)
  1442. else:
  1443. bins = _util.cast_if_needed(bins, a.dtype)
  1444. if range is None:
  1445. h, b = torch.histogram(a, bins, weight=weights, density=bool(density))
  1446. else:
  1447. h, b = torch.histogram(
  1448. a, bins, range=range, weight=weights, density=bool(density)
  1449. )
  1450. if not density and is_w_int:
  1451. h = h.long()
  1452. if is_a_int:
  1453. b = b.long()
  1454. return h, b
  1455. def histogram2d(
  1456. x,
  1457. y,
  1458. bins=10,
  1459. range: Optional[ArrayLike] = None,
  1460. normed=None,
  1461. weights: Optional[ArrayLike] = None,
  1462. density=None,
  1463. ):
  1464. # vendored from https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/twodim_base.py#L655-L821
  1465. if len(x) != len(y):
  1466. raise ValueError("x and y must have the same length.")
  1467. try:
  1468. N = len(bins)
  1469. except TypeError:
  1470. N = 1
  1471. if N != 1 and N != 2:
  1472. bins = [bins, bins]
  1473. h, e = histogramdd((x, y), bins, range, normed, weights, density)
  1474. return h, e[0], e[1]
  1475. def histogramdd(
  1476. sample,
  1477. bins=10,
  1478. range: Optional[ArrayLike] = None,
  1479. normed=None,
  1480. weights: Optional[ArrayLike] = None,
  1481. density=None,
  1482. ):
  1483. # have to normalize manually because `sample` interpretation differs
  1484. # for a list of lists and a 2D array
  1485. if normed is not None:
  1486. raise ValueError("normed argument is deprecated, use density= instead")
  1487. from ._normalizations import normalize_array_like, normalize_seq_array_like
  1488. if isinstance(sample, (list, tuple)):
  1489. sample = normalize_array_like(sample).T
  1490. else:
  1491. sample = normalize_array_like(sample)
  1492. sample = torch.atleast_2d(sample)
  1493. if not (sample.dtype.is_floating_point or sample.dtype.is_complex):
  1494. sample = sample.double()
  1495. # bins is either an int, or a sequence of ints or a sequence of arrays
  1496. bins_is_array = not (
  1497. isinstance(bins, int) or builtins.all(isinstance(b, int) for b in bins)
  1498. )
  1499. if bins_is_array:
  1500. bins = normalize_seq_array_like(bins)
  1501. bins_dtypes = [b.dtype for b in bins]
  1502. bins = [_util.cast_if_needed(b, sample.dtype) for b in bins]
  1503. if range is not None:
  1504. range = range.flatten().tolist()
  1505. if weights is not None:
  1506. # range=... is required : interleave min and max values per dimension
  1507. mm = sample.aminmax(dim=0)
  1508. range = torch.cat(mm).reshape(2, -1).T.flatten()
  1509. range = tuple(range.tolist())
  1510. weights = _util.cast_if_needed(weights, sample.dtype)
  1511. w_kwd = {"weight": weights}
  1512. else:
  1513. w_kwd = {}
  1514. h, b = torch.histogramdd(sample, bins, range, density=bool(density), **w_kwd)
  1515. if bins_is_array:
  1516. b = [_util.cast_if_needed(bb, dtyp) for bb, dtyp in zip(b, bins_dtypes)]
  1517. return h, b
  1518. # ### odds and ends
  1519. def min_scalar_type(a: ArrayLike, /):
  1520. # https://github.com/numpy/numpy/blob/maintenance/1.24.x/numpy/core/src/multiarray/convert_datatype.c#L1288
  1521. from ._dtypes import DType
  1522. if a.numel() > 1:
  1523. # numpy docs: "For non-scalar array a, returns the vector's dtype unmodified."
  1524. return DType(a.dtype)
  1525. if a.dtype == torch.bool:
  1526. dtype = torch.bool
  1527. elif a.dtype.is_complex:
  1528. fi = torch.finfo(torch.float32)
  1529. fits_in_single = a.dtype == torch.complex64 or (
  1530. fi.min <= a.real <= fi.max and fi.min <= a.imag <= fi.max
  1531. )
  1532. dtype = torch.complex64 if fits_in_single else torch.complex128
  1533. elif a.dtype.is_floating_point:
  1534. for dt in [torch.float16, torch.float32, torch.float64]:
  1535. fi = torch.finfo(dt)
  1536. if fi.min <= a <= fi.max:
  1537. dtype = dt
  1538. break
  1539. else:
  1540. # must be integer
  1541. for dt in [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]:
  1542. # Prefer unsigned int where possible, as numpy does.
  1543. ii = torch.iinfo(dt)
  1544. if ii.min <= a <= ii.max:
  1545. dtype = dt
  1546. break
  1547. return DType(dtype)
  1548. def pad(array: ArrayLike, pad_width: ArrayLike, mode="constant", **kwargs):
  1549. if mode != "constant":
  1550. raise NotImplementedError
  1551. value = kwargs.get("constant_values", 0)
  1552. # `value` must be a python scalar for torch.nn.functional.pad
  1553. typ = _dtypes_impl.python_type_for_torch(array.dtype)
  1554. value = typ(value)
  1555. pad_width = torch.broadcast_to(pad_width, (array.ndim, 2))
  1556. pad_width = torch.flip(pad_width, (0,)).flatten()
  1557. return torch.nn.functional.pad(array, tuple(pad_width), value=value)