_fastica.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782
  1. """
  2. Python implementation of the fast ICA algorithms.
  3. Reference: Tables 8.3 and 8.4 page 196 in the book:
  4. Independent Component Analysis, by Hyvarinen et al.
  5. """
  6. # Authors: Pierre Lafaye de Micheaux, Stefan van der Walt, Gael Varoquaux,
  7. # Bertrand Thirion, Alexandre Gramfort, Denis A. Engemann
  8. # License: BSD 3 clause
  9. import warnings
  10. from numbers import Integral, Real
  11. import numpy as np
  12. from scipy import linalg
  13. from ..base import (
  14. BaseEstimator,
  15. ClassNamePrefixFeaturesOutMixin,
  16. TransformerMixin,
  17. _fit_context,
  18. )
  19. from ..exceptions import ConvergenceWarning
  20. from ..utils import as_float_array, check_array, check_random_state
  21. from ..utils._param_validation import Interval, Options, StrOptions, validate_params
  22. from ..utils.validation import check_is_fitted
  23. __all__ = ["fastica", "FastICA"]
  24. def _gs_decorrelation(w, W, j):
  25. """
  26. Orthonormalize w wrt the first j rows of W.
  27. Parameters
  28. ----------
  29. w : ndarray of shape (n,)
  30. Array to be orthogonalized
  31. W : ndarray of shape (p, n)
  32. Null space definition
  33. j : int < p
  34. The no of (from the first) rows of Null space W wrt which w is
  35. orthogonalized.
  36. Notes
  37. -----
  38. Assumes that W is orthogonal
  39. w changed in place
  40. """
  41. w -= np.linalg.multi_dot([w, W[:j].T, W[:j]])
  42. return w
  43. def _sym_decorrelation(W):
  44. """Symmetric decorrelation
  45. i.e. W <- (W * W.T) ^{-1/2} * W
  46. """
  47. s, u = linalg.eigh(np.dot(W, W.T))
  48. # Avoid sqrt of negative values because of rounding errors. Note that
  49. # np.sqrt(tiny) is larger than tiny and therefore this clipping also
  50. # prevents division by zero in the next step.
  51. s = np.clip(s, a_min=np.finfo(W.dtype).tiny, a_max=None)
  52. # u (resp. s) contains the eigenvectors (resp. square roots of
  53. # the eigenvalues) of W * W.T
  54. return np.linalg.multi_dot([u * (1.0 / np.sqrt(s)), u.T, W])
  55. def _ica_def(X, tol, g, fun_args, max_iter, w_init):
  56. """Deflationary FastICA using fun approx to neg-entropy function
  57. Used internally by FastICA.
  58. """
  59. n_components = w_init.shape[0]
  60. W = np.zeros((n_components, n_components), dtype=X.dtype)
  61. n_iter = []
  62. # j is the index of the extracted component
  63. for j in range(n_components):
  64. w = w_init[j, :].copy()
  65. w /= np.sqrt((w**2).sum())
  66. for i in range(max_iter):
  67. gwtx, g_wtx = g(np.dot(w.T, X), fun_args)
  68. w1 = (X * gwtx).mean(axis=1) - g_wtx.mean() * w
  69. _gs_decorrelation(w1, W, j)
  70. w1 /= np.sqrt((w1**2).sum())
  71. lim = np.abs(np.abs((w1 * w).sum()) - 1)
  72. w = w1
  73. if lim < tol:
  74. break
  75. n_iter.append(i + 1)
  76. W[j, :] = w
  77. return W, max(n_iter)
  78. def _ica_par(X, tol, g, fun_args, max_iter, w_init):
  79. """Parallel FastICA.
  80. Used internally by FastICA --main loop
  81. """
  82. W = _sym_decorrelation(w_init)
  83. del w_init
  84. p_ = float(X.shape[1])
  85. for ii in range(max_iter):
  86. gwtx, g_wtx = g(np.dot(W, X), fun_args)
  87. W1 = _sym_decorrelation(np.dot(gwtx, X.T) / p_ - g_wtx[:, np.newaxis] * W)
  88. del gwtx, g_wtx
  89. # builtin max, abs are faster than numpy counter parts.
  90. # np.einsum allows having the lowest memory footprint.
  91. # It is faster than np.diag(np.dot(W1, W.T)).
  92. lim = max(abs(abs(np.einsum("ij,ij->i", W1, W)) - 1))
  93. W = W1
  94. if lim < tol:
  95. break
  96. else:
  97. warnings.warn(
  98. (
  99. "FastICA did not converge. Consider increasing "
  100. "tolerance or the maximum number of iterations."
  101. ),
  102. ConvergenceWarning,
  103. )
  104. return W, ii + 1
  105. # Some standard non-linear functions.
  106. # XXX: these should be optimized, as they can be a bottleneck.
  107. def _logcosh(x, fun_args=None):
  108. alpha = fun_args.get("alpha", 1.0) # comment it out?
  109. x *= alpha
  110. gx = np.tanh(x, x) # apply the tanh inplace
  111. g_x = np.empty(x.shape[0], dtype=x.dtype)
  112. # XXX compute in chunks to avoid extra allocation
  113. for i, gx_i in enumerate(gx): # please don't vectorize.
  114. g_x[i] = (alpha * (1 - gx_i**2)).mean()
  115. return gx, g_x
  116. def _exp(x, fun_args):
  117. exp = np.exp(-(x**2) / 2)
  118. gx = x * exp
  119. g_x = (1 - x**2) * exp
  120. return gx, g_x.mean(axis=-1)
  121. def _cube(x, fun_args):
  122. return x**3, (3 * x**2).mean(axis=-1)
  123. @validate_params(
  124. {
  125. "X": ["array-like"],
  126. "return_X_mean": ["boolean"],
  127. "compute_sources": ["boolean"],
  128. "return_n_iter": ["boolean"],
  129. },
  130. prefer_skip_nested_validation=False,
  131. )
  132. def fastica(
  133. X,
  134. n_components=None,
  135. *,
  136. algorithm="parallel",
  137. whiten="unit-variance",
  138. fun="logcosh",
  139. fun_args=None,
  140. max_iter=200,
  141. tol=1e-04,
  142. w_init=None,
  143. whiten_solver="svd",
  144. random_state=None,
  145. return_X_mean=False,
  146. compute_sources=True,
  147. return_n_iter=False,
  148. ):
  149. """Perform Fast Independent Component Analysis.
  150. The implementation is based on [1]_.
  151. Read more in the :ref:`User Guide <ICA>`.
  152. Parameters
  153. ----------
  154. X : array-like of shape (n_samples, n_features)
  155. Training vector, where `n_samples` is the number of samples and
  156. `n_features` is the number of features.
  157. n_components : int, default=None
  158. Number of components to use. If None is passed, all are used.
  159. algorithm : {'parallel', 'deflation'}, default='parallel'
  160. Specify which algorithm to use for FastICA.
  161. whiten : str or bool, default='unit-variance'
  162. Specify the whitening strategy to use.
  163. - If 'arbitrary-variance', a whitening with variance
  164. arbitrary is used.
  165. - If 'unit-variance', the whitening matrix is rescaled to ensure that
  166. each recovered source has unit variance.
  167. - If False, the data is already considered to be whitened, and no
  168. whitening is performed.
  169. .. versionchanged:: 1.3
  170. The default value of `whiten` changed to 'unit-variance' in 1.3.
  171. fun : {'logcosh', 'exp', 'cube'} or callable, default='logcosh'
  172. The functional form of the G function used in the
  173. approximation to neg-entropy. Could be either 'logcosh', 'exp',
  174. or 'cube'.
  175. You can also provide your own function. It should return a tuple
  176. containing the value of the function, and of its derivative, in the
  177. point. The derivative should be averaged along its last dimension.
  178. Example::
  179. def my_g(x):
  180. return x ** 3, (3 * x ** 2).mean(axis=-1)
  181. fun_args : dict, default=None
  182. Arguments to send to the functional form.
  183. If empty or None and if fun='logcosh', fun_args will take value
  184. {'alpha' : 1.0}.
  185. max_iter : int, default=200
  186. Maximum number of iterations to perform.
  187. tol : float, default=1e-4
  188. A positive scalar giving the tolerance at which the
  189. un-mixing matrix is considered to have converged.
  190. w_init : ndarray of shape (n_components, n_components), default=None
  191. Initial un-mixing array. If `w_init=None`, then an array of values
  192. drawn from a normal distribution is used.
  193. whiten_solver : {"eigh", "svd"}, default="svd"
  194. The solver to use for whitening.
  195. - "svd" is more stable numerically if the problem is degenerate, and
  196. often faster when `n_samples <= n_features`.
  197. - "eigh" is generally more memory efficient when
  198. `n_samples >= n_features`, and can be faster when
  199. `n_samples >= 50 * n_features`.
  200. .. versionadded:: 1.2
  201. random_state : int, RandomState instance or None, default=None
  202. Used to initialize ``w_init`` when not specified, with a
  203. normal distribution. Pass an int, for reproducible results
  204. across multiple function calls.
  205. See :term:`Glossary <random_state>`.
  206. return_X_mean : bool, default=False
  207. If True, X_mean is returned too.
  208. compute_sources : bool, default=True
  209. If False, sources are not computed, but only the rotation matrix.
  210. This can save memory when working with big data. Defaults to True.
  211. return_n_iter : bool, default=False
  212. Whether or not to return the number of iterations.
  213. Returns
  214. -------
  215. K : ndarray of shape (n_components, n_features) or None
  216. If whiten is 'True', K is the pre-whitening matrix that projects data
  217. onto the first n_components principal components. If whiten is 'False',
  218. K is 'None'.
  219. W : ndarray of shape (n_components, n_components)
  220. The square matrix that unmixes the data after whitening.
  221. The mixing matrix is the pseudo-inverse of matrix ``W K``
  222. if K is not None, else it is the inverse of W.
  223. S : ndarray of shape (n_samples, n_components) or None
  224. Estimated source matrix.
  225. X_mean : ndarray of shape (n_features,)
  226. The mean over features. Returned only if return_X_mean is True.
  227. n_iter : int
  228. If the algorithm is "deflation", n_iter is the
  229. maximum number of iterations run across all components. Else
  230. they are just the number of iterations taken to converge. This is
  231. returned only when return_n_iter is set to `True`.
  232. Notes
  233. -----
  234. The data matrix X is considered to be a linear combination of
  235. non-Gaussian (independent) components i.e. X = AS where columns of S
  236. contain the independent components and A is a linear mixing
  237. matrix. In short ICA attempts to `un-mix' the data by estimating an
  238. un-mixing matrix W where ``S = W K X.``
  239. While FastICA was proposed to estimate as many sources
  240. as features, it is possible to estimate less by setting
  241. n_components < n_features. It this case K is not a square matrix
  242. and the estimated A is the pseudo-inverse of ``W K``.
  243. This implementation was originally made for data of shape
  244. [n_features, n_samples]. Now the input is transposed
  245. before the algorithm is applied. This makes it slightly
  246. faster for Fortran-ordered input.
  247. References
  248. ----------
  249. .. [1] A. Hyvarinen and E. Oja, "Fast Independent Component Analysis",
  250. Algorithms and Applications, Neural Networks, 13(4-5), 2000,
  251. pp. 411-430.
  252. """
  253. est = FastICA(
  254. n_components=n_components,
  255. algorithm=algorithm,
  256. whiten=whiten,
  257. fun=fun,
  258. fun_args=fun_args,
  259. max_iter=max_iter,
  260. tol=tol,
  261. w_init=w_init,
  262. whiten_solver=whiten_solver,
  263. random_state=random_state,
  264. )
  265. est._validate_params()
  266. S = est._fit_transform(X, compute_sources=compute_sources)
  267. if est.whiten in ["unit-variance", "arbitrary-variance"]:
  268. K = est.whitening_
  269. X_mean = est.mean_
  270. else:
  271. K = None
  272. X_mean = None
  273. returned_values = [K, est._unmixing, S]
  274. if return_X_mean:
  275. returned_values.append(X_mean)
  276. if return_n_iter:
  277. returned_values.append(est.n_iter_)
  278. return returned_values
  279. class FastICA(ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator):
  280. """FastICA: a fast algorithm for Independent Component Analysis.
  281. The implementation is based on [1]_.
  282. Read more in the :ref:`User Guide <ICA>`.
  283. Parameters
  284. ----------
  285. n_components : int, default=None
  286. Number of components to use. If None is passed, all are used.
  287. algorithm : {'parallel', 'deflation'}, default='parallel'
  288. Specify which algorithm to use for FastICA.
  289. whiten : str or bool, default='unit-variance'
  290. Specify the whitening strategy to use.
  291. - If 'arbitrary-variance', a whitening with variance
  292. arbitrary is used.
  293. - If 'unit-variance', the whitening matrix is rescaled to ensure that
  294. each recovered source has unit variance.
  295. - If False, the data is already considered to be whitened, and no
  296. whitening is performed.
  297. .. versionchanged:: 1.3
  298. The default value of `whiten` changed to 'unit-variance' in 1.3.
  299. fun : {'logcosh', 'exp', 'cube'} or callable, default='logcosh'
  300. The functional form of the G function used in the
  301. approximation to neg-entropy. Could be either 'logcosh', 'exp',
  302. or 'cube'.
  303. You can also provide your own function. It should return a tuple
  304. containing the value of the function, and of its derivative, in the
  305. point. The derivative should be averaged along its last dimension.
  306. Example::
  307. def my_g(x):
  308. return x ** 3, (3 * x ** 2).mean(axis=-1)
  309. fun_args : dict, default=None
  310. Arguments to send to the functional form.
  311. If empty or None and if fun='logcosh', fun_args will take value
  312. {'alpha' : 1.0}.
  313. max_iter : int, default=200
  314. Maximum number of iterations during fit.
  315. tol : float, default=1e-4
  316. A positive scalar giving the tolerance at which the
  317. un-mixing matrix is considered to have converged.
  318. w_init : array-like of shape (n_components, n_components), default=None
  319. Initial un-mixing array. If `w_init=None`, then an array of values
  320. drawn from a normal distribution is used.
  321. whiten_solver : {"eigh", "svd"}, default="svd"
  322. The solver to use for whitening.
  323. - "svd" is more stable numerically if the problem is degenerate, and
  324. often faster when `n_samples <= n_features`.
  325. - "eigh" is generally more memory efficient when
  326. `n_samples >= n_features`, and can be faster when
  327. `n_samples >= 50 * n_features`.
  328. .. versionadded:: 1.2
  329. random_state : int, RandomState instance or None, default=None
  330. Used to initialize ``w_init`` when not specified, with a
  331. normal distribution. Pass an int, for reproducible results
  332. across multiple function calls.
  333. See :term:`Glossary <random_state>`.
  334. Attributes
  335. ----------
  336. components_ : ndarray of shape (n_components, n_features)
  337. The linear operator to apply to the data to get the independent
  338. sources. This is equal to the unmixing matrix when ``whiten`` is
  339. False, and equal to ``np.dot(unmixing_matrix, self.whitening_)`` when
  340. ``whiten`` is True.
  341. mixing_ : ndarray of shape (n_features, n_components)
  342. The pseudo-inverse of ``components_``. It is the linear operator
  343. that maps independent sources to the data.
  344. mean_ : ndarray of shape(n_features,)
  345. The mean over features. Only set if `self.whiten` is True.
  346. n_features_in_ : int
  347. Number of features seen during :term:`fit`.
  348. .. versionadded:: 0.24
  349. feature_names_in_ : ndarray of shape (`n_features_in_`,)
  350. Names of features seen during :term:`fit`. Defined only when `X`
  351. has feature names that are all strings.
  352. .. versionadded:: 1.0
  353. n_iter_ : int
  354. If the algorithm is "deflation", n_iter is the
  355. maximum number of iterations run across all components. Else
  356. they are just the number of iterations taken to converge.
  357. whitening_ : ndarray of shape (n_components, n_features)
  358. Only set if whiten is 'True'. This is the pre-whitening matrix
  359. that projects data onto the first `n_components` principal components.
  360. See Also
  361. --------
  362. PCA : Principal component analysis (PCA).
  363. IncrementalPCA : Incremental principal components analysis (IPCA).
  364. KernelPCA : Kernel Principal component analysis (KPCA).
  365. MiniBatchSparsePCA : Mini-batch Sparse Principal Components Analysis.
  366. SparsePCA : Sparse Principal Components Analysis (SparsePCA).
  367. References
  368. ----------
  369. .. [1] A. Hyvarinen and E. Oja, Independent Component Analysis:
  370. Algorithms and Applications, Neural Networks, 13(4-5), 2000,
  371. pp. 411-430.
  372. Examples
  373. --------
  374. >>> from sklearn.datasets import load_digits
  375. >>> from sklearn.decomposition import FastICA
  376. >>> X, _ = load_digits(return_X_y=True)
  377. >>> transformer = FastICA(n_components=7,
  378. ... random_state=0,
  379. ... whiten='unit-variance')
  380. >>> X_transformed = transformer.fit_transform(X)
  381. >>> X_transformed.shape
  382. (1797, 7)
  383. """
  384. _parameter_constraints: dict = {
  385. "n_components": [Interval(Integral, 1, None, closed="left"), None],
  386. "algorithm": [StrOptions({"parallel", "deflation"})],
  387. "whiten": [
  388. StrOptions({"arbitrary-variance", "unit-variance"}),
  389. Options(bool, {False}),
  390. ],
  391. "fun": [StrOptions({"logcosh", "exp", "cube"}), callable],
  392. "fun_args": [dict, None],
  393. "max_iter": [Interval(Integral, 1, None, closed="left")],
  394. "tol": [Interval(Real, 0.0, None, closed="left")],
  395. "w_init": ["array-like", None],
  396. "whiten_solver": [StrOptions({"eigh", "svd"})],
  397. "random_state": ["random_state"],
  398. }
  399. def __init__(
  400. self,
  401. n_components=None,
  402. *,
  403. algorithm="parallel",
  404. whiten="unit-variance",
  405. fun="logcosh",
  406. fun_args=None,
  407. max_iter=200,
  408. tol=1e-4,
  409. w_init=None,
  410. whiten_solver="svd",
  411. random_state=None,
  412. ):
  413. super().__init__()
  414. self.n_components = n_components
  415. self.algorithm = algorithm
  416. self.whiten = whiten
  417. self.fun = fun
  418. self.fun_args = fun_args
  419. self.max_iter = max_iter
  420. self.tol = tol
  421. self.w_init = w_init
  422. self.whiten_solver = whiten_solver
  423. self.random_state = random_state
  424. def _fit_transform(self, X, compute_sources=False):
  425. """Fit the model.
  426. Parameters
  427. ----------
  428. X : array-like of shape (n_samples, n_features)
  429. Training data, where `n_samples` is the number of samples
  430. and `n_features` is the number of features.
  431. compute_sources : bool, default=False
  432. If False, sources are not computes but only the rotation matrix.
  433. This can save memory when working with big data. Defaults to False.
  434. Returns
  435. -------
  436. S : ndarray of shape (n_samples, n_components) or None
  437. Sources matrix. `None` if `compute_sources` is `False`.
  438. """
  439. XT = self._validate_data(
  440. X, copy=self.whiten, dtype=[np.float64, np.float32], ensure_min_samples=2
  441. ).T
  442. fun_args = {} if self.fun_args is None else self.fun_args
  443. random_state = check_random_state(self.random_state)
  444. alpha = fun_args.get("alpha", 1.0)
  445. if not 1 <= alpha <= 2:
  446. raise ValueError("alpha must be in [1,2]")
  447. if self.fun == "logcosh":
  448. g = _logcosh
  449. elif self.fun == "exp":
  450. g = _exp
  451. elif self.fun == "cube":
  452. g = _cube
  453. elif callable(self.fun):
  454. def g(x, fun_args):
  455. return self.fun(x, **fun_args)
  456. n_features, n_samples = XT.shape
  457. n_components = self.n_components
  458. if not self.whiten and n_components is not None:
  459. n_components = None
  460. warnings.warn("Ignoring n_components with whiten=False.")
  461. if n_components is None:
  462. n_components = min(n_samples, n_features)
  463. if n_components > min(n_samples, n_features):
  464. n_components = min(n_samples, n_features)
  465. warnings.warn(
  466. "n_components is too large: it will be set to %s" % n_components
  467. )
  468. if self.whiten:
  469. # Centering the features of X
  470. X_mean = XT.mean(axis=-1)
  471. XT -= X_mean[:, np.newaxis]
  472. # Whitening and preprocessing by PCA
  473. if self.whiten_solver == "eigh":
  474. # Faster when num_samples >> n_features
  475. d, u = linalg.eigh(XT.dot(X))
  476. sort_indices = np.argsort(d)[::-1]
  477. eps = np.finfo(d.dtype).eps
  478. degenerate_idx = d < eps
  479. if np.any(degenerate_idx):
  480. warnings.warn(
  481. "There are some small singular values, using "
  482. "whiten_solver = 'svd' might lead to more "
  483. "accurate results."
  484. )
  485. d[degenerate_idx] = eps # For numerical issues
  486. np.sqrt(d, out=d)
  487. d, u = d[sort_indices], u[:, sort_indices]
  488. elif self.whiten_solver == "svd":
  489. u, d = linalg.svd(XT, full_matrices=False, check_finite=False)[:2]
  490. # Give consistent eigenvectors for both svd solvers
  491. u *= np.sign(u[0])
  492. K = (u / d).T[:n_components] # see (6.33) p.140
  493. del u, d
  494. X1 = np.dot(K, XT)
  495. # see (13.6) p.267 Here X1 is white and data
  496. # in X has been projected onto a subspace by PCA
  497. X1 *= np.sqrt(n_samples)
  498. else:
  499. # X must be casted to floats to avoid typing issues with numpy
  500. # 2.0 and the line below
  501. X1 = as_float_array(XT, copy=False) # copy has been taken care of
  502. w_init = self.w_init
  503. if w_init is None:
  504. w_init = np.asarray(
  505. random_state.normal(size=(n_components, n_components)), dtype=X1.dtype
  506. )
  507. else:
  508. w_init = np.asarray(w_init)
  509. if w_init.shape != (n_components, n_components):
  510. raise ValueError(
  511. "w_init has invalid shape -- should be %(shape)s"
  512. % {"shape": (n_components, n_components)}
  513. )
  514. kwargs = {
  515. "tol": self.tol,
  516. "g": g,
  517. "fun_args": fun_args,
  518. "max_iter": self.max_iter,
  519. "w_init": w_init,
  520. }
  521. if self.algorithm == "parallel":
  522. W, n_iter = _ica_par(X1, **kwargs)
  523. elif self.algorithm == "deflation":
  524. W, n_iter = _ica_def(X1, **kwargs)
  525. del X1
  526. self.n_iter_ = n_iter
  527. if compute_sources:
  528. if self.whiten:
  529. S = np.linalg.multi_dot([W, K, XT]).T
  530. else:
  531. S = np.dot(W, XT).T
  532. else:
  533. S = None
  534. if self.whiten:
  535. if self.whiten == "unit-variance":
  536. if not compute_sources:
  537. S = np.linalg.multi_dot([W, K, XT]).T
  538. S_std = np.std(S, axis=0, keepdims=True)
  539. S /= S_std
  540. W /= S_std.T
  541. self.components_ = np.dot(W, K)
  542. self.mean_ = X_mean
  543. self.whitening_ = K
  544. else:
  545. self.components_ = W
  546. self.mixing_ = linalg.pinv(self.components_, check_finite=False)
  547. self._unmixing = W
  548. return S
  549. @_fit_context(prefer_skip_nested_validation=True)
  550. def fit_transform(self, X, y=None):
  551. """Fit the model and recover the sources from X.
  552. Parameters
  553. ----------
  554. X : array-like of shape (n_samples, n_features)
  555. Training data, where `n_samples` is the number of samples
  556. and `n_features` is the number of features.
  557. y : Ignored
  558. Not used, present for API consistency by convention.
  559. Returns
  560. -------
  561. X_new : ndarray of shape (n_samples, n_components)
  562. Estimated sources obtained by transforming the data with the
  563. estimated unmixing matrix.
  564. """
  565. return self._fit_transform(X, compute_sources=True)
  566. @_fit_context(prefer_skip_nested_validation=True)
  567. def fit(self, X, y=None):
  568. """Fit the model to X.
  569. Parameters
  570. ----------
  571. X : array-like of shape (n_samples, n_features)
  572. Training data, where `n_samples` is the number of samples
  573. and `n_features` is the number of features.
  574. y : Ignored
  575. Not used, present for API consistency by convention.
  576. Returns
  577. -------
  578. self : object
  579. Returns the instance itself.
  580. """
  581. self._fit_transform(X, compute_sources=False)
  582. return self
  583. def transform(self, X, copy=True):
  584. """Recover the sources from X (apply the unmixing matrix).
  585. Parameters
  586. ----------
  587. X : array-like of shape (n_samples, n_features)
  588. Data to transform, where `n_samples` is the number of samples
  589. and `n_features` is the number of features.
  590. copy : bool, default=True
  591. If False, data passed to fit can be overwritten. Defaults to True.
  592. Returns
  593. -------
  594. X_new : ndarray of shape (n_samples, n_components)
  595. Estimated sources obtained by transforming the data with the
  596. estimated unmixing matrix.
  597. """
  598. check_is_fitted(self)
  599. X = self._validate_data(
  600. X, copy=(copy and self.whiten), dtype=[np.float64, np.float32], reset=False
  601. )
  602. if self.whiten:
  603. X -= self.mean_
  604. return np.dot(X, self.components_.T)
  605. def inverse_transform(self, X, copy=True):
  606. """Transform the sources back to the mixed data (apply mixing matrix).
  607. Parameters
  608. ----------
  609. X : array-like of shape (n_samples, n_components)
  610. Sources, where `n_samples` is the number of samples
  611. and `n_components` is the number of components.
  612. copy : bool, default=True
  613. If False, data passed to fit are overwritten. Defaults to True.
  614. Returns
  615. -------
  616. X_new : ndarray of shape (n_samples, n_features)
  617. Reconstructed data obtained with the mixing matrix.
  618. """
  619. check_is_fitted(self)
  620. X = check_array(X, copy=(copy and self.whiten), dtype=[np.float64, np.float32])
  621. X = np.dot(X, self.mixing_.T)
  622. if self.whiten:
  623. X += self.mean_
  624. return X
  625. @property
  626. def _n_features_out(self):
  627. """Number of transformed output features."""
  628. return self.components_.shape[0]
  629. def _more_tags(self):
  630. return {"preserves_dtype": [np.float32, np.float64]}