_omp.py 37 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108
  1. """Orthogonal matching pursuit algorithms
  2. """
  3. # Author: Vlad Niculae
  4. #
  5. # License: BSD 3 clause
  6. import warnings
  7. from math import sqrt
  8. from numbers import Integral, Real
  9. import numpy as np
  10. from scipy import linalg
  11. from scipy.linalg.lapack import get_lapack_funcs
  12. from ..base import MultiOutputMixin, RegressorMixin, _fit_context
  13. from ..model_selection import check_cv
  14. from ..utils import as_float_array, check_array
  15. from ..utils._param_validation import Hidden, Interval, StrOptions, validate_params
  16. from ..utils.parallel import Parallel, delayed
  17. from ._base import LinearModel, _deprecate_normalize, _pre_fit
  18. premature = (
  19. "Orthogonal matching pursuit ended prematurely due to linear"
  20. " dependence in the dictionary. The requested precision might"
  21. " not have been met."
  22. )
  23. def _cholesky_omp(X, y, n_nonzero_coefs, tol=None, copy_X=True, return_path=False):
  24. """Orthogonal Matching Pursuit step using the Cholesky decomposition.
  25. Parameters
  26. ----------
  27. X : ndarray of shape (n_samples, n_features)
  28. Input dictionary. Columns are assumed to have unit norm.
  29. y : ndarray of shape (n_samples,)
  30. Input targets.
  31. n_nonzero_coefs : int
  32. Targeted number of non-zero elements.
  33. tol : float, default=None
  34. Targeted squared error, if not None overrides n_nonzero_coefs.
  35. copy_X : bool, default=True
  36. Whether the design matrix X must be copied by the algorithm. A false
  37. value is only helpful if X is already Fortran-ordered, otherwise a
  38. copy is made anyway.
  39. return_path : bool, default=False
  40. Whether to return every value of the nonzero coefficients along the
  41. forward path. Useful for cross-validation.
  42. Returns
  43. -------
  44. gamma : ndarray of shape (n_nonzero_coefs,)
  45. Non-zero elements of the solution.
  46. idx : ndarray of shape (n_nonzero_coefs,)
  47. Indices of the positions of the elements in gamma within the solution
  48. vector.
  49. coef : ndarray of shape (n_features, n_nonzero_coefs)
  50. The first k values of column k correspond to the coefficient value
  51. for the active features at that step. The lower left triangle contains
  52. garbage. Only returned if ``return_path=True``.
  53. n_active : int
  54. Number of active features at convergence.
  55. """
  56. if copy_X:
  57. X = X.copy("F")
  58. else: # even if we are allowed to overwrite, still copy it if bad order
  59. X = np.asfortranarray(X)
  60. min_float = np.finfo(X.dtype).eps
  61. nrm2, swap = linalg.get_blas_funcs(("nrm2", "swap"), (X,))
  62. (potrs,) = get_lapack_funcs(("potrs",), (X,))
  63. alpha = np.dot(X.T, y)
  64. residual = y
  65. gamma = np.empty(0)
  66. n_active = 0
  67. indices = np.arange(X.shape[1]) # keeping track of swapping
  68. max_features = X.shape[1] if tol is not None else n_nonzero_coefs
  69. L = np.empty((max_features, max_features), dtype=X.dtype)
  70. if return_path:
  71. coefs = np.empty_like(L)
  72. while True:
  73. lam = np.argmax(np.abs(np.dot(X.T, residual)))
  74. if lam < n_active or alpha[lam] ** 2 < min_float:
  75. # atom already selected or inner product too small
  76. warnings.warn(premature, RuntimeWarning, stacklevel=2)
  77. break
  78. if n_active > 0:
  79. # Updates the Cholesky decomposition of X' X
  80. L[n_active, :n_active] = np.dot(X[:, :n_active].T, X[:, lam])
  81. linalg.solve_triangular(
  82. L[:n_active, :n_active],
  83. L[n_active, :n_active],
  84. trans=0,
  85. lower=1,
  86. overwrite_b=True,
  87. check_finite=False,
  88. )
  89. v = nrm2(L[n_active, :n_active]) ** 2
  90. Lkk = linalg.norm(X[:, lam]) ** 2 - v
  91. if Lkk <= min_float: # selected atoms are dependent
  92. warnings.warn(premature, RuntimeWarning, stacklevel=2)
  93. break
  94. L[n_active, n_active] = sqrt(Lkk)
  95. else:
  96. L[0, 0] = linalg.norm(X[:, lam])
  97. X.T[n_active], X.T[lam] = swap(X.T[n_active], X.T[lam])
  98. alpha[n_active], alpha[lam] = alpha[lam], alpha[n_active]
  99. indices[n_active], indices[lam] = indices[lam], indices[n_active]
  100. n_active += 1
  101. # solves LL'x = X'y as a composition of two triangular systems
  102. gamma, _ = potrs(
  103. L[:n_active, :n_active], alpha[:n_active], lower=True, overwrite_b=False
  104. )
  105. if return_path:
  106. coefs[:n_active, n_active - 1] = gamma
  107. residual = y - np.dot(X[:, :n_active], gamma)
  108. if tol is not None and nrm2(residual) ** 2 <= tol:
  109. break
  110. elif n_active == max_features:
  111. break
  112. if return_path:
  113. return gamma, indices[:n_active], coefs[:, :n_active], n_active
  114. else:
  115. return gamma, indices[:n_active], n_active
  116. def _gram_omp(
  117. Gram,
  118. Xy,
  119. n_nonzero_coefs,
  120. tol_0=None,
  121. tol=None,
  122. copy_Gram=True,
  123. copy_Xy=True,
  124. return_path=False,
  125. ):
  126. """Orthogonal Matching Pursuit step on a precomputed Gram matrix.
  127. This function uses the Cholesky decomposition method.
  128. Parameters
  129. ----------
  130. Gram : ndarray of shape (n_features, n_features)
  131. Gram matrix of the input data matrix.
  132. Xy : ndarray of shape (n_features,)
  133. Input targets.
  134. n_nonzero_coefs : int
  135. Targeted number of non-zero elements.
  136. tol_0 : float, default=None
  137. Squared norm of y, required if tol is not None.
  138. tol : float, default=None
  139. Targeted squared error, if not None overrides n_nonzero_coefs.
  140. copy_Gram : bool, default=True
  141. Whether the gram matrix must be copied by the algorithm. A false
  142. value is only helpful if it is already Fortran-ordered, otherwise a
  143. copy is made anyway.
  144. copy_Xy : bool, default=True
  145. Whether the covariance vector Xy must be copied by the algorithm.
  146. If False, it may be overwritten.
  147. return_path : bool, default=False
  148. Whether to return every value of the nonzero coefficients along the
  149. forward path. Useful for cross-validation.
  150. Returns
  151. -------
  152. gamma : ndarray of shape (n_nonzero_coefs,)
  153. Non-zero elements of the solution.
  154. idx : ndarray of shape (n_nonzero_coefs,)
  155. Indices of the positions of the elements in gamma within the solution
  156. vector.
  157. coefs : ndarray of shape (n_features, n_nonzero_coefs)
  158. The first k values of column k correspond to the coefficient value
  159. for the active features at that step. The lower left triangle contains
  160. garbage. Only returned if ``return_path=True``.
  161. n_active : int
  162. Number of active features at convergence.
  163. """
  164. Gram = Gram.copy("F") if copy_Gram else np.asfortranarray(Gram)
  165. if copy_Xy or not Xy.flags.writeable:
  166. Xy = Xy.copy()
  167. min_float = np.finfo(Gram.dtype).eps
  168. nrm2, swap = linalg.get_blas_funcs(("nrm2", "swap"), (Gram,))
  169. (potrs,) = get_lapack_funcs(("potrs",), (Gram,))
  170. indices = np.arange(len(Gram)) # keeping track of swapping
  171. alpha = Xy
  172. tol_curr = tol_0
  173. delta = 0
  174. gamma = np.empty(0)
  175. n_active = 0
  176. max_features = len(Gram) if tol is not None else n_nonzero_coefs
  177. L = np.empty((max_features, max_features), dtype=Gram.dtype)
  178. L[0, 0] = 1.0
  179. if return_path:
  180. coefs = np.empty_like(L)
  181. while True:
  182. lam = np.argmax(np.abs(alpha))
  183. if lam < n_active or alpha[lam] ** 2 < min_float:
  184. # selected same atom twice, or inner product too small
  185. warnings.warn(premature, RuntimeWarning, stacklevel=3)
  186. break
  187. if n_active > 0:
  188. L[n_active, :n_active] = Gram[lam, :n_active]
  189. linalg.solve_triangular(
  190. L[:n_active, :n_active],
  191. L[n_active, :n_active],
  192. trans=0,
  193. lower=1,
  194. overwrite_b=True,
  195. check_finite=False,
  196. )
  197. v = nrm2(L[n_active, :n_active]) ** 2
  198. Lkk = Gram[lam, lam] - v
  199. if Lkk <= min_float: # selected atoms are dependent
  200. warnings.warn(premature, RuntimeWarning, stacklevel=3)
  201. break
  202. L[n_active, n_active] = sqrt(Lkk)
  203. else:
  204. L[0, 0] = sqrt(Gram[lam, lam])
  205. Gram[n_active], Gram[lam] = swap(Gram[n_active], Gram[lam])
  206. Gram.T[n_active], Gram.T[lam] = swap(Gram.T[n_active], Gram.T[lam])
  207. indices[n_active], indices[lam] = indices[lam], indices[n_active]
  208. Xy[n_active], Xy[lam] = Xy[lam], Xy[n_active]
  209. n_active += 1
  210. # solves LL'x = X'y as a composition of two triangular systems
  211. gamma, _ = potrs(
  212. L[:n_active, :n_active], Xy[:n_active], lower=True, overwrite_b=False
  213. )
  214. if return_path:
  215. coefs[:n_active, n_active - 1] = gamma
  216. beta = np.dot(Gram[:, :n_active], gamma)
  217. alpha = Xy - beta
  218. if tol is not None:
  219. tol_curr += delta
  220. delta = np.inner(gamma, beta[:n_active])
  221. tol_curr -= delta
  222. if abs(tol_curr) <= tol:
  223. break
  224. elif n_active == max_features:
  225. break
  226. if return_path:
  227. return gamma, indices[:n_active], coefs[:, :n_active], n_active
  228. else:
  229. return gamma, indices[:n_active], n_active
  230. @validate_params(
  231. {
  232. "X": ["array-like"],
  233. "y": [np.ndarray],
  234. "n_nonzero_coefs": [Interval(Integral, 1, None, closed="left"), None],
  235. "tol": [Interval(Real, 0, None, closed="left"), None],
  236. "precompute": ["boolean", StrOptions({"auto"})],
  237. "copy_X": ["boolean"],
  238. "return_path": ["boolean"],
  239. "return_n_iter": ["boolean"],
  240. },
  241. prefer_skip_nested_validation=True,
  242. )
  243. def orthogonal_mp(
  244. X,
  245. y,
  246. *,
  247. n_nonzero_coefs=None,
  248. tol=None,
  249. precompute=False,
  250. copy_X=True,
  251. return_path=False,
  252. return_n_iter=False,
  253. ):
  254. r"""Orthogonal Matching Pursuit (OMP).
  255. Solves n_targets Orthogonal Matching Pursuit problems.
  256. An instance of the problem has the form:
  257. When parametrized by the number of non-zero coefficients using
  258. `n_nonzero_coefs`:
  259. argmin ||y - X\gamma||^2 subject to ||\gamma||_0 <= n_{nonzero coefs}
  260. When parametrized by error using the parameter `tol`:
  261. argmin ||\gamma||_0 subject to ||y - X\gamma||^2 <= tol
  262. Read more in the :ref:`User Guide <omp>`.
  263. Parameters
  264. ----------
  265. X : array-like of shape (n_samples, n_features)
  266. Input data. Columns are assumed to have unit norm.
  267. y : ndarray of shape (n_samples,) or (n_samples, n_targets)
  268. Input targets.
  269. n_nonzero_coefs : int, default=None
  270. Desired number of non-zero entries in the solution. If None (by
  271. default) this value is set to 10% of n_features.
  272. tol : float, default=None
  273. Maximum squared norm of the residual. If not None, overrides n_nonzero_coefs.
  274. precompute : 'auto' or bool, default=False
  275. Whether to perform precomputations. Improves performance when n_targets
  276. or n_samples is very large.
  277. copy_X : bool, default=True
  278. Whether the design matrix X must be copied by the algorithm. A false
  279. value is only helpful if X is already Fortran-ordered, otherwise a
  280. copy is made anyway.
  281. return_path : bool, default=False
  282. Whether to return every value of the nonzero coefficients along the
  283. forward path. Useful for cross-validation.
  284. return_n_iter : bool, default=False
  285. Whether or not to return the number of iterations.
  286. Returns
  287. -------
  288. coef : ndarray of shape (n_features,) or (n_features, n_targets)
  289. Coefficients of the OMP solution. If `return_path=True`, this contains
  290. the whole coefficient path. In this case its shape is
  291. (n_features, n_features) or (n_features, n_targets, n_features) and
  292. iterating over the last axis generates coefficients in increasing order
  293. of active features.
  294. n_iters : array-like or int
  295. Number of active features across every target. Returned only if
  296. `return_n_iter` is set to True.
  297. See Also
  298. --------
  299. OrthogonalMatchingPursuit : Orthogonal Matching Pursuit model.
  300. orthogonal_mp_gram : Solve OMP problems using Gram matrix and the product X.T * y.
  301. lars_path : Compute Least Angle Regression or Lasso path using LARS algorithm.
  302. sklearn.decomposition.sparse_encode : Sparse coding.
  303. Notes
  304. -----
  305. Orthogonal matching pursuit was introduced in S. Mallat, Z. Zhang,
  306. Matching pursuits with time-frequency dictionaries, IEEE Transactions on
  307. Signal Processing, Vol. 41, No. 12. (December 1993), pp. 3397-3415.
  308. (https://www.di.ens.fr/~mallat/papiers/MallatPursuit93.pdf)
  309. This implementation is based on Rubinstein, R., Zibulevsky, M. and Elad,
  310. M., Efficient Implementation of the K-SVD Algorithm using Batch Orthogonal
  311. Matching Pursuit Technical Report - CS Technion, April 2008.
  312. https://www.cs.technion.ac.il/~ronrubin/Publications/KSVD-OMP-v2.pdf
  313. """
  314. X = check_array(X, order="F", copy=copy_X)
  315. copy_X = False
  316. if y.ndim == 1:
  317. y = y.reshape(-1, 1)
  318. y = check_array(y)
  319. if y.shape[1] > 1: # subsequent targets will be affected
  320. copy_X = True
  321. if n_nonzero_coefs is None and tol is None:
  322. # default for n_nonzero_coefs is 0.1 * n_features
  323. # but at least one.
  324. n_nonzero_coefs = max(int(0.1 * X.shape[1]), 1)
  325. if tol is None and n_nonzero_coefs > X.shape[1]:
  326. raise ValueError(
  327. "The number of atoms cannot be more than the number of features"
  328. )
  329. if precompute == "auto":
  330. precompute = X.shape[0] > X.shape[1]
  331. if precompute:
  332. G = np.dot(X.T, X)
  333. G = np.asfortranarray(G)
  334. Xy = np.dot(X.T, y)
  335. if tol is not None:
  336. norms_squared = np.sum((y**2), axis=0)
  337. else:
  338. norms_squared = None
  339. return orthogonal_mp_gram(
  340. G,
  341. Xy,
  342. n_nonzero_coefs=n_nonzero_coefs,
  343. tol=tol,
  344. norms_squared=norms_squared,
  345. copy_Gram=copy_X,
  346. copy_Xy=False,
  347. return_path=return_path,
  348. )
  349. if return_path:
  350. coef = np.zeros((X.shape[1], y.shape[1], X.shape[1]))
  351. else:
  352. coef = np.zeros((X.shape[1], y.shape[1]))
  353. n_iters = []
  354. for k in range(y.shape[1]):
  355. out = _cholesky_omp(
  356. X, y[:, k], n_nonzero_coefs, tol, copy_X=copy_X, return_path=return_path
  357. )
  358. if return_path:
  359. _, idx, coefs, n_iter = out
  360. coef = coef[:, :, : len(idx)]
  361. for n_active, x in enumerate(coefs.T):
  362. coef[idx[: n_active + 1], k, n_active] = x[: n_active + 1]
  363. else:
  364. x, idx, n_iter = out
  365. coef[idx, k] = x
  366. n_iters.append(n_iter)
  367. if y.shape[1] == 1:
  368. n_iters = n_iters[0]
  369. if return_n_iter:
  370. return np.squeeze(coef), n_iters
  371. else:
  372. return np.squeeze(coef)
  373. def orthogonal_mp_gram(
  374. Gram,
  375. Xy,
  376. *,
  377. n_nonzero_coefs=None,
  378. tol=None,
  379. norms_squared=None,
  380. copy_Gram=True,
  381. copy_Xy=True,
  382. return_path=False,
  383. return_n_iter=False,
  384. ):
  385. """Gram Orthogonal Matching Pursuit (OMP).
  386. Solves n_targets Orthogonal Matching Pursuit problems using only
  387. the Gram matrix X.T * X and the product X.T * y.
  388. Read more in the :ref:`User Guide <omp>`.
  389. Parameters
  390. ----------
  391. Gram : ndarray of shape (n_features, n_features)
  392. Gram matrix of the input data: X.T * X.
  393. Xy : ndarray of shape (n_features,) or (n_features, n_targets)
  394. Input targets multiplied by X: X.T * y.
  395. n_nonzero_coefs : int, default=None
  396. Desired number of non-zero entries in the solution. If None (by
  397. default) this value is set to 10% of n_features.
  398. tol : float, default=None
  399. Maximum squared norm of the residual. If not `None`,
  400. overrides `n_nonzero_coefs`.
  401. norms_squared : array-like of shape (n_targets,), default=None
  402. Squared L2 norms of the lines of y. Required if tol is not None.
  403. copy_Gram : bool, default=True
  404. Whether the gram matrix must be copied by the algorithm. A false
  405. value is only helpful if it is already Fortran-ordered, otherwise a
  406. copy is made anyway.
  407. copy_Xy : bool, default=True
  408. Whether the covariance vector Xy must be copied by the algorithm.
  409. If False, it may be overwritten.
  410. return_path : bool, default=False
  411. Whether to return every value of the nonzero coefficients along the
  412. forward path. Useful for cross-validation.
  413. return_n_iter : bool, default=False
  414. Whether or not to return the number of iterations.
  415. Returns
  416. -------
  417. coef : ndarray of shape (n_features,) or (n_features, n_targets)
  418. Coefficients of the OMP solution. If `return_path=True`, this contains
  419. the whole coefficient path. In this case its shape is
  420. (n_features, n_features) or (n_features, n_targets, n_features) and
  421. iterating over the last axis yields coefficients in increasing order
  422. of active features.
  423. n_iters : array-like or int
  424. Number of active features across every target. Returned only if
  425. `return_n_iter` is set to True.
  426. See Also
  427. --------
  428. OrthogonalMatchingPursuit : Orthogonal Matching Pursuit model (OMP).
  429. orthogonal_mp : Solves n_targets Orthogonal Matching Pursuit problems.
  430. lars_path : Compute Least Angle Regression or Lasso path using
  431. LARS algorithm.
  432. sklearn.decomposition.sparse_encode : Generic sparse coding.
  433. Each column of the result is the solution to a Lasso problem.
  434. Notes
  435. -----
  436. Orthogonal matching pursuit was introduced in G. Mallat, Z. Zhang,
  437. Matching pursuits with time-frequency dictionaries, IEEE Transactions on
  438. Signal Processing, Vol. 41, No. 12. (December 1993), pp. 3397-3415.
  439. (https://www.di.ens.fr/~mallat/papiers/MallatPursuit93.pdf)
  440. This implementation is based on Rubinstein, R., Zibulevsky, M. and Elad,
  441. M., Efficient Implementation of the K-SVD Algorithm using Batch Orthogonal
  442. Matching Pursuit Technical Report - CS Technion, April 2008.
  443. https://www.cs.technion.ac.il/~ronrubin/Publications/KSVD-OMP-v2.pdf
  444. """
  445. Gram = check_array(Gram, order="F", copy=copy_Gram)
  446. Xy = np.asarray(Xy)
  447. if Xy.ndim > 1 and Xy.shape[1] > 1:
  448. # or subsequent target will be affected
  449. copy_Gram = True
  450. if Xy.ndim == 1:
  451. Xy = Xy[:, np.newaxis]
  452. if tol is not None:
  453. norms_squared = [norms_squared]
  454. if copy_Xy or not Xy.flags.writeable:
  455. # Make the copy once instead of many times in _gram_omp itself.
  456. Xy = Xy.copy()
  457. if n_nonzero_coefs is None and tol is None:
  458. n_nonzero_coefs = int(0.1 * len(Gram))
  459. if tol is not None and norms_squared is None:
  460. raise ValueError(
  461. "Gram OMP needs the precomputed norms in order "
  462. "to evaluate the error sum of squares."
  463. )
  464. if tol is not None and tol < 0:
  465. raise ValueError("Epsilon cannot be negative")
  466. if tol is None and n_nonzero_coefs <= 0:
  467. raise ValueError("The number of atoms must be positive")
  468. if tol is None and n_nonzero_coefs > len(Gram):
  469. raise ValueError(
  470. "The number of atoms cannot be more than the number of features"
  471. )
  472. if return_path:
  473. coef = np.zeros((len(Gram), Xy.shape[1], len(Gram)), dtype=Gram.dtype)
  474. else:
  475. coef = np.zeros((len(Gram), Xy.shape[1]), dtype=Gram.dtype)
  476. n_iters = []
  477. for k in range(Xy.shape[1]):
  478. out = _gram_omp(
  479. Gram,
  480. Xy[:, k],
  481. n_nonzero_coefs,
  482. norms_squared[k] if tol is not None else None,
  483. tol,
  484. copy_Gram=copy_Gram,
  485. copy_Xy=False,
  486. return_path=return_path,
  487. )
  488. if return_path:
  489. _, idx, coefs, n_iter = out
  490. coef = coef[:, :, : len(idx)]
  491. for n_active, x in enumerate(coefs.T):
  492. coef[idx[: n_active + 1], k, n_active] = x[: n_active + 1]
  493. else:
  494. x, idx, n_iter = out
  495. coef[idx, k] = x
  496. n_iters.append(n_iter)
  497. if Xy.shape[1] == 1:
  498. n_iters = n_iters[0]
  499. if return_n_iter:
  500. return np.squeeze(coef), n_iters
  501. else:
  502. return np.squeeze(coef)
  503. class OrthogonalMatchingPursuit(MultiOutputMixin, RegressorMixin, LinearModel):
  504. """Orthogonal Matching Pursuit model (OMP).
  505. Read more in the :ref:`User Guide <omp>`.
  506. Parameters
  507. ----------
  508. n_nonzero_coefs : int, default=None
  509. Desired number of non-zero entries in the solution. If None (by
  510. default) this value is set to 10% of n_features.
  511. tol : float, default=None
  512. Maximum squared norm of the residual. If not None, overrides n_nonzero_coefs.
  513. fit_intercept : bool, default=True
  514. Whether to calculate the intercept for this model. If set
  515. to false, no intercept will be used in calculations
  516. (i.e. data is expected to be centered).
  517. normalize : bool, default=False
  518. This parameter is ignored when ``fit_intercept`` is set to False.
  519. If True, the regressors X will be normalized before regression by
  520. subtracting the mean and dividing by the l2-norm.
  521. If you wish to standardize, please use
  522. :class:`~sklearn.preprocessing.StandardScaler` before calling ``fit``
  523. on an estimator with ``normalize=False``.
  524. .. versionchanged:: 1.2
  525. default changed from True to False in 1.2.
  526. .. deprecated:: 1.2
  527. ``normalize`` was deprecated in version 1.2 and will be removed in 1.4.
  528. precompute : 'auto' or bool, default='auto'
  529. Whether to use a precomputed Gram and Xy matrix to speed up
  530. calculations. Improves performance when :term:`n_targets` or
  531. :term:`n_samples` is very large. Note that if you already have such
  532. matrices, you can pass them directly to the fit method.
  533. Attributes
  534. ----------
  535. coef_ : ndarray of shape (n_features,) or (n_targets, n_features)
  536. Parameter vector (w in the formula).
  537. intercept_ : float or ndarray of shape (n_targets,)
  538. Independent term in decision function.
  539. n_iter_ : int or array-like
  540. Number of active features across every target.
  541. n_nonzero_coefs_ : int
  542. The number of non-zero coefficients in the solution. If
  543. `n_nonzero_coefs` is None and `tol` is None this value is either set
  544. to 10% of `n_features` or 1, whichever is greater.
  545. n_features_in_ : int
  546. Number of features seen during :term:`fit`.
  547. .. versionadded:: 0.24
  548. feature_names_in_ : ndarray of shape (`n_features_in_`,)
  549. Names of features seen during :term:`fit`. Defined only when `X`
  550. has feature names that are all strings.
  551. .. versionadded:: 1.0
  552. See Also
  553. --------
  554. orthogonal_mp : Solves n_targets Orthogonal Matching Pursuit problems.
  555. orthogonal_mp_gram : Solves n_targets Orthogonal Matching Pursuit
  556. problems using only the Gram matrix X.T * X and the product X.T * y.
  557. lars_path : Compute Least Angle Regression or Lasso path using LARS algorithm.
  558. Lars : Least Angle Regression model a.k.a. LAR.
  559. LassoLars : Lasso model fit with Least Angle Regression a.k.a. Lars.
  560. sklearn.decomposition.sparse_encode : Generic sparse coding.
  561. Each column of the result is the solution to a Lasso problem.
  562. OrthogonalMatchingPursuitCV : Cross-validated
  563. Orthogonal Matching Pursuit model (OMP).
  564. Notes
  565. -----
  566. Orthogonal matching pursuit was introduced in G. Mallat, Z. Zhang,
  567. Matching pursuits with time-frequency dictionaries, IEEE Transactions on
  568. Signal Processing, Vol. 41, No. 12. (December 1993), pp. 3397-3415.
  569. (https://www.di.ens.fr/~mallat/papiers/MallatPursuit93.pdf)
  570. This implementation is based on Rubinstein, R., Zibulevsky, M. and Elad,
  571. M., Efficient Implementation of the K-SVD Algorithm using Batch Orthogonal
  572. Matching Pursuit Technical Report - CS Technion, April 2008.
  573. https://www.cs.technion.ac.il/~ronrubin/Publications/KSVD-OMP-v2.pdf
  574. Examples
  575. --------
  576. >>> from sklearn.linear_model import OrthogonalMatchingPursuit
  577. >>> from sklearn.datasets import make_regression
  578. >>> X, y = make_regression(noise=4, random_state=0)
  579. >>> reg = OrthogonalMatchingPursuit().fit(X, y)
  580. >>> reg.score(X, y)
  581. 0.9991...
  582. >>> reg.predict(X[:1,])
  583. array([-78.3854...])
  584. """
  585. _parameter_constraints: dict = {
  586. "n_nonzero_coefs": [Interval(Integral, 1, None, closed="left"), None],
  587. "tol": [Interval(Real, 0, None, closed="left"), None],
  588. "fit_intercept": ["boolean"],
  589. "normalize": ["boolean", Hidden(StrOptions({"deprecated"}))],
  590. "precompute": [StrOptions({"auto"}), "boolean"],
  591. }
  592. def __init__(
  593. self,
  594. *,
  595. n_nonzero_coefs=None,
  596. tol=None,
  597. fit_intercept=True,
  598. normalize="deprecated",
  599. precompute="auto",
  600. ):
  601. self.n_nonzero_coefs = n_nonzero_coefs
  602. self.tol = tol
  603. self.fit_intercept = fit_intercept
  604. self.normalize = normalize
  605. self.precompute = precompute
  606. @_fit_context(prefer_skip_nested_validation=True)
  607. def fit(self, X, y):
  608. """Fit the model using X, y as training data.
  609. Parameters
  610. ----------
  611. X : array-like of shape (n_samples, n_features)
  612. Training data.
  613. y : array-like of shape (n_samples,) or (n_samples, n_targets)
  614. Target values. Will be cast to X's dtype if necessary.
  615. Returns
  616. -------
  617. self : object
  618. Returns an instance of self.
  619. """
  620. _normalize = _deprecate_normalize(
  621. self.normalize, estimator_name=self.__class__.__name__
  622. )
  623. X, y = self._validate_data(X, y, multi_output=True, y_numeric=True)
  624. n_features = X.shape[1]
  625. X, y, X_offset, y_offset, X_scale, Gram, Xy = _pre_fit(
  626. X, y, None, self.precompute, _normalize, self.fit_intercept, copy=True
  627. )
  628. if y.ndim == 1:
  629. y = y[:, np.newaxis]
  630. if self.n_nonzero_coefs is None and self.tol is None:
  631. # default for n_nonzero_coefs is 0.1 * n_features
  632. # but at least one.
  633. self.n_nonzero_coefs_ = max(int(0.1 * n_features), 1)
  634. else:
  635. self.n_nonzero_coefs_ = self.n_nonzero_coefs
  636. if Gram is False:
  637. coef_, self.n_iter_ = orthogonal_mp(
  638. X,
  639. y,
  640. n_nonzero_coefs=self.n_nonzero_coefs_,
  641. tol=self.tol,
  642. precompute=False,
  643. copy_X=True,
  644. return_n_iter=True,
  645. )
  646. else:
  647. norms_sq = np.sum(y**2, axis=0) if self.tol is not None else None
  648. coef_, self.n_iter_ = orthogonal_mp_gram(
  649. Gram,
  650. Xy=Xy,
  651. n_nonzero_coefs=self.n_nonzero_coefs_,
  652. tol=self.tol,
  653. norms_squared=norms_sq,
  654. copy_Gram=True,
  655. copy_Xy=True,
  656. return_n_iter=True,
  657. )
  658. self.coef_ = coef_.T
  659. self._set_intercept(X_offset, y_offset, X_scale)
  660. return self
  661. def _omp_path_residues(
  662. X_train,
  663. y_train,
  664. X_test,
  665. y_test,
  666. copy=True,
  667. fit_intercept=True,
  668. normalize=False,
  669. max_iter=100,
  670. ):
  671. """Compute the residues on left-out data for a full LARS path.
  672. Parameters
  673. ----------
  674. X_train : ndarray of shape (n_samples, n_features)
  675. The data to fit the LARS on.
  676. y_train : ndarray of shape (n_samples)
  677. The target variable to fit LARS on.
  678. X_test : ndarray of shape (n_samples, n_features)
  679. The data to compute the residues on.
  680. y_test : ndarray of shape (n_samples)
  681. The target variable to compute the residues on.
  682. copy : bool, default=True
  683. Whether X_train, X_test, y_train and y_test should be copied. If
  684. False, they may be overwritten.
  685. fit_intercept : bool, default=True
  686. Whether to calculate the intercept for this model. If set
  687. to false, no intercept will be used in calculations
  688. (i.e. data is expected to be centered).
  689. normalize : bool, default=False
  690. This parameter is ignored when ``fit_intercept`` is set to False.
  691. If True, the regressors X will be normalized before regression by
  692. subtracting the mean and dividing by the l2-norm.
  693. If you wish to standardize, please use
  694. :class:`~sklearn.preprocessing.StandardScaler` before calling ``fit``
  695. on an estimator with ``normalize=False``.
  696. .. versionchanged:: 1.2
  697. default changed from True to False in 1.2.
  698. .. deprecated:: 1.2
  699. ``normalize`` was deprecated in version 1.2 and will be removed in 1.4.
  700. max_iter : int, default=100
  701. Maximum numbers of iterations to perform, therefore maximum features
  702. to include. 100 by default.
  703. Returns
  704. -------
  705. residues : ndarray of shape (n_samples, max_features)
  706. Residues of the prediction on the test data.
  707. """
  708. if copy:
  709. X_train = X_train.copy()
  710. y_train = y_train.copy()
  711. X_test = X_test.copy()
  712. y_test = y_test.copy()
  713. if fit_intercept:
  714. X_mean = X_train.mean(axis=0)
  715. X_train -= X_mean
  716. X_test -= X_mean
  717. y_mean = y_train.mean(axis=0)
  718. y_train = as_float_array(y_train, copy=False)
  719. y_train -= y_mean
  720. y_test = as_float_array(y_test, copy=False)
  721. y_test -= y_mean
  722. if normalize:
  723. norms = np.sqrt(np.sum(X_train**2, axis=0))
  724. nonzeros = np.flatnonzero(norms)
  725. X_train[:, nonzeros] /= norms[nonzeros]
  726. coefs = orthogonal_mp(
  727. X_train,
  728. y_train,
  729. n_nonzero_coefs=max_iter,
  730. tol=None,
  731. precompute=False,
  732. copy_X=False,
  733. return_path=True,
  734. )
  735. if coefs.ndim == 1:
  736. coefs = coefs[:, np.newaxis]
  737. if normalize:
  738. coefs[nonzeros] /= norms[nonzeros][:, np.newaxis]
  739. return np.dot(coefs.T, X_test.T) - y_test
  740. class OrthogonalMatchingPursuitCV(RegressorMixin, LinearModel):
  741. """Cross-validated Orthogonal Matching Pursuit model (OMP).
  742. See glossary entry for :term:`cross-validation estimator`.
  743. Read more in the :ref:`User Guide <omp>`.
  744. Parameters
  745. ----------
  746. copy : bool, default=True
  747. Whether the design matrix X must be copied by the algorithm. A false
  748. value is only helpful if X is already Fortran-ordered, otherwise a
  749. copy is made anyway.
  750. fit_intercept : bool, default=True
  751. Whether to calculate the intercept for this model. If set
  752. to false, no intercept will be used in calculations
  753. (i.e. data is expected to be centered).
  754. normalize : bool, default=False
  755. This parameter is ignored when ``fit_intercept`` is set to False.
  756. If True, the regressors X will be normalized before regression by
  757. subtracting the mean and dividing by the l2-norm.
  758. If you wish to standardize, please use
  759. :class:`~sklearn.preprocessing.StandardScaler` before calling ``fit``
  760. on an estimator with ``normalize=False``.
  761. .. versionchanged:: 1.2
  762. default changed from True to False in 1.2.
  763. .. deprecated:: 1.2
  764. ``normalize`` was deprecated in version 1.2 and will be removed in 1.4.
  765. max_iter : int, default=None
  766. Maximum numbers of iterations to perform, therefore maximum features
  767. to include. 10% of ``n_features`` but at least 5 if available.
  768. cv : int, cross-validation generator or iterable, default=None
  769. Determines the cross-validation splitting strategy.
  770. Possible inputs for cv are:
  771. - None, to use the default 5-fold cross-validation,
  772. - integer, to specify the number of folds.
  773. - :term:`CV splitter`,
  774. - An iterable yielding (train, test) splits as arrays of indices.
  775. For integer/None inputs, :class:`~sklearn.model_selection.KFold` is used.
  776. Refer :ref:`User Guide <cross_validation>` for the various
  777. cross-validation strategies that can be used here.
  778. .. versionchanged:: 0.22
  779. ``cv`` default value if None changed from 3-fold to 5-fold.
  780. n_jobs : int, default=None
  781. Number of CPUs to use during the cross validation.
  782. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
  783. ``-1`` means using all processors. See :term:`Glossary <n_jobs>`
  784. for more details.
  785. verbose : bool or int, default=False
  786. Sets the verbosity amount.
  787. Attributes
  788. ----------
  789. intercept_ : float or ndarray of shape (n_targets,)
  790. Independent term in decision function.
  791. coef_ : ndarray of shape (n_features,) or (n_targets, n_features)
  792. Parameter vector (w in the problem formulation).
  793. n_nonzero_coefs_ : int
  794. Estimated number of non-zero coefficients giving the best mean squared
  795. error over the cross-validation folds.
  796. n_iter_ : int or array-like
  797. Number of active features across every target for the model refit with
  798. the best hyperparameters got by cross-validating across all folds.
  799. n_features_in_ : int
  800. Number of features seen during :term:`fit`.
  801. .. versionadded:: 0.24
  802. feature_names_in_ : ndarray of shape (`n_features_in_`,)
  803. Names of features seen during :term:`fit`. Defined only when `X`
  804. has feature names that are all strings.
  805. .. versionadded:: 1.0
  806. See Also
  807. --------
  808. orthogonal_mp : Solves n_targets Orthogonal Matching Pursuit problems.
  809. orthogonal_mp_gram : Solves n_targets Orthogonal Matching Pursuit
  810. problems using only the Gram matrix X.T * X and the product X.T * y.
  811. lars_path : Compute Least Angle Regression or Lasso path using LARS algorithm.
  812. Lars : Least Angle Regression model a.k.a. LAR.
  813. LassoLars : Lasso model fit with Least Angle Regression a.k.a. Lars.
  814. OrthogonalMatchingPursuit : Orthogonal Matching Pursuit model (OMP).
  815. LarsCV : Cross-validated Least Angle Regression model.
  816. LassoLarsCV : Cross-validated Lasso model fit with Least Angle Regression.
  817. sklearn.decomposition.sparse_encode : Generic sparse coding.
  818. Each column of the result is the solution to a Lasso problem.
  819. Notes
  820. -----
  821. In `fit`, once the optimal number of non-zero coefficients is found through
  822. cross-validation, the model is fit again using the entire training set.
  823. Examples
  824. --------
  825. >>> from sklearn.linear_model import OrthogonalMatchingPursuitCV
  826. >>> from sklearn.datasets import make_regression
  827. >>> X, y = make_regression(n_features=100, n_informative=10,
  828. ... noise=4, random_state=0)
  829. >>> reg = OrthogonalMatchingPursuitCV(cv=5).fit(X, y)
  830. >>> reg.score(X, y)
  831. 0.9991...
  832. >>> reg.n_nonzero_coefs_
  833. 10
  834. >>> reg.predict(X[:1,])
  835. array([-78.3854...])
  836. """
  837. _parameter_constraints: dict = {
  838. "copy": ["boolean"],
  839. "fit_intercept": ["boolean"],
  840. "normalize": ["boolean", Hidden(StrOptions({"deprecated"}))],
  841. "max_iter": [Interval(Integral, 0, None, closed="left"), None],
  842. "cv": ["cv_object"],
  843. "n_jobs": [Integral, None],
  844. "verbose": ["verbose"],
  845. }
  846. def __init__(
  847. self,
  848. *,
  849. copy=True,
  850. fit_intercept=True,
  851. normalize="deprecated",
  852. max_iter=None,
  853. cv=None,
  854. n_jobs=None,
  855. verbose=False,
  856. ):
  857. self.copy = copy
  858. self.fit_intercept = fit_intercept
  859. self.normalize = normalize
  860. self.max_iter = max_iter
  861. self.cv = cv
  862. self.n_jobs = n_jobs
  863. self.verbose = verbose
  864. @_fit_context(prefer_skip_nested_validation=True)
  865. def fit(self, X, y):
  866. """Fit the model using X, y as training data.
  867. Parameters
  868. ----------
  869. X : array-like of shape (n_samples, n_features)
  870. Training data.
  871. y : array-like of shape (n_samples,)
  872. Target values. Will be cast to X's dtype if necessary.
  873. Returns
  874. -------
  875. self : object
  876. Returns an instance of self.
  877. """
  878. _normalize = _deprecate_normalize(
  879. self.normalize, estimator_name=self.__class__.__name__
  880. )
  881. X, y = self._validate_data(X, y, y_numeric=True, ensure_min_features=2)
  882. X = as_float_array(X, copy=False, force_all_finite=False)
  883. cv = check_cv(self.cv, classifier=False)
  884. max_iter = (
  885. min(max(int(0.1 * X.shape[1]), 5), X.shape[1])
  886. if not self.max_iter
  887. else self.max_iter
  888. )
  889. cv_paths = Parallel(n_jobs=self.n_jobs, verbose=self.verbose)(
  890. delayed(_omp_path_residues)(
  891. X[train],
  892. y[train],
  893. X[test],
  894. y[test],
  895. self.copy,
  896. self.fit_intercept,
  897. _normalize,
  898. max_iter,
  899. )
  900. for train, test in cv.split(X)
  901. )
  902. min_early_stop = min(fold.shape[0] for fold in cv_paths)
  903. mse_folds = np.array(
  904. [(fold[:min_early_stop] ** 2).mean(axis=1) for fold in cv_paths]
  905. )
  906. best_n_nonzero_coefs = np.argmin(mse_folds.mean(axis=0)) + 1
  907. self.n_nonzero_coefs_ = best_n_nonzero_coefs
  908. omp = OrthogonalMatchingPursuit(
  909. n_nonzero_coefs=best_n_nonzero_coefs,
  910. fit_intercept=self.fit_intercept,
  911. normalize=_normalize,
  912. )
  913. # avoid duplicating warning for deprecated normalize
  914. with warnings.catch_warnings():
  915. warnings.filterwarnings("ignore", category=FutureWarning)
  916. omp.fit(X, y)
  917. self.coef_ = omp.coef_
  918. self.intercept_ = omp.intercept_
  919. self.n_iter_ = omp.n_iter_
  920. return self