_ufuncs.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. # mypy: ignore-errors
  2. from __future__ import annotations
  3. from typing import Optional
  4. import torch
  5. from . import _binary_ufuncs_impl, _dtypes_impl, _unary_ufuncs_impl, _util
  6. from ._normalizations import (
  7. ArrayLike,
  8. ArrayLikeOrScalar,
  9. CastingModes,
  10. DTypeLike,
  11. normalizer,
  12. NotImplementedType,
  13. OutArray,
  14. )
  15. def _ufunc_postprocess(result, out, casting):
  16. if out is not None:
  17. result = _util.typecast_tensor(result, out.dtype.torch_dtype, casting)
  18. result = torch.broadcast_to(result, out.shape)
  19. return result
  20. # ############# Binary ufuncs ######################
  21. _binary = [
  22. name
  23. for name in dir(_binary_ufuncs_impl)
  24. if not name.startswith("_") and name not in ["torch", "matmul", "divmod", "ldexp"]
  25. ]
  26. NEP50_FUNCS = (
  27. "add",
  28. "subtract",
  29. "multiply",
  30. "floor_divide",
  31. "true_divide",
  32. "divide",
  33. "remainder",
  34. "bitwise_and",
  35. "bitwise_or",
  36. "bitwise_xor",
  37. "bitwise_left_shift",
  38. "bitwise_right_shift",
  39. "hypot",
  40. "arctan2",
  41. "logaddexp",
  42. "logaddexp2",
  43. "heaviside",
  44. "copysign",
  45. "fmax",
  46. "minimum",
  47. "fmin",
  48. "maximum",
  49. "fmod",
  50. "gcd",
  51. "lcm",
  52. "pow",
  53. )
  54. def deco_binary_ufunc(torch_func):
  55. """Common infra for binary ufuncs.
  56. Normalize arguments, sort out type casting, broadcasting and delegate to
  57. the pytorch functions for the actual work.
  58. """
  59. @normalizer
  60. def wrapped(
  61. x1: ArrayLikeOrScalar,
  62. x2: ArrayLikeOrScalar,
  63. /,
  64. out: Optional[OutArray] = None,
  65. *,
  66. where: NotImplementedType = True,
  67. casting: Optional[CastingModes] = "same_kind",
  68. order: NotImplementedType = "K",
  69. dtype: Optional[DTypeLike] = None,
  70. subok: NotImplementedType = False,
  71. signature: NotImplementedType = None,
  72. extobj: NotImplementedType = None,
  73. ):
  74. if dtype is not None:
  75. def cast(x, dtype):
  76. if isinstance(x, torch.Tensor):
  77. return _util.typecast_tensor(x, dtype, casting)
  78. else:
  79. return torch.as_tensor(x, dtype=dtype)
  80. x1 = cast(x1, dtype)
  81. x2 = cast(x2, dtype)
  82. elif isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor):
  83. dtype = _dtypes_impl.result_type_impl(x1, x2)
  84. x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting)
  85. else:
  86. x1, x2 = _dtypes_impl.nep50_to_tensors(
  87. x1, x2, torch_func.__name__ in NEP50_FUNCS, torch_func.__name__
  88. )
  89. result = torch_func(x1, x2)
  90. return _ufunc_postprocess(result, out, casting)
  91. wrapped.__qualname__ = torch_func.__name__
  92. wrapped.__name__ = torch_func.__name__
  93. return wrapped
  94. # matmul's signature is _slightly_ different from other ufuncs:
  95. # - no where=...
  96. # - additional axis=..., axes=...
  97. # - no NEP50 scalars in or out
  98. @normalizer
  99. def matmul(
  100. x1: ArrayLike,
  101. x2: ArrayLike,
  102. /,
  103. out: Optional[OutArray] = None,
  104. *,
  105. casting: Optional[CastingModes] = "same_kind",
  106. order: NotImplementedType = "K",
  107. dtype: Optional[DTypeLike] = None,
  108. subok: NotImplementedType = False,
  109. signature: NotImplementedType = None,
  110. extobj: NotImplementedType = None,
  111. axes: NotImplementedType = None,
  112. axis: NotImplementedType = None,
  113. ):
  114. if dtype is None:
  115. dtype = _dtypes_impl.result_type_impl(x1, x2)
  116. x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting)
  117. result = _binary_ufuncs_impl.matmul(x1, x2)
  118. result = _ufunc_postprocess(result, out, casting)
  119. return result
  120. # ldexp casting is special : the dtype of the result == dtype of the 1st arg
  121. @normalizer
  122. def ldexp(
  123. x1: ArrayLikeOrScalar,
  124. x2: ArrayLikeOrScalar,
  125. /,
  126. out: Optional[OutArray] = None,
  127. *,
  128. where: NotImplementedType = True,
  129. casting: Optional[CastingModes] = "same_kind",
  130. order: NotImplementedType = "K",
  131. dtype: Optional[DTypeLike] = None,
  132. subok: NotImplementedType = False,
  133. signature: NotImplementedType = None,
  134. extobj: NotImplementedType = None,
  135. ):
  136. if dtype is not None:
  137. if isinstance(x1, torch.Tensor):
  138. x1 = _util.typecast_tensor(x1, dtype, casting)
  139. else:
  140. x1 = torch.as_tensor(x1, dtype=dtype)
  141. else:
  142. if not isinstance(x1, torch.Tensor):
  143. x1 = torch.as_tensor(x1)
  144. x1 = _util.cast_int_to_float(x1)
  145. x2 = torch.as_tensor(x2)
  146. # the second arg must be integer
  147. if _dtypes_impl._category(x2.dtype) != 1:
  148. raise ValueError("ldexp 2nd arg must be integer")
  149. result = _binary_ufuncs_impl.ldexp(x1, x2)
  150. if x1.dtype == torch.float16:
  151. # torch.ldexp(f16, int) -> f32, undo it
  152. result = result.to(torch.float16)
  153. return _ufunc_postprocess(result, out, casting)
  154. # nin=2, nout=2
  155. @normalizer
  156. def divmod(
  157. x1: ArrayLike,
  158. x2: ArrayLike,
  159. out1: Optional[OutArray] = None,
  160. out2: Optional[OutArray] = None,
  161. /,
  162. out: tuple[Optional[OutArray], Optional[OutArray]] = (None, None),
  163. *,
  164. where: NotImplementedType = True,
  165. casting: Optional[CastingModes] = "same_kind",
  166. order: NotImplementedType = "K",
  167. dtype: Optional[DTypeLike] = None,
  168. subok: NotImplementedType = False,
  169. signature: NotImplementedType = None,
  170. extobj: NotImplementedType = None,
  171. ):
  172. # make sure we either have no out arrays at all, or there is either
  173. # out1, out2, or out=tuple, but not both
  174. num_outs = sum(x is not None for x in [out1, out2])
  175. if num_outs == 1:
  176. raise ValueError("both out1 and out2 need to be provided")
  177. elif num_outs == 2:
  178. o1, o2 = out
  179. if o1 is not None or o2 is not None:
  180. raise TypeError(
  181. "cannot specify 'out' as both a positional and keyword argument"
  182. )
  183. else:
  184. out1, out2 = out
  185. if dtype is None:
  186. dtype = _dtypes_impl.result_type_impl(x1, x2)
  187. x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting)
  188. quot, rem = _binary_ufuncs_impl.divmod(x1, x2)
  189. quot = _ufunc_postprocess(quot, out1, casting)
  190. rem = _ufunc_postprocess(rem, out2, casting)
  191. return quot, rem
  192. #
  193. # Attach ufuncs to this module, for a further export to the public namespace in __init__.py
  194. #
  195. for name in _binary:
  196. ufunc = getattr(_binary_ufuncs_impl, name)
  197. vars()[name] = deco_binary_ufunc(ufunc)
  198. def modf(x, /, *args, **kwds):
  199. quot, rem = divmod(x, 1, *args, **kwds)
  200. return rem, quot
  201. _binary = _binary + ["divmod", "modf", "matmul", "ldexp"]
  202. # ############# Unary ufuncs ######################
  203. _unary = [
  204. name
  205. for name in dir(_unary_ufuncs_impl)
  206. if not name.startswith("_") and name != "torch"
  207. ]
  208. # these are ufunc(int) -> float
  209. _fp_unary = [
  210. "arccos",
  211. "arccosh",
  212. "arcsin",
  213. "arcsinh",
  214. "arctan",
  215. "arctanh",
  216. "cbrt",
  217. "cos",
  218. "cosh",
  219. "deg2rad",
  220. "degrees",
  221. "exp",
  222. "exp2",
  223. "expm1",
  224. "log",
  225. "log10",
  226. "log1p",
  227. "log2",
  228. "rad2deg",
  229. "radians",
  230. "reciprocal",
  231. "sin",
  232. "sinh",
  233. "sqrt",
  234. "square",
  235. "tan",
  236. "tanh",
  237. "trunc",
  238. ]
  239. def deco_unary_ufunc(torch_func):
  240. """Common infra for unary ufuncs.
  241. Normalize arguments, sort out type casting, broadcasting and delegate to
  242. the pytorch functions for the actual work.
  243. """
  244. @normalizer
  245. def wrapped(
  246. x: ArrayLike,
  247. /,
  248. out: Optional[OutArray] = None,
  249. *,
  250. where=True,
  251. casting: Optional[CastingModes] = "same_kind",
  252. order="K",
  253. dtype: Optional[DTypeLike] = None,
  254. subok: NotImplementedType = False,
  255. signature=None,
  256. extobj=None,
  257. ):
  258. if dtype is not None:
  259. x = _util.typecast_tensor(x, dtype, casting)
  260. if torch_func.__name__ in _fp_unary:
  261. x = _util.cast_int_to_float(x)
  262. result = torch_func(x)
  263. result = _ufunc_postprocess(result, out, casting)
  264. return result
  265. wrapped.__qualname__ = torch_func.__name__
  266. wrapped.__name__ = torch_func.__name__
  267. return wrapped
  268. #
  269. # Attach ufuncs to this module, for a further export to the public namespace in __init__.py
  270. #
  271. for name in _unary:
  272. ufunc = getattr(_unary_ufuncs_impl, name)
  273. vars()[name] = deco_unary_ufunc(ufunc)
  274. __all__ = _binary + _unary # noqa: PLE0605