| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334 |
- # mypy: ignore-errors
- from __future__ import annotations
- from typing import Optional
- import torch
- from . import _binary_ufuncs_impl, _dtypes_impl, _unary_ufuncs_impl, _util
- from ._normalizations import (
- ArrayLike,
- ArrayLikeOrScalar,
- CastingModes,
- DTypeLike,
- normalizer,
- NotImplementedType,
- OutArray,
- )
- def _ufunc_postprocess(result, out, casting):
- if out is not None:
- result = _util.typecast_tensor(result, out.dtype.torch_dtype, casting)
- result = torch.broadcast_to(result, out.shape)
- return result
- # ############# Binary ufuncs ######################
- _binary = [
- name
- for name in dir(_binary_ufuncs_impl)
- if not name.startswith("_") and name not in ["torch", "matmul", "divmod", "ldexp"]
- ]
- NEP50_FUNCS = (
- "add",
- "subtract",
- "multiply",
- "floor_divide",
- "true_divide",
- "divide",
- "remainder",
- "bitwise_and",
- "bitwise_or",
- "bitwise_xor",
- "bitwise_left_shift",
- "bitwise_right_shift",
- "hypot",
- "arctan2",
- "logaddexp",
- "logaddexp2",
- "heaviside",
- "copysign",
- "fmax",
- "minimum",
- "fmin",
- "maximum",
- "fmod",
- "gcd",
- "lcm",
- "pow",
- )
- def deco_binary_ufunc(torch_func):
- """Common infra for binary ufuncs.
- Normalize arguments, sort out type casting, broadcasting and delegate to
- the pytorch functions for the actual work.
- """
- @normalizer
- def wrapped(
- x1: ArrayLikeOrScalar,
- x2: ArrayLikeOrScalar,
- /,
- out: Optional[OutArray] = None,
- *,
- where: NotImplementedType = True,
- casting: Optional[CastingModes] = "same_kind",
- order: NotImplementedType = "K",
- dtype: Optional[DTypeLike] = None,
- subok: NotImplementedType = False,
- signature: NotImplementedType = None,
- extobj: NotImplementedType = None,
- ):
- if dtype is not None:
- def cast(x, dtype):
- if isinstance(x, torch.Tensor):
- return _util.typecast_tensor(x, dtype, casting)
- else:
- return torch.as_tensor(x, dtype=dtype)
- x1 = cast(x1, dtype)
- x2 = cast(x2, dtype)
- elif isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor):
- dtype = _dtypes_impl.result_type_impl(x1, x2)
- x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting)
- else:
- x1, x2 = _dtypes_impl.nep50_to_tensors(
- x1, x2, torch_func.__name__ in NEP50_FUNCS, torch_func.__name__
- )
- result = torch_func(x1, x2)
- return _ufunc_postprocess(result, out, casting)
- wrapped.__qualname__ = torch_func.__name__
- wrapped.__name__ = torch_func.__name__
- return wrapped
- # matmul's signature is _slightly_ different from other ufuncs:
- # - no where=...
- # - additional axis=..., axes=...
- # - no NEP50 scalars in or out
- @normalizer
- def matmul(
- x1: ArrayLike,
- x2: ArrayLike,
- /,
- out: Optional[OutArray] = None,
- *,
- casting: Optional[CastingModes] = "same_kind",
- order: NotImplementedType = "K",
- dtype: Optional[DTypeLike] = None,
- subok: NotImplementedType = False,
- signature: NotImplementedType = None,
- extobj: NotImplementedType = None,
- axes: NotImplementedType = None,
- axis: NotImplementedType = None,
- ):
- if dtype is None:
- dtype = _dtypes_impl.result_type_impl(x1, x2)
- x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting)
- result = _binary_ufuncs_impl.matmul(x1, x2)
- result = _ufunc_postprocess(result, out, casting)
- return result
- # ldexp casting is special : the dtype of the result == dtype of the 1st arg
- @normalizer
- def ldexp(
- x1: ArrayLikeOrScalar,
- x2: ArrayLikeOrScalar,
- /,
- out: Optional[OutArray] = None,
- *,
- where: NotImplementedType = True,
- casting: Optional[CastingModes] = "same_kind",
- order: NotImplementedType = "K",
- dtype: Optional[DTypeLike] = None,
- subok: NotImplementedType = False,
- signature: NotImplementedType = None,
- extobj: NotImplementedType = None,
- ):
- if dtype is not None:
- if isinstance(x1, torch.Tensor):
- x1 = _util.typecast_tensor(x1, dtype, casting)
- else:
- x1 = torch.as_tensor(x1, dtype=dtype)
- else:
- if not isinstance(x1, torch.Tensor):
- x1 = torch.as_tensor(x1)
- x1 = _util.cast_int_to_float(x1)
- x2 = torch.as_tensor(x2)
- # the second arg must be integer
- if _dtypes_impl._category(x2.dtype) != 1:
- raise ValueError("ldexp 2nd arg must be integer")
- result = _binary_ufuncs_impl.ldexp(x1, x2)
- if x1.dtype == torch.float16:
- # torch.ldexp(f16, int) -> f32, undo it
- result = result.to(torch.float16)
- return _ufunc_postprocess(result, out, casting)
- # nin=2, nout=2
- @normalizer
- def divmod(
- x1: ArrayLike,
- x2: ArrayLike,
- out1: Optional[OutArray] = None,
- out2: Optional[OutArray] = None,
- /,
- out: tuple[Optional[OutArray], Optional[OutArray]] = (None, None),
- *,
- where: NotImplementedType = True,
- casting: Optional[CastingModes] = "same_kind",
- order: NotImplementedType = "K",
- dtype: Optional[DTypeLike] = None,
- subok: NotImplementedType = False,
- signature: NotImplementedType = None,
- extobj: NotImplementedType = None,
- ):
- # make sure we either have no out arrays at all, or there is either
- # out1, out2, or out=tuple, but not both
- num_outs = sum(x is not None for x in [out1, out2])
- if num_outs == 1:
- raise ValueError("both out1 and out2 need to be provided")
- elif num_outs == 2:
- o1, o2 = out
- if o1 is not None or o2 is not None:
- raise TypeError(
- "cannot specify 'out' as both a positional and keyword argument"
- )
- else:
- out1, out2 = out
- if dtype is None:
- dtype = _dtypes_impl.result_type_impl(x1, x2)
- x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting)
- quot, rem = _binary_ufuncs_impl.divmod(x1, x2)
- quot = _ufunc_postprocess(quot, out1, casting)
- rem = _ufunc_postprocess(rem, out2, casting)
- return quot, rem
- #
- # Attach ufuncs to this module, for a further export to the public namespace in __init__.py
- #
- for name in _binary:
- ufunc = getattr(_binary_ufuncs_impl, name)
- vars()[name] = deco_binary_ufunc(ufunc)
- def modf(x, /, *args, **kwds):
- quot, rem = divmod(x, 1, *args, **kwds)
- return rem, quot
- _binary = _binary + ["divmod", "modf", "matmul", "ldexp"]
- # ############# Unary ufuncs ######################
- _unary = [
- name
- for name in dir(_unary_ufuncs_impl)
- if not name.startswith("_") and name != "torch"
- ]
- # these are ufunc(int) -> float
- _fp_unary = [
- "arccos",
- "arccosh",
- "arcsin",
- "arcsinh",
- "arctan",
- "arctanh",
- "cbrt",
- "cos",
- "cosh",
- "deg2rad",
- "degrees",
- "exp",
- "exp2",
- "expm1",
- "log",
- "log10",
- "log1p",
- "log2",
- "rad2deg",
- "radians",
- "reciprocal",
- "sin",
- "sinh",
- "sqrt",
- "square",
- "tan",
- "tanh",
- "trunc",
- ]
- def deco_unary_ufunc(torch_func):
- """Common infra for unary ufuncs.
- Normalize arguments, sort out type casting, broadcasting and delegate to
- the pytorch functions for the actual work.
- """
- @normalizer
- def wrapped(
- x: ArrayLike,
- /,
- out: Optional[OutArray] = None,
- *,
- where=True,
- casting: Optional[CastingModes] = "same_kind",
- order="K",
- dtype: Optional[DTypeLike] = None,
- subok: NotImplementedType = False,
- signature=None,
- extobj=None,
- ):
- if dtype is not None:
- x = _util.typecast_tensor(x, dtype, casting)
- if torch_func.__name__ in _fp_unary:
- x = _util.cast_int_to_float(x)
- result = torch_func(x)
- result = _ufunc_postprocess(result, out, casting)
- return result
- wrapped.__qualname__ = torch_func.__name__
- wrapped.__name__ = torch_func.__name__
- return wrapped
- #
- # Attach ufuncs to this module, for a further export to the public namespace in __init__.py
- #
- for name in _unary:
- ufunc = getattr(_unary_ufuncs_impl, name)
- vars()[name] = deco_unary_ufunc(ufunc)
- __all__ = _binary + _unary # noqa: PLE0605
|