linalg.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. # mypy: ignore-errors
  2. from __future__ import annotations
  3. import functools
  4. import math
  5. from typing import Sequence
  6. import torch
  7. from . import _dtypes_impl, _util
  8. from ._normalizations import ArrayLike, KeepDims, normalizer
  9. class LinAlgError(Exception):
  10. pass
  11. def _atleast_float_1(a):
  12. if not (a.dtype.is_floating_point or a.dtype.is_complex):
  13. a = a.to(_dtypes_impl.default_dtypes().float_dtype)
  14. return a
  15. def _atleast_float_2(a, b):
  16. dtyp = _dtypes_impl.result_type_impl(a, b)
  17. if not (dtyp.is_floating_point or dtyp.is_complex):
  18. dtyp = _dtypes_impl.default_dtypes().float_dtype
  19. a = _util.cast_if_needed(a, dtyp)
  20. b = _util.cast_if_needed(b, dtyp)
  21. return a, b
  22. def linalg_errors(func):
  23. @functools.wraps(func)
  24. def wrapped(*args, **kwds):
  25. try:
  26. return func(*args, **kwds)
  27. except torch._C._LinAlgError as e:
  28. raise LinAlgError(*e.args) # noqa: B904
  29. return wrapped
  30. # ### Matrix and vector products ###
  31. @normalizer
  32. @linalg_errors
  33. def matrix_power(a: ArrayLike, n):
  34. a = _atleast_float_1(a)
  35. return torch.linalg.matrix_power(a, n)
  36. @normalizer
  37. @linalg_errors
  38. def multi_dot(inputs: Sequence[ArrayLike], *, out=None):
  39. return torch.linalg.multi_dot(inputs)
  40. # ### Solving equations and inverting matrices ###
  41. @normalizer
  42. @linalg_errors
  43. def solve(a: ArrayLike, b: ArrayLike):
  44. a, b = _atleast_float_2(a, b)
  45. return torch.linalg.solve(a, b)
  46. @normalizer
  47. @linalg_errors
  48. def lstsq(a: ArrayLike, b: ArrayLike, rcond=None):
  49. a, b = _atleast_float_2(a, b)
  50. # NumPy is using gelsd: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/umath_linalg.cpp#L3991
  51. # on CUDA, only `gels` is available though, so use it instead
  52. driver = "gels" if a.is_cuda or b.is_cuda else "gelsd"
  53. return torch.linalg.lstsq(a, b, rcond=rcond, driver=driver)
  54. @normalizer
  55. @linalg_errors
  56. def inv(a: ArrayLike):
  57. a = _atleast_float_1(a)
  58. result = torch.linalg.inv(a)
  59. return result
  60. @normalizer
  61. @linalg_errors
  62. def pinv(a: ArrayLike, rcond=1e-15, hermitian=False):
  63. a = _atleast_float_1(a)
  64. return torch.linalg.pinv(a, rtol=rcond, hermitian=hermitian)
  65. @normalizer
  66. @linalg_errors
  67. def tensorsolve(a: ArrayLike, b: ArrayLike, axes=None):
  68. a, b = _atleast_float_2(a, b)
  69. return torch.linalg.tensorsolve(a, b, dims=axes)
  70. @normalizer
  71. @linalg_errors
  72. def tensorinv(a: ArrayLike, ind=2):
  73. a = _atleast_float_1(a)
  74. return torch.linalg.tensorinv(a, ind=ind)
  75. # ### Norms and other numbers ###
  76. @normalizer
  77. @linalg_errors
  78. def det(a: ArrayLike):
  79. a = _atleast_float_1(a)
  80. return torch.linalg.det(a)
  81. @normalizer
  82. @linalg_errors
  83. def slogdet(a: ArrayLike):
  84. a = _atleast_float_1(a)
  85. return torch.linalg.slogdet(a)
  86. @normalizer
  87. @linalg_errors
  88. def cond(x: ArrayLike, p=None):
  89. x = _atleast_float_1(x)
  90. # check if empty
  91. # cf: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1744
  92. if x.numel() == 0 and math.prod(x.shape[-2:]) == 0:
  93. raise LinAlgError("cond is not defined on empty arrays")
  94. result = torch.linalg.cond(x, p=p)
  95. # Convert nans to infs (numpy does it in a data-dependent way, depending on
  96. # whether the input array has nans or not)
  97. # XXX: NumPy does this: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1744
  98. return torch.where(torch.isnan(result), float("inf"), result)
  99. @normalizer
  100. @linalg_errors
  101. def matrix_rank(a: ArrayLike, tol=None, hermitian=False):
  102. a = _atleast_float_1(a)
  103. if a.ndim < 2:
  104. return int((a != 0).any())
  105. if tol is None:
  106. # follow https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1885
  107. atol = 0
  108. rtol = max(a.shape[-2:]) * torch.finfo(a.dtype).eps
  109. else:
  110. atol, rtol = tol, 0
  111. return torch.linalg.matrix_rank(a, atol=atol, rtol=rtol, hermitian=hermitian)
  112. @normalizer
  113. @linalg_errors
  114. def norm(x: ArrayLike, ord=None, axis=None, keepdims: KeepDims = False):
  115. x = _atleast_float_1(x)
  116. return torch.linalg.norm(x, ord=ord, dim=axis)
  117. # ### Decompositions ###
  118. @normalizer
  119. @linalg_errors
  120. def cholesky(a: ArrayLike):
  121. a = _atleast_float_1(a)
  122. return torch.linalg.cholesky(a)
  123. @normalizer
  124. @linalg_errors
  125. def qr(a: ArrayLike, mode="reduced"):
  126. a = _atleast_float_1(a)
  127. result = torch.linalg.qr(a, mode=mode)
  128. if mode == "r":
  129. # match NumPy
  130. result = result.R
  131. return result
  132. @normalizer
  133. @linalg_errors
  134. def svd(a: ArrayLike, full_matrices=True, compute_uv=True, hermitian=False):
  135. a = _atleast_float_1(a)
  136. if not compute_uv:
  137. return torch.linalg.svdvals(a)
  138. # NB: ignore the hermitian= argument (no pytorch equivalent)
  139. result = torch.linalg.svd(a, full_matrices=full_matrices)
  140. return result
  141. # ### Eigenvalues and eigenvectors ###
  142. @normalizer
  143. @linalg_errors
  144. def eig(a: ArrayLike):
  145. a = _atleast_float_1(a)
  146. w, vt = torch.linalg.eig(a)
  147. if not a.is_complex() and w.is_complex() and (w.imag == 0).all():
  148. w = w.real
  149. vt = vt.real
  150. return w, vt
  151. @normalizer
  152. @linalg_errors
  153. def eigh(a: ArrayLike, UPLO="L"):
  154. a = _atleast_float_1(a)
  155. return torch.linalg.eigh(a, UPLO=UPLO)
  156. @normalizer
  157. @linalg_errors
  158. def eigvals(a: ArrayLike):
  159. a = _atleast_float_1(a)
  160. result = torch.linalg.eigvals(a)
  161. if not a.is_complex() and result.is_complex() and (result.imag == 0).all():
  162. result = result.real
  163. return result
  164. @normalizer
  165. @linalg_errors
  166. def eigvalsh(a: ArrayLike, UPLO="L"):
  167. a = _atleast_float_1(a)
  168. return torch.linalg.eigvalsh(a, UPLO=UPLO)