| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239 |
- # mypy: ignore-errors
- from __future__ import annotations
- import functools
- import math
- from typing import Sequence
- import torch
- from . import _dtypes_impl, _util
- from ._normalizations import ArrayLike, KeepDims, normalizer
- class LinAlgError(Exception):
- pass
- def _atleast_float_1(a):
- if not (a.dtype.is_floating_point or a.dtype.is_complex):
- a = a.to(_dtypes_impl.default_dtypes().float_dtype)
- return a
- def _atleast_float_2(a, b):
- dtyp = _dtypes_impl.result_type_impl(a, b)
- if not (dtyp.is_floating_point or dtyp.is_complex):
- dtyp = _dtypes_impl.default_dtypes().float_dtype
- a = _util.cast_if_needed(a, dtyp)
- b = _util.cast_if_needed(b, dtyp)
- return a, b
- def linalg_errors(func):
- @functools.wraps(func)
- def wrapped(*args, **kwds):
- try:
- return func(*args, **kwds)
- except torch._C._LinAlgError as e:
- raise LinAlgError(*e.args) # noqa: B904
- return wrapped
- # ### Matrix and vector products ###
- @normalizer
- @linalg_errors
- def matrix_power(a: ArrayLike, n):
- a = _atleast_float_1(a)
- return torch.linalg.matrix_power(a, n)
- @normalizer
- @linalg_errors
- def multi_dot(inputs: Sequence[ArrayLike], *, out=None):
- return torch.linalg.multi_dot(inputs)
- # ### Solving equations and inverting matrices ###
- @normalizer
- @linalg_errors
- def solve(a: ArrayLike, b: ArrayLike):
- a, b = _atleast_float_2(a, b)
- return torch.linalg.solve(a, b)
- @normalizer
- @linalg_errors
- def lstsq(a: ArrayLike, b: ArrayLike, rcond=None):
- a, b = _atleast_float_2(a, b)
- # NumPy is using gelsd: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/umath_linalg.cpp#L3991
- # on CUDA, only `gels` is available though, so use it instead
- driver = "gels" if a.is_cuda or b.is_cuda else "gelsd"
- return torch.linalg.lstsq(a, b, rcond=rcond, driver=driver)
- @normalizer
- @linalg_errors
- def inv(a: ArrayLike):
- a = _atleast_float_1(a)
- result = torch.linalg.inv(a)
- return result
- @normalizer
- @linalg_errors
- def pinv(a: ArrayLike, rcond=1e-15, hermitian=False):
- a = _atleast_float_1(a)
- return torch.linalg.pinv(a, rtol=rcond, hermitian=hermitian)
- @normalizer
- @linalg_errors
- def tensorsolve(a: ArrayLike, b: ArrayLike, axes=None):
- a, b = _atleast_float_2(a, b)
- return torch.linalg.tensorsolve(a, b, dims=axes)
- @normalizer
- @linalg_errors
- def tensorinv(a: ArrayLike, ind=2):
- a = _atleast_float_1(a)
- return torch.linalg.tensorinv(a, ind=ind)
- # ### Norms and other numbers ###
- @normalizer
- @linalg_errors
- def det(a: ArrayLike):
- a = _atleast_float_1(a)
- return torch.linalg.det(a)
- @normalizer
- @linalg_errors
- def slogdet(a: ArrayLike):
- a = _atleast_float_1(a)
- return torch.linalg.slogdet(a)
- @normalizer
- @linalg_errors
- def cond(x: ArrayLike, p=None):
- x = _atleast_float_1(x)
- # check if empty
- # cf: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1744
- if x.numel() == 0 and math.prod(x.shape[-2:]) == 0:
- raise LinAlgError("cond is not defined on empty arrays")
- result = torch.linalg.cond(x, p=p)
- # Convert nans to infs (numpy does it in a data-dependent way, depending on
- # whether the input array has nans or not)
- # XXX: NumPy does this: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1744
- return torch.where(torch.isnan(result), float("inf"), result)
- @normalizer
- @linalg_errors
- def matrix_rank(a: ArrayLike, tol=None, hermitian=False):
- a = _atleast_float_1(a)
- if a.ndim < 2:
- return int((a != 0).any())
- if tol is None:
- # follow https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1885
- atol = 0
- rtol = max(a.shape[-2:]) * torch.finfo(a.dtype).eps
- else:
- atol, rtol = tol, 0
- return torch.linalg.matrix_rank(a, atol=atol, rtol=rtol, hermitian=hermitian)
- @normalizer
- @linalg_errors
- def norm(x: ArrayLike, ord=None, axis=None, keepdims: KeepDims = False):
- x = _atleast_float_1(x)
- return torch.linalg.norm(x, ord=ord, dim=axis)
- # ### Decompositions ###
- @normalizer
- @linalg_errors
- def cholesky(a: ArrayLike):
- a = _atleast_float_1(a)
- return torch.linalg.cholesky(a)
- @normalizer
- @linalg_errors
- def qr(a: ArrayLike, mode="reduced"):
- a = _atleast_float_1(a)
- result = torch.linalg.qr(a, mode=mode)
- if mode == "r":
- # match NumPy
- result = result.R
- return result
- @normalizer
- @linalg_errors
- def svd(a: ArrayLike, full_matrices=True, compute_uv=True, hermitian=False):
- a = _atleast_float_1(a)
- if not compute_uv:
- return torch.linalg.svdvals(a)
- # NB: ignore the hermitian= argument (no pytorch equivalent)
- result = torch.linalg.svd(a, full_matrices=full_matrices)
- return result
- # ### Eigenvalues and eigenvectors ###
- @normalizer
- @linalg_errors
- def eig(a: ArrayLike):
- a = _atleast_float_1(a)
- w, vt = torch.linalg.eig(a)
- if not a.is_complex() and w.is_complex() and (w.imag == 0).all():
- w = w.real
- vt = vt.real
- return w, vt
- @normalizer
- @linalg_errors
- def eigh(a: ArrayLike, UPLO="L"):
- a = _atleast_float_1(a)
- return torch.linalg.eigh(a, UPLO=UPLO)
- @normalizer
- @linalg_errors
- def eigvals(a: ArrayLike):
- a = _atleast_float_1(a)
- result = torch.linalg.eigvals(a)
- if not a.is_complex() and result.is_complex() and (result.imag == 0).all():
- result = result.real
- return result
- @normalizer
- @linalg_errors
- def eigvalsh(a: ArrayLike, UPLO="L"):
- a = _atleast_float_1(a)
- return torch.linalg.eigvalsh(a, UPLO=UPLO)
|