_util.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. # mypy: ignore-errors
  2. """Assorted utilities, which do not need anything other then torch and stdlib.
  3. """
  4. import operator
  5. import torch
  6. from . import _dtypes_impl
  7. # https://github.com/numpy/numpy/blob/v1.23.0/numpy/distutils/misc_util.py#L497-L504
  8. def is_sequence(seq):
  9. if isinstance(seq, str):
  10. return False
  11. try:
  12. len(seq)
  13. except Exception:
  14. return False
  15. return True
  16. class AxisError(ValueError, IndexError):
  17. pass
  18. class UFuncTypeError(TypeError, RuntimeError):
  19. pass
  20. def cast_if_needed(tensor, dtype):
  21. # NB: no casting if dtype=None
  22. if dtype is not None and tensor.dtype != dtype:
  23. tensor = tensor.to(dtype)
  24. return tensor
  25. def cast_int_to_float(x):
  26. # cast integers and bools to the default float dtype
  27. if _dtypes_impl._category(x.dtype) < 2:
  28. x = x.to(_dtypes_impl.default_dtypes().float_dtype)
  29. return x
  30. # a replica of the version in ./numpy/numpy/core/src/multiarray/common.h
  31. def normalize_axis_index(ax, ndim, argname=None):
  32. if not (-ndim <= ax < ndim):
  33. raise AxisError(f"axis {ax} is out of bounds for array of dimension {ndim}")
  34. if ax < 0:
  35. ax += ndim
  36. return ax
  37. # from https://github.com/numpy/numpy/blob/main/numpy/core/numeric.py#L1378
  38. def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False):
  39. """
  40. Normalizes an axis argument into a tuple of non-negative integer axes.
  41. This handles shorthands such as ``1`` and converts them to ``(1,)``,
  42. as well as performing the handling of negative indices covered by
  43. `normalize_axis_index`.
  44. By default, this forbids axes from being specified multiple times.
  45. Used internally by multi-axis-checking logic.
  46. Parameters
  47. ----------
  48. axis : int, iterable of int
  49. The un-normalized index or indices of the axis.
  50. ndim : int
  51. The number of dimensions of the array that `axis` should be normalized
  52. against.
  53. argname : str, optional
  54. A prefix to put before the error message, typically the name of the
  55. argument.
  56. allow_duplicate : bool, optional
  57. If False, the default, disallow an axis from being specified twice.
  58. Returns
  59. -------
  60. normalized_axes : tuple of int
  61. The normalized axis index, such that `0 <= normalized_axis < ndim`
  62. """
  63. # Optimization to speed-up the most common cases.
  64. if type(axis) not in (tuple, list):
  65. try:
  66. axis = [operator.index(axis)]
  67. except TypeError:
  68. pass
  69. # Going via an iterator directly is slower than via list comprehension.
  70. axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis])
  71. if not allow_duplicate and len(set(axis)) != len(axis):
  72. if argname:
  73. raise ValueError(f"repeated axis in `{argname}` argument")
  74. else:
  75. raise ValueError("repeated axis")
  76. return axis
  77. def allow_only_single_axis(axis):
  78. if axis is None:
  79. return axis
  80. if len(axis) != 1:
  81. raise NotImplementedError("does not handle tuple axis")
  82. return axis[0]
  83. def expand_shape(arr_shape, axis):
  84. # taken from numpy 1.23.x, expand_dims function
  85. if type(axis) not in (list, tuple):
  86. axis = (axis,)
  87. out_ndim = len(axis) + len(arr_shape)
  88. axis = normalize_axis_tuple(axis, out_ndim)
  89. shape_it = iter(arr_shape)
  90. shape = [1 if ax in axis else next(shape_it) for ax in range(out_ndim)]
  91. return shape
  92. def apply_keepdims(tensor, axis, ndim):
  93. if axis is None:
  94. # tensor was a scalar
  95. shape = (1,) * ndim
  96. tensor = tensor.expand(shape).contiguous()
  97. else:
  98. shape = expand_shape(tensor.shape, axis)
  99. tensor = tensor.reshape(shape)
  100. return tensor
  101. def axis_none_flatten(*tensors, axis=None):
  102. """Flatten the arrays if axis is None."""
  103. if axis is None:
  104. tensors = tuple(ar.flatten() for ar in tensors)
  105. return tensors, 0
  106. else:
  107. return tensors, axis
  108. def typecast_tensor(t, target_dtype, casting):
  109. """Dtype-cast tensor to target_dtype.
  110. Parameters
  111. ----------
  112. t : torch.Tensor
  113. The tensor to cast
  114. target_dtype : torch dtype object
  115. The array dtype to cast all tensors to
  116. casting : str
  117. The casting mode, see `np.can_cast`
  118. Returns
  119. -------
  120. `torch.Tensor` of the `target_dtype` dtype
  121. Raises
  122. ------
  123. ValueError
  124. if the argument cannot be cast according to the `casting` rule
  125. """
  126. can_cast = _dtypes_impl.can_cast_impl
  127. if not can_cast(t.dtype, target_dtype, casting=casting):
  128. raise TypeError(
  129. f"Cannot cast array data from {t.dtype} to"
  130. f" {target_dtype} according to the rule '{casting}'"
  131. )
  132. return cast_if_needed(t, target_dtype)
  133. def typecast_tensors(tensors, target_dtype, casting):
  134. return tuple(typecast_tensor(t, target_dtype, casting) for t in tensors)
  135. def _try_convert_to_tensor(obj):
  136. try:
  137. tensor = torch.as_tensor(obj)
  138. except Exception as e:
  139. mesg = f"failed to convert {obj} to ndarray. \nInternal error is: {str(e)}."
  140. raise NotImplementedError(mesg) # noqa: B904
  141. return tensor
  142. def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0):
  143. """The core logic of the array(...) function.
  144. Parameters
  145. ----------
  146. obj : tensor_like
  147. The thing to coerce
  148. dtype : torch.dtype object or None
  149. Coerce to this torch dtype
  150. copy : bool
  151. Copy or not
  152. ndmin : int
  153. The results as least this many dimensions
  154. is_weak : bool
  155. Whether obj is a weakly typed python scalar.
  156. Returns
  157. -------
  158. tensor : torch.Tensor
  159. a tensor object with requested dtype, ndim and copy semantics.
  160. Notes
  161. -----
  162. This is almost a "tensor_like" coersion function. Does not handle wrapper
  163. ndarrays (those should be handled in the ndarray-aware layer prior to
  164. invoking this function).
  165. """
  166. if isinstance(obj, torch.Tensor):
  167. tensor = obj
  168. else:
  169. # tensor.dtype is the pytorch default, typically float32. If obj's elements
  170. # are not exactly representable in float32, we've lost precision:
  171. # >>> torch.as_tensor(1e12).item() - 1e12
  172. # -4096.0
  173. default_dtype = torch.get_default_dtype()
  174. torch.set_default_dtype(_dtypes_impl.get_default_dtype_for(torch.float32))
  175. try:
  176. tensor = _try_convert_to_tensor(obj)
  177. finally:
  178. torch.set_default_dtype(default_dtype)
  179. # type cast if requested
  180. tensor = cast_if_needed(tensor, dtype)
  181. # adjust ndim if needed
  182. ndim_extra = ndmin - tensor.ndim
  183. if ndim_extra > 0:
  184. tensor = tensor.view((1,) * ndim_extra + tensor.shape)
  185. # copy if requested
  186. if copy:
  187. tensor = tensor.clone()
  188. return tensor
  189. def ndarrays_to_tensors(*inputs):
  190. """Convert all ndarrays from `inputs` to tensors. (other things are intact)"""
  191. from ._ndarray import ndarray
  192. if len(inputs) == 0:
  193. return ValueError()
  194. elif len(inputs) == 1:
  195. input_ = inputs[0]
  196. if isinstance(input_, ndarray):
  197. return input_.tensor
  198. elif isinstance(input_, tuple):
  199. result = []
  200. for sub_input in input_:
  201. sub_result = ndarrays_to_tensors(sub_input)
  202. result.append(sub_result)
  203. return tuple(result)
  204. else:
  205. return input_
  206. else:
  207. assert isinstance(inputs, tuple) # sanity check
  208. return ndarrays_to_tensors(inputs)