| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261 |
- # mypy: ignore-errors
- """Assorted utilities, which do not need anything other then torch and stdlib.
- """
- import operator
- import torch
- from . import _dtypes_impl
- # https://github.com/numpy/numpy/blob/v1.23.0/numpy/distutils/misc_util.py#L497-L504
- def is_sequence(seq):
- if isinstance(seq, str):
- return False
- try:
- len(seq)
- except Exception:
- return False
- return True
- class AxisError(ValueError, IndexError):
- pass
- class UFuncTypeError(TypeError, RuntimeError):
- pass
- def cast_if_needed(tensor, dtype):
- # NB: no casting if dtype=None
- if dtype is not None and tensor.dtype != dtype:
- tensor = tensor.to(dtype)
- return tensor
- def cast_int_to_float(x):
- # cast integers and bools to the default float dtype
- if _dtypes_impl._category(x.dtype) < 2:
- x = x.to(_dtypes_impl.default_dtypes().float_dtype)
- return x
- # a replica of the version in ./numpy/numpy/core/src/multiarray/common.h
- def normalize_axis_index(ax, ndim, argname=None):
- if not (-ndim <= ax < ndim):
- raise AxisError(f"axis {ax} is out of bounds for array of dimension {ndim}")
- if ax < 0:
- ax += ndim
- return ax
- # from https://github.com/numpy/numpy/blob/main/numpy/core/numeric.py#L1378
- def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False):
- """
- Normalizes an axis argument into a tuple of non-negative integer axes.
- This handles shorthands such as ``1`` and converts them to ``(1,)``,
- as well as performing the handling of negative indices covered by
- `normalize_axis_index`.
- By default, this forbids axes from being specified multiple times.
- Used internally by multi-axis-checking logic.
- Parameters
- ----------
- axis : int, iterable of int
- The un-normalized index or indices of the axis.
- ndim : int
- The number of dimensions of the array that `axis` should be normalized
- against.
- argname : str, optional
- A prefix to put before the error message, typically the name of the
- argument.
- allow_duplicate : bool, optional
- If False, the default, disallow an axis from being specified twice.
- Returns
- -------
- normalized_axes : tuple of int
- The normalized axis index, such that `0 <= normalized_axis < ndim`
- """
- # Optimization to speed-up the most common cases.
- if type(axis) not in (tuple, list):
- try:
- axis = [operator.index(axis)]
- except TypeError:
- pass
- # Going via an iterator directly is slower than via list comprehension.
- axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis])
- if not allow_duplicate and len(set(axis)) != len(axis):
- if argname:
- raise ValueError(f"repeated axis in `{argname}` argument")
- else:
- raise ValueError("repeated axis")
- return axis
- def allow_only_single_axis(axis):
- if axis is None:
- return axis
- if len(axis) != 1:
- raise NotImplementedError("does not handle tuple axis")
- return axis[0]
- def expand_shape(arr_shape, axis):
- # taken from numpy 1.23.x, expand_dims function
- if type(axis) not in (list, tuple):
- axis = (axis,)
- out_ndim = len(axis) + len(arr_shape)
- axis = normalize_axis_tuple(axis, out_ndim)
- shape_it = iter(arr_shape)
- shape = [1 if ax in axis else next(shape_it) for ax in range(out_ndim)]
- return shape
- def apply_keepdims(tensor, axis, ndim):
- if axis is None:
- # tensor was a scalar
- shape = (1,) * ndim
- tensor = tensor.expand(shape).contiguous()
- else:
- shape = expand_shape(tensor.shape, axis)
- tensor = tensor.reshape(shape)
- return tensor
- def axis_none_flatten(*tensors, axis=None):
- """Flatten the arrays if axis is None."""
- if axis is None:
- tensors = tuple(ar.flatten() for ar in tensors)
- return tensors, 0
- else:
- return tensors, axis
- def typecast_tensor(t, target_dtype, casting):
- """Dtype-cast tensor to target_dtype.
- Parameters
- ----------
- t : torch.Tensor
- The tensor to cast
- target_dtype : torch dtype object
- The array dtype to cast all tensors to
- casting : str
- The casting mode, see `np.can_cast`
- Returns
- -------
- `torch.Tensor` of the `target_dtype` dtype
- Raises
- ------
- ValueError
- if the argument cannot be cast according to the `casting` rule
- """
- can_cast = _dtypes_impl.can_cast_impl
- if not can_cast(t.dtype, target_dtype, casting=casting):
- raise TypeError(
- f"Cannot cast array data from {t.dtype} to"
- f" {target_dtype} according to the rule '{casting}'"
- )
- return cast_if_needed(t, target_dtype)
- def typecast_tensors(tensors, target_dtype, casting):
- return tuple(typecast_tensor(t, target_dtype, casting) for t in tensors)
- def _try_convert_to_tensor(obj):
- try:
- tensor = torch.as_tensor(obj)
- except Exception as e:
- mesg = f"failed to convert {obj} to ndarray. \nInternal error is: {str(e)}."
- raise NotImplementedError(mesg) # noqa: B904
- return tensor
- def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0):
- """The core logic of the array(...) function.
- Parameters
- ----------
- obj : tensor_like
- The thing to coerce
- dtype : torch.dtype object or None
- Coerce to this torch dtype
- copy : bool
- Copy or not
- ndmin : int
- The results as least this many dimensions
- is_weak : bool
- Whether obj is a weakly typed python scalar.
- Returns
- -------
- tensor : torch.Tensor
- a tensor object with requested dtype, ndim and copy semantics.
- Notes
- -----
- This is almost a "tensor_like" coersion function. Does not handle wrapper
- ndarrays (those should be handled in the ndarray-aware layer prior to
- invoking this function).
- """
- if isinstance(obj, torch.Tensor):
- tensor = obj
- else:
- # tensor.dtype is the pytorch default, typically float32. If obj's elements
- # are not exactly representable in float32, we've lost precision:
- # >>> torch.as_tensor(1e12).item() - 1e12
- # -4096.0
- default_dtype = torch.get_default_dtype()
- torch.set_default_dtype(_dtypes_impl.get_default_dtype_for(torch.float32))
- try:
- tensor = _try_convert_to_tensor(obj)
- finally:
- torch.set_default_dtype(default_dtype)
- # type cast if requested
- tensor = cast_if_needed(tensor, dtype)
- # adjust ndim if needed
- ndim_extra = ndmin - tensor.ndim
- if ndim_extra > 0:
- tensor = tensor.view((1,) * ndim_extra + tensor.shape)
- # copy if requested
- if copy:
- tensor = tensor.clone()
- return tensor
- def ndarrays_to_tensors(*inputs):
- """Convert all ndarrays from `inputs` to tensors. (other things are intact)"""
- from ._ndarray import ndarray
- if len(inputs) == 0:
- return ValueError()
- elif len(inputs) == 1:
- input_ = inputs[0]
- if isinstance(input_, ndarray):
- return input_.tensor
- elif isinstance(input_, tuple):
- result = []
- for sub_input in input_:
- sub_result = ndarrays_to_tensors(sub_input)
- result.append(sub_result)
- return tuple(result)
- else:
- return input_
- else:
- assert isinstance(inputs, tuple) # sanity check
- return ndarrays_to_tensors(inputs)
|