| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258 |
- # mypy: ignore-errors
- """ "Normalize" arguments: convert array_likes to tensors, dtypes to torch dtypes and so on.
- """
- from __future__ import annotations
- import functools
- import inspect
- import operator
- import typing
- import torch
- from . import _dtypes, _dtypes_impl, _util
- ArrayLike = typing.TypeVar("ArrayLike")
- Scalar = typing.Union[int, float, complex, bool]
- ArrayLikeOrScalar = typing.Union[ArrayLike, Scalar]
- DTypeLike = typing.TypeVar("DTypeLike")
- AxisLike = typing.TypeVar("AxisLike")
- NDArray = typing.TypeVar("NDArray")
- CastingModes = typing.TypeVar("CastingModes")
- KeepDims = typing.TypeVar("KeepDims")
- # OutArray is to annotate the out= array argument.
- #
- # This one is special is several respects:
- # First, It needs to be an NDArray, and we need to preserve the `result is out`
- # semantics. Therefore, we cannot just extract the Tensor from the out array.
- # So we never pass the out array to implementer functions and handle it in the
- # `normalizer` below.
- # Second, the out= argument can be either keyword or positional argument, and
- # as a positional arg, it can be anywhere in the signature.
- # To handle all this, we define a special `OutArray` annotation and dispatch on it.
- #
- OutArray = typing.TypeVar("OutArray")
- try:
- from typing import NotImplementedType
- except ImportError:
- NotImplementedType = typing.TypeVar("NotImplementedType")
- def normalize_array_like(x, parm=None):
- from ._ndarray import asarray
- return asarray(x).tensor
- def normalize_array_like_or_scalar(x, parm=None):
- if _dtypes_impl.is_scalar_or_symbolic(x):
- return x
- return normalize_array_like(x, parm)
- def normalize_optional_array_like_or_scalar(x, parm=None):
- if x is None:
- return None
- return normalize_array_like_or_scalar(x, parm)
- def normalize_optional_array_like(x, parm=None):
- # This explicit normalizer is needed because otherwise normalize_array_like
- # does not run for a parameter annotated as Optional[ArrayLike]
- return None if x is None else normalize_array_like(x, parm)
- def normalize_seq_array_like(x, parm=None):
- return tuple(normalize_array_like(value) for value in x)
- def normalize_dtype(dtype, parm=None):
- # cf _decorators.dtype_to_torch
- torch_dtype = None
- if dtype is not None:
- dtype = _dtypes.dtype(dtype)
- torch_dtype = dtype.torch_dtype
- return torch_dtype
- def normalize_not_implemented(arg, parm):
- if arg != parm.default:
- raise NotImplementedError(f"'{parm.name}' parameter is not supported.")
- def normalize_axis_like(arg, parm=None):
- from ._ndarray import ndarray
- if isinstance(arg, ndarray):
- arg = operator.index(arg)
- return arg
- def normalize_ndarray(arg, parm=None):
- # check the arg is an ndarray, extract its tensor attribute
- if arg is None:
- return arg
- from ._ndarray import ndarray
- if not isinstance(arg, ndarray):
- raise TypeError(f"'{parm.name}' must be an array")
- return arg.tensor
- def normalize_outarray(arg, parm=None):
- # almost normalize_ndarray, only return the array, not its tensor
- if arg is None:
- return arg
- from ._ndarray import ndarray
- # Dynamo can pass torch tensors as out arguments,
- # wrap it in an ndarray before processing
- if isinstance(arg, torch.Tensor):
- arg = ndarray(arg)
- if not isinstance(arg, ndarray):
- raise TypeError(f"'{parm.name}' must be an array")
- return arg
- def normalize_casting(arg, parm=None):
- if arg not in ["no", "equiv", "safe", "same_kind", "unsafe"]:
- raise ValueError(
- f"casting must be one of 'no', 'equiv', 'safe', 'same_kind', or 'unsafe' (got '{arg}')"
- )
- return arg
- normalizers = {
- "ArrayLike": normalize_array_like,
- "ArrayLikeOrScalar": normalize_array_like_or_scalar,
- "Optional[ArrayLike]": normalize_optional_array_like,
- "Sequence[ArrayLike]": normalize_seq_array_like,
- "Optional[ArrayLikeOrScalar]": normalize_optional_array_like_or_scalar,
- "Optional[NDArray]": normalize_ndarray,
- "Optional[OutArray]": normalize_outarray,
- "NDArray": normalize_ndarray,
- "Optional[DTypeLike]": normalize_dtype,
- "AxisLike": normalize_axis_like,
- "NotImplementedType": normalize_not_implemented,
- "Optional[CastingModes]": normalize_casting,
- }
- def maybe_normalize(arg, parm):
- """Normalize arg if a normalizer is registered."""
- normalizer = normalizers.get(parm.annotation, None)
- return normalizer(arg, parm) if normalizer else arg
- # ### Return value helpers ###
- def maybe_copy_to(out, result, promote_scalar_result=False):
- # NB: here out is either an ndarray or None
- if out is None:
- return result
- elif isinstance(result, torch.Tensor):
- if result.shape != out.shape:
- can_fit = result.numel() == 1 and out.ndim == 0
- if promote_scalar_result and can_fit:
- result = result.squeeze()
- else:
- raise ValueError(
- f"Bad size of the out array: out.shape = {out.shape}"
- f" while result.shape = {result.shape}."
- )
- out.tensor.copy_(result)
- return out
- elif isinstance(result, (tuple, list)):
- return type(result)(
- maybe_copy_to(o, r, promote_scalar_result) for o, r in zip(out, result)
- )
- else:
- raise AssertionError # We should never hit this path
- def wrap_tensors(result):
- from ._ndarray import ndarray
- if isinstance(result, torch.Tensor):
- return ndarray(result)
- elif isinstance(result, (tuple, list)):
- result = type(result)(wrap_tensors(x) for x in result)
- return result
- def array_or_scalar(values, py_type=float, return_scalar=False):
- if return_scalar:
- return py_type(values.item())
- else:
- from ._ndarray import ndarray
- return ndarray(values)
- # ### The main decorator to normalize arguments / postprocess the output ###
- def normalizer(_func=None, *, promote_scalar_result=False):
- def normalizer_inner(func):
- @functools.wraps(func)
- def wrapped(*args, **kwds):
- sig = inspect.signature(func)
- params = sig.parameters
- first_param = next(iter(params.values()))
- # NumPy's API does not have positional args before variadic positional args
- if first_param.kind == inspect.Parameter.VAR_POSITIONAL:
- args = [maybe_normalize(arg, first_param) for arg in args]
- else:
- # NB: extra unknown arguments: pass through, will raise in func(*args) below
- args = (
- tuple(
- maybe_normalize(arg, parm)
- for arg, parm in zip(args, params.values())
- )
- + args[len(params.values()) :]
- )
- kwds = {
- name: maybe_normalize(arg, params[name]) if name in params else arg
- for name, arg in kwds.items()
- }
- result = func(*args, **kwds)
- # keepdims
- bound_args = None
- if "keepdims" in params and params["keepdims"].annotation == "KeepDims":
- # keepdims can be in any position so we need sig.bind
- bound_args = sig.bind(*args, **kwds).arguments
- if bound_args.get("keepdims", False):
- # In this case the first arg is the initial tensor and
- # the second arg is (optionally) the axis
- tensor = args[0]
- axis = bound_args.get("axis")
- result = _util.apply_keepdims(result, axis, tensor.ndim)
- # out
- if "out" in params:
- # out can be in any position so we need sig.bind
- if bound_args is None:
- bound_args = sig.bind(*args, **kwds).arguments
- out = bound_args.get("out")
- result = maybe_copy_to(out, result, promote_scalar_result)
- result = wrap_tensors(result)
- return result
- return wrapped
- if _func is None:
- return normalizer_inner
- else:
- return normalizer_inner(_func)
|