_reductions_impl.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. # mypy: ignore-errors
  2. """ Implementation of reduction operations, to be wrapped into arrays, dtypes etc
  3. in the 'public' layer.
  4. Anything here only deals with torch objects, e.g. "dtype" is a torch.dtype instance etc
  5. """
  6. from __future__ import annotations
  7. import functools
  8. from typing import Optional, TYPE_CHECKING
  9. import torch
  10. from . import _dtypes_impl, _util
  11. if TYPE_CHECKING:
  12. from ._normalizations import (
  13. ArrayLike,
  14. AxisLike,
  15. DTypeLike,
  16. KeepDims,
  17. NotImplementedType,
  18. OutArray,
  19. )
  20. def _deco_axis_expand(func):
  21. """
  22. Generically handle axis arguments in reductions.
  23. axis is *always* the 2nd arg in the function so no need to have a look at its signature
  24. """
  25. @functools.wraps(func)
  26. def wrapped(a, axis=None, *args, **kwds):
  27. if axis is not None:
  28. axis = _util.normalize_axis_tuple(axis, a.ndim)
  29. if axis == ():
  30. # So we insert a length-one axis and run the reduction along it.
  31. # We cannot return a.clone() as this would sidestep the checks inside the function
  32. newshape = _util.expand_shape(a.shape, axis=0)
  33. a = a.reshape(newshape)
  34. axis = (0,)
  35. return func(a, axis, *args, **kwds)
  36. return wrapped
  37. def _atleast_float(dtype, other_dtype):
  38. """Return a dtype that is real or complex floating-point.
  39. For inputs that are boolean or integer dtypes, this returns the default
  40. float dtype; inputs that are complex get converted to the default complex
  41. dtype; real floating-point dtypes (`float*`) get passed through unchanged
  42. """
  43. if dtype is None:
  44. dtype = other_dtype
  45. if not (dtype.is_floating_point or dtype.is_complex):
  46. return _dtypes_impl.default_dtypes().float_dtype
  47. return dtype
  48. @_deco_axis_expand
  49. def count_nonzero(a: ArrayLike, axis: AxisLike = None, *, keepdims: KeepDims = False):
  50. return a.count_nonzero(axis)
  51. @_deco_axis_expand
  52. def argmax(
  53. a: ArrayLike,
  54. axis: AxisLike = None,
  55. out: Optional[OutArray] = None,
  56. *,
  57. keepdims: KeepDims = False,
  58. ):
  59. if a.is_complex():
  60. raise NotImplementedError(f"argmax with dtype={a.dtype}.")
  61. axis = _util.allow_only_single_axis(axis)
  62. if a.dtype == torch.bool:
  63. # RuntimeError: "argmax_cpu" not implemented for 'Bool'
  64. a = a.to(torch.uint8)
  65. return torch.argmax(a, axis)
  66. @_deco_axis_expand
  67. def argmin(
  68. a: ArrayLike,
  69. axis: AxisLike = None,
  70. out: Optional[OutArray] = None,
  71. *,
  72. keepdims: KeepDims = False,
  73. ):
  74. if a.is_complex():
  75. raise NotImplementedError(f"argmin with dtype={a.dtype}.")
  76. axis = _util.allow_only_single_axis(axis)
  77. if a.dtype == torch.bool:
  78. # RuntimeError: "argmin_cpu" not implemented for 'Bool'
  79. a = a.to(torch.uint8)
  80. return torch.argmin(a, axis)
  81. @_deco_axis_expand
  82. def any(
  83. a: ArrayLike,
  84. axis: AxisLike = None,
  85. out: Optional[OutArray] = None,
  86. keepdims: KeepDims = False,
  87. *,
  88. where: NotImplementedType = None,
  89. ):
  90. axis = _util.allow_only_single_axis(axis)
  91. axis_kw = {} if axis is None else {"dim": axis}
  92. return torch.any(a, **axis_kw)
  93. @_deco_axis_expand
  94. def all(
  95. a: ArrayLike,
  96. axis: AxisLike = None,
  97. out: Optional[OutArray] = None,
  98. keepdims: KeepDims = False,
  99. *,
  100. where: NotImplementedType = None,
  101. ):
  102. axis = _util.allow_only_single_axis(axis)
  103. axis_kw = {} if axis is None else {"dim": axis}
  104. return torch.all(a, **axis_kw)
  105. @_deco_axis_expand
  106. def amax(
  107. a: ArrayLike,
  108. axis: AxisLike = None,
  109. out: Optional[OutArray] = None,
  110. keepdims: KeepDims = False,
  111. initial: NotImplementedType = None,
  112. where: NotImplementedType = None,
  113. ):
  114. if a.is_complex():
  115. raise NotImplementedError(f"amax with dtype={a.dtype}")
  116. return a.amax(axis)
  117. max = amax
  118. @_deco_axis_expand
  119. def amin(
  120. a: ArrayLike,
  121. axis: AxisLike = None,
  122. out: Optional[OutArray] = None,
  123. keepdims: KeepDims = False,
  124. initial: NotImplementedType = None,
  125. where: NotImplementedType = None,
  126. ):
  127. if a.is_complex():
  128. raise NotImplementedError(f"amin with dtype={a.dtype}")
  129. return a.amin(axis)
  130. min = amin
  131. @_deco_axis_expand
  132. def ptp(
  133. a: ArrayLike,
  134. axis: AxisLike = None,
  135. out: Optional[OutArray] = None,
  136. keepdims: KeepDims = False,
  137. ):
  138. return a.amax(axis) - a.amin(axis)
  139. @_deco_axis_expand
  140. def sum(
  141. a: ArrayLike,
  142. axis: AxisLike = None,
  143. dtype: Optional[DTypeLike] = None,
  144. out: Optional[OutArray] = None,
  145. keepdims: KeepDims = False,
  146. initial: NotImplementedType = None,
  147. where: NotImplementedType = None,
  148. ):
  149. assert dtype is None or isinstance(dtype, torch.dtype)
  150. if dtype == torch.bool:
  151. dtype = _dtypes_impl.default_dtypes().int_dtype
  152. axis_kw = {} if axis is None else {"dim": axis}
  153. return a.sum(dtype=dtype, **axis_kw)
  154. @_deco_axis_expand
  155. def prod(
  156. a: ArrayLike,
  157. axis: AxisLike = None,
  158. dtype: Optional[DTypeLike] = None,
  159. out: Optional[OutArray] = None,
  160. keepdims: KeepDims = False,
  161. initial: NotImplementedType = None,
  162. where: NotImplementedType = None,
  163. ):
  164. axis = _util.allow_only_single_axis(axis)
  165. if dtype == torch.bool:
  166. dtype = _dtypes_impl.default_dtypes().int_dtype
  167. axis_kw = {} if axis is None else {"dim": axis}
  168. return a.prod(dtype=dtype, **axis_kw)
  169. product = prod
  170. @_deco_axis_expand
  171. def mean(
  172. a: ArrayLike,
  173. axis: AxisLike = None,
  174. dtype: Optional[DTypeLike] = None,
  175. out: Optional[OutArray] = None,
  176. keepdims: KeepDims = False,
  177. *,
  178. where: NotImplementedType = None,
  179. ):
  180. dtype = _atleast_float(dtype, a.dtype)
  181. axis_kw = {} if axis is None else {"dim": axis}
  182. result = a.mean(dtype=dtype, **axis_kw)
  183. return result
  184. @_deco_axis_expand
  185. def std(
  186. a: ArrayLike,
  187. axis: AxisLike = None,
  188. dtype: Optional[DTypeLike] = None,
  189. out: Optional[OutArray] = None,
  190. ddof=0,
  191. keepdims: KeepDims = False,
  192. *,
  193. where: NotImplementedType = None,
  194. ):
  195. in_dtype = dtype
  196. dtype = _atleast_float(dtype, a.dtype)
  197. tensor = _util.cast_if_needed(a, dtype)
  198. result = tensor.std(dim=axis, correction=ddof)
  199. return _util.cast_if_needed(result, in_dtype)
  200. @_deco_axis_expand
  201. def var(
  202. a: ArrayLike,
  203. axis: AxisLike = None,
  204. dtype: Optional[DTypeLike] = None,
  205. out: Optional[OutArray] = None,
  206. ddof=0,
  207. keepdims: KeepDims = False,
  208. *,
  209. where: NotImplementedType = None,
  210. ):
  211. in_dtype = dtype
  212. dtype = _atleast_float(dtype, a.dtype)
  213. tensor = _util.cast_if_needed(a, dtype)
  214. result = tensor.var(dim=axis, correction=ddof)
  215. return _util.cast_if_needed(result, in_dtype)
  216. # cumsum / cumprod are almost reductions:
  217. # 1. no keepdims
  218. # 2. axis=None flattens
  219. def cumsum(
  220. a: ArrayLike,
  221. axis: AxisLike = None,
  222. dtype: Optional[DTypeLike] = None,
  223. out: Optional[OutArray] = None,
  224. ):
  225. if dtype == torch.bool:
  226. dtype = _dtypes_impl.default_dtypes().int_dtype
  227. if dtype is None:
  228. dtype = a.dtype
  229. (a,), axis = _util.axis_none_flatten(a, axis=axis)
  230. axis = _util.normalize_axis_index(axis, a.ndim)
  231. return a.cumsum(axis=axis, dtype=dtype)
  232. def cumprod(
  233. a: ArrayLike,
  234. axis: AxisLike = None,
  235. dtype: Optional[DTypeLike] = None,
  236. out: Optional[OutArray] = None,
  237. ):
  238. if dtype == torch.bool:
  239. dtype = _dtypes_impl.default_dtypes().int_dtype
  240. if dtype is None:
  241. dtype = a.dtype
  242. (a,), axis = _util.axis_none_flatten(a, axis=axis)
  243. axis = _util.normalize_axis_index(axis, a.ndim)
  244. return a.cumprod(axis=axis, dtype=dtype)
  245. cumproduct = cumprod
  246. def average(
  247. a: ArrayLike,
  248. axis=None,
  249. weights: ArrayLike = None,
  250. returned=False,
  251. *,
  252. keepdims=False,
  253. ):
  254. if weights is None:
  255. result = mean(a, axis=axis)
  256. wsum = torch.as_tensor(a.numel() / result.numel(), dtype=result.dtype)
  257. else:
  258. if not a.dtype.is_floating_point:
  259. a = a.double()
  260. # axis & weights
  261. if a.shape != weights.shape:
  262. if axis is None:
  263. raise TypeError(
  264. "Axis must be specified when shapes of a and weights differ."
  265. )
  266. if weights.ndim != 1:
  267. raise TypeError(
  268. "1D weights expected when shapes of a and weights differ."
  269. )
  270. if weights.shape[0] != a.shape[axis]:
  271. raise ValueError(
  272. "Length of weights not compatible with specified axis."
  273. )
  274. # setup weight to broadcast along axis
  275. weights = torch.broadcast_to(weights, (a.ndim - 1) * (1,) + weights.shape)
  276. weights = weights.swapaxes(-1, axis)
  277. # do the work
  278. result_dtype = _dtypes_impl.result_type_impl(a, weights)
  279. numerator = sum(a * weights, axis, dtype=result_dtype)
  280. wsum = sum(weights, axis, dtype=result_dtype)
  281. result = numerator / wsum
  282. # We process keepdims manually because the decorator does not deal with variadic returns
  283. if keepdims:
  284. result = _util.apply_keepdims(result, axis, a.ndim)
  285. if returned:
  286. if wsum.shape != result.shape:
  287. wsum = torch.broadcast_to(wsum, result.shape).clone()
  288. return result, wsum
  289. else:
  290. return result
  291. # Not using deco_axis_expand as it assumes that axis is the second arg
  292. def quantile(
  293. a: ArrayLike,
  294. q: ArrayLike,
  295. axis: AxisLike = None,
  296. out: Optional[OutArray] = None,
  297. overwrite_input=False,
  298. method="linear",
  299. keepdims: KeepDims = False,
  300. *,
  301. interpolation: NotImplementedType = None,
  302. ):
  303. if overwrite_input:
  304. # raise NotImplementedError("overwrite_input in quantile not implemented.")
  305. # NumPy documents that `overwrite_input` MAY modify inputs:
  306. # https://numpy.org/doc/stable/reference/generated/numpy.percentile.html#numpy-percentile
  307. # Here we choose to work out-of-place because why not.
  308. pass
  309. if not a.dtype.is_floating_point:
  310. dtype = _dtypes_impl.default_dtypes().float_dtype
  311. a = a.to(dtype)
  312. # edge case: torch.quantile only supports float32 and float64
  313. if a.dtype == torch.float16:
  314. a = a.to(torch.float32)
  315. if axis is None:
  316. a = a.flatten()
  317. q = q.flatten()
  318. axis = (0,)
  319. else:
  320. axis = _util.normalize_axis_tuple(axis, a.ndim)
  321. # FIXME(Mario) Doesn't np.quantile accept a tuple?
  322. # torch.quantile does accept a number. If we don't want to implement the tuple behaviour
  323. # (it's deffo low prio) change `normalize_axis_tuple` into a normalize_axis index above.
  324. axis = _util.allow_only_single_axis(axis)
  325. q = _util.cast_if_needed(q, a.dtype)
  326. return torch.quantile(a, q, axis=axis, interpolation=method)
  327. def percentile(
  328. a: ArrayLike,
  329. q: ArrayLike,
  330. axis: AxisLike = None,
  331. out: Optional[OutArray] = None,
  332. overwrite_input=False,
  333. method="linear",
  334. keepdims: KeepDims = False,
  335. *,
  336. interpolation: NotImplementedType = None,
  337. ):
  338. # np.percentile(float_tensor, 30) : q.dtype is int64 => q / 100.0 is float32
  339. if _dtypes_impl.python_type_for_torch(q.dtype) == int:
  340. q = q.to(_dtypes_impl.default_dtypes().float_dtype)
  341. qq = q / 100.0
  342. return quantile(
  343. a,
  344. qq,
  345. axis=axis,
  346. overwrite_input=overwrite_input,
  347. method=method,
  348. keepdims=keepdims,
  349. interpolation=interpolation,
  350. )
  351. def median(
  352. a: ArrayLike,
  353. axis=None,
  354. out: Optional[OutArray] = None,
  355. overwrite_input=False,
  356. keepdims: KeepDims = False,
  357. ):
  358. return quantile(
  359. a,
  360. torch.as_tensor(0.5),
  361. axis=axis,
  362. overwrite_input=overwrite_input,
  363. out=out,
  364. keepdims=keepdims,
  365. )