_normalizations.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. # mypy: ignore-errors
  2. """ "Normalize" arguments: convert array_likes to tensors, dtypes to torch dtypes and so on.
  3. """
  4. from __future__ import annotations
  5. import functools
  6. import inspect
  7. import operator
  8. import typing
  9. import torch
  10. from . import _dtypes, _dtypes_impl, _util
  11. ArrayLike = typing.TypeVar("ArrayLike")
  12. Scalar = typing.Union[int, float, complex, bool]
  13. ArrayLikeOrScalar = typing.Union[ArrayLike, Scalar]
  14. DTypeLike = typing.TypeVar("DTypeLike")
  15. AxisLike = typing.TypeVar("AxisLike")
  16. NDArray = typing.TypeVar("NDArray")
  17. CastingModes = typing.TypeVar("CastingModes")
  18. KeepDims = typing.TypeVar("KeepDims")
  19. # OutArray is to annotate the out= array argument.
  20. #
  21. # This one is special is several respects:
  22. # First, It needs to be an NDArray, and we need to preserve the `result is out`
  23. # semantics. Therefore, we cannot just extract the Tensor from the out array.
  24. # So we never pass the out array to implementer functions and handle it in the
  25. # `normalizer` below.
  26. # Second, the out= argument can be either keyword or positional argument, and
  27. # as a positional arg, it can be anywhere in the signature.
  28. # To handle all this, we define a special `OutArray` annotation and dispatch on it.
  29. #
  30. OutArray = typing.TypeVar("OutArray")
  31. try:
  32. from typing import NotImplementedType
  33. except ImportError:
  34. NotImplementedType = typing.TypeVar("NotImplementedType")
  35. def normalize_array_like(x, parm=None):
  36. from ._ndarray import asarray
  37. return asarray(x).tensor
  38. def normalize_array_like_or_scalar(x, parm=None):
  39. if _dtypes_impl.is_scalar_or_symbolic(x):
  40. return x
  41. return normalize_array_like(x, parm)
  42. def normalize_optional_array_like_or_scalar(x, parm=None):
  43. if x is None:
  44. return None
  45. return normalize_array_like_or_scalar(x, parm)
  46. def normalize_optional_array_like(x, parm=None):
  47. # This explicit normalizer is needed because otherwise normalize_array_like
  48. # does not run for a parameter annotated as Optional[ArrayLike]
  49. return None if x is None else normalize_array_like(x, parm)
  50. def normalize_seq_array_like(x, parm=None):
  51. return tuple(normalize_array_like(value) for value in x)
  52. def normalize_dtype(dtype, parm=None):
  53. # cf _decorators.dtype_to_torch
  54. torch_dtype = None
  55. if dtype is not None:
  56. dtype = _dtypes.dtype(dtype)
  57. torch_dtype = dtype.torch_dtype
  58. return torch_dtype
  59. def normalize_not_implemented(arg, parm):
  60. if arg != parm.default:
  61. raise NotImplementedError(f"'{parm.name}' parameter is not supported.")
  62. def normalize_axis_like(arg, parm=None):
  63. from ._ndarray import ndarray
  64. if isinstance(arg, ndarray):
  65. arg = operator.index(arg)
  66. return arg
  67. def normalize_ndarray(arg, parm=None):
  68. # check the arg is an ndarray, extract its tensor attribute
  69. if arg is None:
  70. return arg
  71. from ._ndarray import ndarray
  72. if not isinstance(arg, ndarray):
  73. raise TypeError(f"'{parm.name}' must be an array")
  74. return arg.tensor
  75. def normalize_outarray(arg, parm=None):
  76. # almost normalize_ndarray, only return the array, not its tensor
  77. if arg is None:
  78. return arg
  79. from ._ndarray import ndarray
  80. # Dynamo can pass torch tensors as out arguments,
  81. # wrap it in an ndarray before processing
  82. if isinstance(arg, torch.Tensor):
  83. arg = ndarray(arg)
  84. if not isinstance(arg, ndarray):
  85. raise TypeError(f"'{parm.name}' must be an array")
  86. return arg
  87. def normalize_casting(arg, parm=None):
  88. if arg not in ["no", "equiv", "safe", "same_kind", "unsafe"]:
  89. raise ValueError(
  90. f"casting must be one of 'no', 'equiv', 'safe', 'same_kind', or 'unsafe' (got '{arg}')"
  91. )
  92. return arg
  93. normalizers = {
  94. "ArrayLike": normalize_array_like,
  95. "ArrayLikeOrScalar": normalize_array_like_or_scalar,
  96. "Optional[ArrayLike]": normalize_optional_array_like,
  97. "Sequence[ArrayLike]": normalize_seq_array_like,
  98. "Optional[ArrayLikeOrScalar]": normalize_optional_array_like_or_scalar,
  99. "Optional[NDArray]": normalize_ndarray,
  100. "Optional[OutArray]": normalize_outarray,
  101. "NDArray": normalize_ndarray,
  102. "Optional[DTypeLike]": normalize_dtype,
  103. "AxisLike": normalize_axis_like,
  104. "NotImplementedType": normalize_not_implemented,
  105. "Optional[CastingModes]": normalize_casting,
  106. }
  107. def maybe_normalize(arg, parm):
  108. """Normalize arg if a normalizer is registered."""
  109. normalizer = normalizers.get(parm.annotation, None)
  110. return normalizer(arg, parm) if normalizer else arg
  111. # ### Return value helpers ###
  112. def maybe_copy_to(out, result, promote_scalar_result=False):
  113. # NB: here out is either an ndarray or None
  114. if out is None:
  115. return result
  116. elif isinstance(result, torch.Tensor):
  117. if result.shape != out.shape:
  118. can_fit = result.numel() == 1 and out.ndim == 0
  119. if promote_scalar_result and can_fit:
  120. result = result.squeeze()
  121. else:
  122. raise ValueError(
  123. f"Bad size of the out array: out.shape = {out.shape}"
  124. f" while result.shape = {result.shape}."
  125. )
  126. out.tensor.copy_(result)
  127. return out
  128. elif isinstance(result, (tuple, list)):
  129. return type(result)(
  130. maybe_copy_to(o, r, promote_scalar_result) for o, r in zip(out, result)
  131. )
  132. else:
  133. raise AssertionError # We should never hit this path
  134. def wrap_tensors(result):
  135. from ._ndarray import ndarray
  136. if isinstance(result, torch.Tensor):
  137. return ndarray(result)
  138. elif isinstance(result, (tuple, list)):
  139. result = type(result)(wrap_tensors(x) for x in result)
  140. return result
  141. def array_or_scalar(values, py_type=float, return_scalar=False):
  142. if return_scalar:
  143. return py_type(values.item())
  144. else:
  145. from ._ndarray import ndarray
  146. return ndarray(values)
  147. # ### The main decorator to normalize arguments / postprocess the output ###
  148. def normalizer(_func=None, *, promote_scalar_result=False):
  149. def normalizer_inner(func):
  150. @functools.wraps(func)
  151. def wrapped(*args, **kwds):
  152. sig = inspect.signature(func)
  153. params = sig.parameters
  154. first_param = next(iter(params.values()))
  155. # NumPy's API does not have positional args before variadic positional args
  156. if first_param.kind == inspect.Parameter.VAR_POSITIONAL:
  157. args = [maybe_normalize(arg, first_param) for arg in args]
  158. else:
  159. # NB: extra unknown arguments: pass through, will raise in func(*args) below
  160. args = (
  161. tuple(
  162. maybe_normalize(arg, parm)
  163. for arg, parm in zip(args, params.values())
  164. )
  165. + args[len(params.values()) :]
  166. )
  167. kwds = {
  168. name: maybe_normalize(arg, params[name]) if name in params else arg
  169. for name, arg in kwds.items()
  170. }
  171. result = func(*args, **kwds)
  172. # keepdims
  173. bound_args = None
  174. if "keepdims" in params and params["keepdims"].annotation == "KeepDims":
  175. # keepdims can be in any position so we need sig.bind
  176. bound_args = sig.bind(*args, **kwds).arguments
  177. if bound_args.get("keepdims", False):
  178. # In this case the first arg is the initial tensor and
  179. # the second arg is (optionally) the axis
  180. tensor = args[0]
  181. axis = bound_args.get("axis")
  182. result = _util.apply_keepdims(result, axis, tensor.ndim)
  183. # out
  184. if "out" in params:
  185. # out can be in any position so we need sig.bind
  186. if bound_args is None:
  187. bound_args = sig.bind(*args, **kwds).arguments
  188. out = bound_args.get("out")
  189. result = maybe_copy_to(out, result, promote_scalar_result)
  190. result = wrap_tensors(result)
  191. return result
  192. return wrapped
  193. if _func is None:
  194. return normalizer_inner
  195. else:
  196. return normalizer_inner(_func)