discriminant_analysis.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047
  1. """
  2. Linear Discriminant Analysis and Quadratic Discriminant Analysis
  3. """
  4. # Authors: Clemens Brunner
  5. # Martin Billinger
  6. # Matthieu Perrot
  7. # Mathieu Blondel
  8. # License: BSD 3-Clause
  9. import warnings
  10. from numbers import Integral, Real
  11. import numpy as np
  12. import scipy.linalg
  13. from scipy import linalg
  14. from .base import (
  15. BaseEstimator,
  16. ClassifierMixin,
  17. ClassNamePrefixFeaturesOutMixin,
  18. TransformerMixin,
  19. _fit_context,
  20. )
  21. from .covariance import empirical_covariance, ledoit_wolf, shrunk_covariance
  22. from .linear_model._base import LinearClassifierMixin
  23. from .preprocessing import StandardScaler
  24. from .utils._array_api import _expit, device, get_namespace, size
  25. from .utils._param_validation import HasMethods, Interval, StrOptions
  26. from .utils.extmath import softmax
  27. from .utils.multiclass import check_classification_targets, unique_labels
  28. from .utils.validation import check_is_fitted
  29. __all__ = ["LinearDiscriminantAnalysis", "QuadraticDiscriminantAnalysis"]
  30. def _cov(X, shrinkage=None, covariance_estimator=None):
  31. """Estimate covariance matrix (using optional covariance_estimator).
  32. Parameters
  33. ----------
  34. X : array-like of shape (n_samples, n_features)
  35. Input data.
  36. shrinkage : {'empirical', 'auto'} or float, default=None
  37. Shrinkage parameter, possible values:
  38. - None or 'empirical': no shrinkage (default).
  39. - 'auto': automatic shrinkage using the Ledoit-Wolf lemma.
  40. - float between 0 and 1: fixed shrinkage parameter.
  41. Shrinkage parameter is ignored if `covariance_estimator`
  42. is not None.
  43. covariance_estimator : estimator, default=None
  44. If not None, `covariance_estimator` is used to estimate
  45. the covariance matrices instead of relying on the empirical
  46. covariance estimator (with potential shrinkage).
  47. The object should have a fit method and a ``covariance_`` attribute
  48. like the estimators in :mod:`sklearn.covariance``.
  49. if None the shrinkage parameter drives the estimate.
  50. .. versionadded:: 0.24
  51. Returns
  52. -------
  53. s : ndarray of shape (n_features, n_features)
  54. Estimated covariance matrix.
  55. """
  56. if covariance_estimator is None:
  57. shrinkage = "empirical" if shrinkage is None else shrinkage
  58. if isinstance(shrinkage, str):
  59. if shrinkage == "auto":
  60. sc = StandardScaler() # standardize features
  61. X = sc.fit_transform(X)
  62. s = ledoit_wolf(X)[0]
  63. # rescale
  64. s = sc.scale_[:, np.newaxis] * s * sc.scale_[np.newaxis, :]
  65. elif shrinkage == "empirical":
  66. s = empirical_covariance(X)
  67. elif isinstance(shrinkage, Real):
  68. s = shrunk_covariance(empirical_covariance(X), shrinkage)
  69. else:
  70. if shrinkage is not None and shrinkage != 0:
  71. raise ValueError(
  72. "covariance_estimator and shrinkage parameters "
  73. "are not None. Only one of the two can be set."
  74. )
  75. covariance_estimator.fit(X)
  76. if not hasattr(covariance_estimator, "covariance_"):
  77. raise ValueError(
  78. "%s does not have a covariance_ attribute"
  79. % covariance_estimator.__class__.__name__
  80. )
  81. s = covariance_estimator.covariance_
  82. return s
  83. def _class_means(X, y):
  84. """Compute class means.
  85. Parameters
  86. ----------
  87. X : array-like of shape (n_samples, n_features)
  88. Input data.
  89. y : array-like of shape (n_samples,) or (n_samples, n_targets)
  90. Target values.
  91. Returns
  92. -------
  93. means : array-like of shape (n_classes, n_features)
  94. Class means.
  95. """
  96. xp, is_array_api_compliant = get_namespace(X)
  97. classes, y = xp.unique_inverse(y)
  98. means = xp.zeros((classes.shape[0], X.shape[1]), device=device(X), dtype=X.dtype)
  99. if is_array_api_compliant:
  100. for i in range(classes.shape[0]):
  101. means[i, :] = xp.mean(X[y == i], axis=0)
  102. else:
  103. # TODO: Explore the choice of using bincount + add.at as it seems sub optimal
  104. # from a performance-wise
  105. cnt = np.bincount(y)
  106. np.add.at(means, y, X)
  107. means /= cnt[:, None]
  108. return means
  109. def _class_cov(X, y, priors, shrinkage=None, covariance_estimator=None):
  110. """Compute weighted within-class covariance matrix.
  111. The per-class covariance are weighted by the class priors.
  112. Parameters
  113. ----------
  114. X : array-like of shape (n_samples, n_features)
  115. Input data.
  116. y : array-like of shape (n_samples,) or (n_samples, n_targets)
  117. Target values.
  118. priors : array-like of shape (n_classes,)
  119. Class priors.
  120. shrinkage : 'auto' or float, default=None
  121. Shrinkage parameter, possible values:
  122. - None: no shrinkage (default).
  123. - 'auto': automatic shrinkage using the Ledoit-Wolf lemma.
  124. - float between 0 and 1: fixed shrinkage parameter.
  125. Shrinkage parameter is ignored if `covariance_estimator` is not None.
  126. covariance_estimator : estimator, default=None
  127. If not None, `covariance_estimator` is used to estimate
  128. the covariance matrices instead of relying the empirical
  129. covariance estimator (with potential shrinkage).
  130. The object should have a fit method and a ``covariance_`` attribute
  131. like the estimators in sklearn.covariance.
  132. If None, the shrinkage parameter drives the estimate.
  133. .. versionadded:: 0.24
  134. Returns
  135. -------
  136. cov : array-like of shape (n_features, n_features)
  137. Weighted within-class covariance matrix
  138. """
  139. classes = np.unique(y)
  140. cov = np.zeros(shape=(X.shape[1], X.shape[1]))
  141. for idx, group in enumerate(classes):
  142. Xg = X[y == group, :]
  143. cov += priors[idx] * np.atleast_2d(_cov(Xg, shrinkage, covariance_estimator))
  144. return cov
  145. class LinearDiscriminantAnalysis(
  146. ClassNamePrefixFeaturesOutMixin,
  147. LinearClassifierMixin,
  148. TransformerMixin,
  149. BaseEstimator,
  150. ):
  151. """Linear Discriminant Analysis.
  152. A classifier with a linear decision boundary, generated by fitting class
  153. conditional densities to the data and using Bayes' rule.
  154. The model fits a Gaussian density to each class, assuming that all classes
  155. share the same covariance matrix.
  156. The fitted model can also be used to reduce the dimensionality of the input
  157. by projecting it to the most discriminative directions, using the
  158. `transform` method.
  159. .. versionadded:: 0.17
  160. *LinearDiscriminantAnalysis*.
  161. Read more in the :ref:`User Guide <lda_qda>`.
  162. Parameters
  163. ----------
  164. solver : {'svd', 'lsqr', 'eigen'}, default='svd'
  165. Solver to use, possible values:
  166. - 'svd': Singular value decomposition (default).
  167. Does not compute the covariance matrix, therefore this solver is
  168. recommended for data with a large number of features.
  169. - 'lsqr': Least squares solution.
  170. Can be combined with shrinkage or custom covariance estimator.
  171. - 'eigen': Eigenvalue decomposition.
  172. Can be combined with shrinkage or custom covariance estimator.
  173. .. versionchanged:: 1.2
  174. `solver="svd"` now has experimental Array API support. See the
  175. :ref:`Array API User Guide <array_api>` for more details.
  176. shrinkage : 'auto' or float, default=None
  177. Shrinkage parameter, possible values:
  178. - None: no shrinkage (default).
  179. - 'auto': automatic shrinkage using the Ledoit-Wolf lemma.
  180. - float between 0 and 1: fixed shrinkage parameter.
  181. This should be left to None if `covariance_estimator` is used.
  182. Note that shrinkage works only with 'lsqr' and 'eigen' solvers.
  183. priors : array-like of shape (n_classes,), default=None
  184. The class prior probabilities. By default, the class proportions are
  185. inferred from the training data.
  186. n_components : int, default=None
  187. Number of components (<= min(n_classes - 1, n_features)) for
  188. dimensionality reduction. If None, will be set to
  189. min(n_classes - 1, n_features). This parameter only affects the
  190. `transform` method.
  191. store_covariance : bool, default=False
  192. If True, explicitly compute the weighted within-class covariance
  193. matrix when solver is 'svd'. The matrix is always computed
  194. and stored for the other solvers.
  195. .. versionadded:: 0.17
  196. tol : float, default=1.0e-4
  197. Absolute threshold for a singular value of X to be considered
  198. significant, used to estimate the rank of X. Dimensions whose
  199. singular values are non-significant are discarded. Only used if
  200. solver is 'svd'.
  201. .. versionadded:: 0.17
  202. covariance_estimator : covariance estimator, default=None
  203. If not None, `covariance_estimator` is used to estimate
  204. the covariance matrices instead of relying on the empirical
  205. covariance estimator (with potential shrinkage).
  206. The object should have a fit method and a ``covariance_`` attribute
  207. like the estimators in :mod:`sklearn.covariance`.
  208. if None the shrinkage parameter drives the estimate.
  209. This should be left to None if `shrinkage` is used.
  210. Note that `covariance_estimator` works only with 'lsqr' and 'eigen'
  211. solvers.
  212. .. versionadded:: 0.24
  213. Attributes
  214. ----------
  215. coef_ : ndarray of shape (n_features,) or (n_classes, n_features)
  216. Weight vector(s).
  217. intercept_ : ndarray of shape (n_classes,)
  218. Intercept term.
  219. covariance_ : array-like of shape (n_features, n_features)
  220. Weighted within-class covariance matrix. It corresponds to
  221. `sum_k prior_k * C_k` where `C_k` is the covariance matrix of the
  222. samples in class `k`. The `C_k` are estimated using the (potentially
  223. shrunk) biased estimator of covariance. If solver is 'svd', only
  224. exists when `store_covariance` is True.
  225. explained_variance_ratio_ : ndarray of shape (n_components,)
  226. Percentage of variance explained by each of the selected components.
  227. If ``n_components`` is not set then all components are stored and the
  228. sum of explained variances is equal to 1.0. Only available when eigen
  229. or svd solver is used.
  230. means_ : array-like of shape (n_classes, n_features)
  231. Class-wise means.
  232. priors_ : array-like of shape (n_classes,)
  233. Class priors (sum to 1).
  234. scalings_ : array-like of shape (rank, n_classes - 1)
  235. Scaling of the features in the space spanned by the class centroids.
  236. Only available for 'svd' and 'eigen' solvers.
  237. xbar_ : array-like of shape (n_features,)
  238. Overall mean. Only present if solver is 'svd'.
  239. classes_ : array-like of shape (n_classes,)
  240. Unique class labels.
  241. n_features_in_ : int
  242. Number of features seen during :term:`fit`.
  243. .. versionadded:: 0.24
  244. feature_names_in_ : ndarray of shape (`n_features_in_`,)
  245. Names of features seen during :term:`fit`. Defined only when `X`
  246. has feature names that are all strings.
  247. .. versionadded:: 1.0
  248. See Also
  249. --------
  250. QuadraticDiscriminantAnalysis : Quadratic Discriminant Analysis.
  251. Examples
  252. --------
  253. >>> import numpy as np
  254. >>> from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
  255. >>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
  256. >>> y = np.array([1, 1, 1, 2, 2, 2])
  257. >>> clf = LinearDiscriminantAnalysis()
  258. >>> clf.fit(X, y)
  259. LinearDiscriminantAnalysis()
  260. >>> print(clf.predict([[-0.8, -1]]))
  261. [1]
  262. """
  263. _parameter_constraints: dict = {
  264. "solver": [StrOptions({"svd", "lsqr", "eigen"})],
  265. "shrinkage": [StrOptions({"auto"}), Interval(Real, 0, 1, closed="both"), None],
  266. "n_components": [Interval(Integral, 1, None, closed="left"), None],
  267. "priors": ["array-like", None],
  268. "store_covariance": ["boolean"],
  269. "tol": [Interval(Real, 0, None, closed="left")],
  270. "covariance_estimator": [HasMethods("fit"), None],
  271. }
  272. def __init__(
  273. self,
  274. solver="svd",
  275. shrinkage=None,
  276. priors=None,
  277. n_components=None,
  278. store_covariance=False,
  279. tol=1e-4,
  280. covariance_estimator=None,
  281. ):
  282. self.solver = solver
  283. self.shrinkage = shrinkage
  284. self.priors = priors
  285. self.n_components = n_components
  286. self.store_covariance = store_covariance # used only in svd solver
  287. self.tol = tol # used only in svd solver
  288. self.covariance_estimator = covariance_estimator
  289. def _solve_lstsq(self, X, y, shrinkage, covariance_estimator):
  290. """Least squares solver.
  291. The least squares solver computes a straightforward solution of the
  292. optimal decision rule based directly on the discriminant functions. It
  293. can only be used for classification (with any covariance estimator),
  294. because
  295. estimation of eigenvectors is not performed. Therefore, dimensionality
  296. reduction with the transform is not supported.
  297. Parameters
  298. ----------
  299. X : array-like of shape (n_samples, n_features)
  300. Training data.
  301. y : array-like of shape (n_samples,) or (n_samples, n_classes)
  302. Target values.
  303. shrinkage : 'auto', float or None
  304. Shrinkage parameter, possible values:
  305. - None: no shrinkage.
  306. - 'auto': automatic shrinkage using the Ledoit-Wolf lemma.
  307. - float between 0 and 1: fixed shrinkage parameter.
  308. Shrinkage parameter is ignored if `covariance_estimator` i
  309. not None
  310. covariance_estimator : estimator, default=None
  311. If not None, `covariance_estimator` is used to estimate
  312. the covariance matrices instead of relying the empirical
  313. covariance estimator (with potential shrinkage).
  314. The object should have a fit method and a ``covariance_`` attribute
  315. like the estimators in sklearn.covariance.
  316. if None the shrinkage parameter drives the estimate.
  317. .. versionadded:: 0.24
  318. Notes
  319. -----
  320. This solver is based on [1]_, section 2.6.2, pp. 39-41.
  321. References
  322. ----------
  323. .. [1] R. O. Duda, P. E. Hart, D. G. Stork. Pattern Classification
  324. (Second Edition). John Wiley & Sons, Inc., New York, 2001. ISBN
  325. 0-471-05669-3.
  326. """
  327. self.means_ = _class_means(X, y)
  328. self.covariance_ = _class_cov(
  329. X, y, self.priors_, shrinkage, covariance_estimator
  330. )
  331. self.coef_ = linalg.lstsq(self.covariance_, self.means_.T)[0].T
  332. self.intercept_ = -0.5 * np.diag(np.dot(self.means_, self.coef_.T)) + np.log(
  333. self.priors_
  334. )
  335. def _solve_eigen(self, X, y, shrinkage, covariance_estimator):
  336. """Eigenvalue solver.
  337. The eigenvalue solver computes the optimal solution of the Rayleigh
  338. coefficient (basically the ratio of between class scatter to within
  339. class scatter). This solver supports both classification and
  340. dimensionality reduction (with any covariance estimator).
  341. Parameters
  342. ----------
  343. X : array-like of shape (n_samples, n_features)
  344. Training data.
  345. y : array-like of shape (n_samples,) or (n_samples, n_targets)
  346. Target values.
  347. shrinkage : 'auto', float or None
  348. Shrinkage parameter, possible values:
  349. - None: no shrinkage.
  350. - 'auto': automatic shrinkage using the Ledoit-Wolf lemma.
  351. - float between 0 and 1: fixed shrinkage constant.
  352. Shrinkage parameter is ignored if `covariance_estimator` i
  353. not None
  354. covariance_estimator : estimator, default=None
  355. If not None, `covariance_estimator` is used to estimate
  356. the covariance matrices instead of relying the empirical
  357. covariance estimator (with potential shrinkage).
  358. The object should have a fit method and a ``covariance_`` attribute
  359. like the estimators in sklearn.covariance.
  360. if None the shrinkage parameter drives the estimate.
  361. .. versionadded:: 0.24
  362. Notes
  363. -----
  364. This solver is based on [1]_, section 3.8.3, pp. 121-124.
  365. References
  366. ----------
  367. .. [1] R. O. Duda, P. E. Hart, D. G. Stork. Pattern Classification
  368. (Second Edition). John Wiley & Sons, Inc., New York, 2001. ISBN
  369. 0-471-05669-3.
  370. """
  371. self.means_ = _class_means(X, y)
  372. self.covariance_ = _class_cov(
  373. X, y, self.priors_, shrinkage, covariance_estimator
  374. )
  375. Sw = self.covariance_ # within scatter
  376. St = _cov(X, shrinkage, covariance_estimator) # total scatter
  377. Sb = St - Sw # between scatter
  378. evals, evecs = linalg.eigh(Sb, Sw)
  379. self.explained_variance_ratio_ = np.sort(evals / np.sum(evals))[::-1][
  380. : self._max_components
  381. ]
  382. evecs = evecs[:, np.argsort(evals)[::-1]] # sort eigenvectors
  383. self.scalings_ = evecs
  384. self.coef_ = np.dot(self.means_, evecs).dot(evecs.T)
  385. self.intercept_ = -0.5 * np.diag(np.dot(self.means_, self.coef_.T)) + np.log(
  386. self.priors_
  387. )
  388. def _solve_svd(self, X, y):
  389. """SVD solver.
  390. Parameters
  391. ----------
  392. X : array-like of shape (n_samples, n_features)
  393. Training data.
  394. y : array-like of shape (n_samples,) or (n_samples, n_targets)
  395. Target values.
  396. """
  397. xp, is_array_api_compliant = get_namespace(X)
  398. if is_array_api_compliant:
  399. svd = xp.linalg.svd
  400. else:
  401. svd = scipy.linalg.svd
  402. n_samples, n_features = X.shape
  403. n_classes = self.classes_.shape[0]
  404. self.means_ = _class_means(X, y)
  405. if self.store_covariance:
  406. self.covariance_ = _class_cov(X, y, self.priors_)
  407. Xc = []
  408. for idx, group in enumerate(self.classes_):
  409. Xg = X[y == group]
  410. Xc.append(Xg - self.means_[idx, :])
  411. self.xbar_ = self.priors_ @ self.means_
  412. Xc = xp.concat(Xc, axis=0)
  413. # 1) within (univariate) scaling by with classes std-dev
  414. std = xp.std(Xc, axis=0)
  415. # avoid division by zero in normalization
  416. std[std == 0] = 1.0
  417. fac = xp.asarray(1.0 / (n_samples - n_classes))
  418. # 2) Within variance scaling
  419. X = xp.sqrt(fac) * (Xc / std)
  420. # SVD of centered (within)scaled data
  421. U, S, Vt = svd(X, full_matrices=False)
  422. rank = xp.sum(xp.astype(S > self.tol, xp.int32))
  423. # Scaling of within covariance is: V' 1/S
  424. scalings = (Vt[:rank, :] / std).T / S[:rank]
  425. fac = 1.0 if n_classes == 1 else 1.0 / (n_classes - 1)
  426. # 3) Between variance scaling
  427. # Scale weighted centers
  428. X = (
  429. (xp.sqrt((n_samples * self.priors_) * fac)) * (self.means_ - self.xbar_).T
  430. ).T @ scalings
  431. # Centers are living in a space with n_classes-1 dim (maximum)
  432. # Use SVD to find projection in the space spanned by the
  433. # (n_classes) centers
  434. _, S, Vt = svd(X, full_matrices=False)
  435. if self._max_components == 0:
  436. self.explained_variance_ratio_ = xp.empty((0,), dtype=S.dtype)
  437. else:
  438. self.explained_variance_ratio_ = (S**2 / xp.sum(S**2))[
  439. : self._max_components
  440. ]
  441. rank = xp.sum(xp.astype(S > self.tol * S[0], xp.int32))
  442. self.scalings_ = scalings @ Vt.T[:, :rank]
  443. coef = (self.means_ - self.xbar_) @ self.scalings_
  444. self.intercept_ = -0.5 * xp.sum(coef**2, axis=1) + xp.log(self.priors_)
  445. self.coef_ = coef @ self.scalings_.T
  446. self.intercept_ -= self.xbar_ @ self.coef_.T
  447. @_fit_context(
  448. # LinearDiscriminantAnalysis.covariance_estimator is not validated yet
  449. prefer_skip_nested_validation=False
  450. )
  451. def fit(self, X, y):
  452. """Fit the Linear Discriminant Analysis model.
  453. .. versionchanged:: 0.19
  454. *store_covariance* has been moved to main constructor.
  455. .. versionchanged:: 0.19
  456. *tol* has been moved to main constructor.
  457. Parameters
  458. ----------
  459. X : array-like of shape (n_samples, n_features)
  460. Training data.
  461. y : array-like of shape (n_samples,)
  462. Target values.
  463. Returns
  464. -------
  465. self : object
  466. Fitted estimator.
  467. """
  468. xp, _ = get_namespace(X)
  469. X, y = self._validate_data(
  470. X, y, ensure_min_samples=2, dtype=[xp.float64, xp.float32]
  471. )
  472. self.classes_ = unique_labels(y)
  473. n_samples, _ = X.shape
  474. n_classes = self.classes_.shape[0]
  475. if n_samples == n_classes:
  476. raise ValueError(
  477. "The number of samples must be more than the number of classes."
  478. )
  479. if self.priors is None: # estimate priors from sample
  480. _, cnts = xp.unique_counts(y) # non-negative ints
  481. self.priors_ = xp.astype(cnts, X.dtype) / float(y.shape[0])
  482. else:
  483. self.priors_ = xp.asarray(self.priors, dtype=X.dtype)
  484. if xp.any(self.priors_ < 0):
  485. raise ValueError("priors must be non-negative")
  486. if xp.abs(xp.sum(self.priors_) - 1.0) > 1e-5:
  487. warnings.warn("The priors do not sum to 1. Renormalizing", UserWarning)
  488. self.priors_ = self.priors_ / self.priors_.sum()
  489. # Maximum number of components no matter what n_components is
  490. # specified:
  491. max_components = min(n_classes - 1, X.shape[1])
  492. if self.n_components is None:
  493. self._max_components = max_components
  494. else:
  495. if self.n_components > max_components:
  496. raise ValueError(
  497. "n_components cannot be larger than min(n_features, n_classes - 1)."
  498. )
  499. self._max_components = self.n_components
  500. if self.solver == "svd":
  501. if self.shrinkage is not None:
  502. raise NotImplementedError("shrinkage not supported with 'svd' solver.")
  503. if self.covariance_estimator is not None:
  504. raise ValueError(
  505. "covariance estimator "
  506. "is not supported "
  507. "with svd solver. Try another solver"
  508. )
  509. self._solve_svd(X, y)
  510. elif self.solver == "lsqr":
  511. self._solve_lstsq(
  512. X,
  513. y,
  514. shrinkage=self.shrinkage,
  515. covariance_estimator=self.covariance_estimator,
  516. )
  517. elif self.solver == "eigen":
  518. self._solve_eigen(
  519. X,
  520. y,
  521. shrinkage=self.shrinkage,
  522. covariance_estimator=self.covariance_estimator,
  523. )
  524. if size(self.classes_) == 2: # treat binary case as a special case
  525. coef_ = xp.asarray(self.coef_[1, :] - self.coef_[0, :], dtype=X.dtype)
  526. self.coef_ = xp.reshape(coef_, (1, -1))
  527. intercept_ = xp.asarray(
  528. self.intercept_[1] - self.intercept_[0], dtype=X.dtype
  529. )
  530. self.intercept_ = xp.reshape(intercept_, (1,))
  531. self._n_features_out = self._max_components
  532. return self
  533. def transform(self, X):
  534. """Project data to maximize class separation.
  535. Parameters
  536. ----------
  537. X : array-like of shape (n_samples, n_features)
  538. Input data.
  539. Returns
  540. -------
  541. X_new : ndarray of shape (n_samples, n_components) or \
  542. (n_samples, min(rank, n_components))
  543. Transformed data. In the case of the 'svd' solver, the shape
  544. is (n_samples, min(rank, n_components)).
  545. """
  546. if self.solver == "lsqr":
  547. raise NotImplementedError(
  548. "transform not implemented for 'lsqr' solver (use 'svd' or 'eigen')."
  549. )
  550. check_is_fitted(self)
  551. xp, _ = get_namespace(X)
  552. X = self._validate_data(X, reset=False)
  553. if self.solver == "svd":
  554. X_new = (X - self.xbar_) @ self.scalings_
  555. elif self.solver == "eigen":
  556. X_new = X @ self.scalings_
  557. return X_new[:, : self._max_components]
  558. def predict_proba(self, X):
  559. """Estimate probability.
  560. Parameters
  561. ----------
  562. X : array-like of shape (n_samples, n_features)
  563. Input data.
  564. Returns
  565. -------
  566. C : ndarray of shape (n_samples, n_classes)
  567. Estimated probabilities.
  568. """
  569. check_is_fitted(self)
  570. xp, is_array_api_compliant = get_namespace(X)
  571. decision = self.decision_function(X)
  572. if size(self.classes_) == 2:
  573. proba = _expit(decision)
  574. return xp.stack([1 - proba, proba], axis=1)
  575. else:
  576. return softmax(decision)
  577. def predict_log_proba(self, X):
  578. """Estimate log probability.
  579. Parameters
  580. ----------
  581. X : array-like of shape (n_samples, n_features)
  582. Input data.
  583. Returns
  584. -------
  585. C : ndarray of shape (n_samples, n_classes)
  586. Estimated log probabilities.
  587. """
  588. xp, _ = get_namespace(X)
  589. prediction = self.predict_proba(X)
  590. info = xp.finfo(prediction.dtype)
  591. if hasattr(info, "smallest_normal"):
  592. smallest_normal = info.smallest_normal
  593. else:
  594. # smallest_normal was introduced in NumPy 1.22
  595. smallest_normal = info.tiny
  596. prediction[prediction == 0.0] += smallest_normal
  597. return xp.log(prediction)
  598. def decision_function(self, X):
  599. """Apply decision function to an array of samples.
  600. The decision function is equal (up to a constant factor) to the
  601. log-posterior of the model, i.e. `log p(y = k | x)`. In a binary
  602. classification setting this instead corresponds to the difference
  603. `log p(y = 1 | x) - log p(y = 0 | x)`. See :ref:`lda_qda_math`.
  604. Parameters
  605. ----------
  606. X : array-like of shape (n_samples, n_features)
  607. Array of samples (test vectors).
  608. Returns
  609. -------
  610. C : ndarray of shape (n_samples,) or (n_samples, n_classes)
  611. Decision function values related to each class, per sample.
  612. In the two-class case, the shape is (n_samples,), giving the
  613. log likelihood ratio of the positive class.
  614. """
  615. # Only override for the doc
  616. return super().decision_function(X)
  617. def _more_tags(self):
  618. return {"array_api_support": True}
  619. class QuadraticDiscriminantAnalysis(ClassifierMixin, BaseEstimator):
  620. """Quadratic Discriminant Analysis.
  621. A classifier with a quadratic decision boundary, generated
  622. by fitting class conditional densities to the data
  623. and using Bayes' rule.
  624. The model fits a Gaussian density to each class.
  625. .. versionadded:: 0.17
  626. *QuadraticDiscriminantAnalysis*
  627. Read more in the :ref:`User Guide <lda_qda>`.
  628. Parameters
  629. ----------
  630. priors : array-like of shape (n_classes,), default=None
  631. Class priors. By default, the class proportions are inferred from the
  632. training data.
  633. reg_param : float, default=0.0
  634. Regularizes the per-class covariance estimates by transforming S2 as
  635. ``S2 = (1 - reg_param) * S2 + reg_param * np.eye(n_features)``,
  636. where S2 corresponds to the `scaling_` attribute of a given class.
  637. store_covariance : bool, default=False
  638. If True, the class covariance matrices are explicitly computed and
  639. stored in the `self.covariance_` attribute.
  640. .. versionadded:: 0.17
  641. tol : float, default=1.0e-4
  642. Absolute threshold for a singular value to be considered significant,
  643. used to estimate the rank of `Xk` where `Xk` is the centered matrix
  644. of samples in class k. This parameter does not affect the
  645. predictions. It only controls a warning that is raised when features
  646. are considered to be colinear.
  647. .. versionadded:: 0.17
  648. Attributes
  649. ----------
  650. covariance_ : list of len n_classes of ndarray \
  651. of shape (n_features, n_features)
  652. For each class, gives the covariance matrix estimated using the
  653. samples of that class. The estimations are unbiased. Only present if
  654. `store_covariance` is True.
  655. means_ : array-like of shape (n_classes, n_features)
  656. Class-wise means.
  657. priors_ : array-like of shape (n_classes,)
  658. Class priors (sum to 1).
  659. rotations_ : list of len n_classes of ndarray of shape (n_features, n_k)
  660. For each class k an array of shape (n_features, n_k), where
  661. ``n_k = min(n_features, number of elements in class k)``
  662. It is the rotation of the Gaussian distribution, i.e. its
  663. principal axis. It corresponds to `V`, the matrix of eigenvectors
  664. coming from the SVD of `Xk = U S Vt` where `Xk` is the centered
  665. matrix of samples from class k.
  666. scalings_ : list of len n_classes of ndarray of shape (n_k,)
  667. For each class, contains the scaling of
  668. the Gaussian distributions along its principal axes, i.e. the
  669. variance in the rotated coordinate system. It corresponds to `S^2 /
  670. (n_samples - 1)`, where `S` is the diagonal matrix of singular values
  671. from the SVD of `Xk`, where `Xk` is the centered matrix of samples
  672. from class k.
  673. classes_ : ndarray of shape (n_classes,)
  674. Unique class labels.
  675. n_features_in_ : int
  676. Number of features seen during :term:`fit`.
  677. .. versionadded:: 0.24
  678. feature_names_in_ : ndarray of shape (`n_features_in_`,)
  679. Names of features seen during :term:`fit`. Defined only when `X`
  680. has feature names that are all strings.
  681. .. versionadded:: 1.0
  682. See Also
  683. --------
  684. LinearDiscriminantAnalysis : Linear Discriminant Analysis.
  685. Examples
  686. --------
  687. >>> from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
  688. >>> import numpy as np
  689. >>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
  690. >>> y = np.array([1, 1, 1, 2, 2, 2])
  691. >>> clf = QuadraticDiscriminantAnalysis()
  692. >>> clf.fit(X, y)
  693. QuadraticDiscriminantAnalysis()
  694. >>> print(clf.predict([[-0.8, -1]]))
  695. [1]
  696. """
  697. _parameter_constraints: dict = {
  698. "priors": ["array-like", None],
  699. "reg_param": [Interval(Real, 0, 1, closed="both")],
  700. "store_covariance": ["boolean"],
  701. "tol": [Interval(Real, 0, None, closed="left")],
  702. }
  703. def __init__(
  704. self, *, priors=None, reg_param=0.0, store_covariance=False, tol=1.0e-4
  705. ):
  706. self.priors = priors
  707. self.reg_param = reg_param
  708. self.store_covariance = store_covariance
  709. self.tol = tol
  710. @_fit_context(prefer_skip_nested_validation=True)
  711. def fit(self, X, y):
  712. """Fit the model according to the given training data and parameters.
  713. .. versionchanged:: 0.19
  714. ``store_covariances`` has been moved to main constructor as
  715. ``store_covariance``
  716. .. versionchanged:: 0.19
  717. ``tol`` has been moved to main constructor.
  718. Parameters
  719. ----------
  720. X : array-like of shape (n_samples, n_features)
  721. Training vector, where `n_samples` is the number of samples and
  722. `n_features` is the number of features.
  723. y : array-like of shape (n_samples,)
  724. Target values (integers).
  725. Returns
  726. -------
  727. self : object
  728. Fitted estimator.
  729. """
  730. X, y = self._validate_data(X, y)
  731. check_classification_targets(y)
  732. self.classes_, y = np.unique(y, return_inverse=True)
  733. n_samples, n_features = X.shape
  734. n_classes = len(self.classes_)
  735. if n_classes < 2:
  736. raise ValueError(
  737. "The number of classes has to be greater than one; got %d class"
  738. % (n_classes)
  739. )
  740. if self.priors is None:
  741. self.priors_ = np.bincount(y) / float(n_samples)
  742. else:
  743. self.priors_ = np.array(self.priors)
  744. cov = None
  745. store_covariance = self.store_covariance
  746. if store_covariance:
  747. cov = []
  748. means = []
  749. scalings = []
  750. rotations = []
  751. for ind in range(n_classes):
  752. Xg = X[y == ind, :]
  753. meang = Xg.mean(0)
  754. means.append(meang)
  755. if len(Xg) == 1:
  756. raise ValueError(
  757. "y has only 1 sample in class %s, covariance is ill defined."
  758. % str(self.classes_[ind])
  759. )
  760. Xgc = Xg - meang
  761. # Xgc = U * S * V.T
  762. _, S, Vt = np.linalg.svd(Xgc, full_matrices=False)
  763. rank = np.sum(S > self.tol)
  764. if rank < n_features:
  765. warnings.warn("Variables are collinear")
  766. S2 = (S**2) / (len(Xg) - 1)
  767. S2 = ((1 - self.reg_param) * S2) + self.reg_param
  768. if self.store_covariance or store_covariance:
  769. # cov = V * (S^2 / (n-1)) * V.T
  770. cov.append(np.dot(S2 * Vt.T, Vt))
  771. scalings.append(S2)
  772. rotations.append(Vt.T)
  773. if self.store_covariance or store_covariance:
  774. self.covariance_ = cov
  775. self.means_ = np.asarray(means)
  776. self.scalings_ = scalings
  777. self.rotations_ = rotations
  778. return self
  779. def _decision_function(self, X):
  780. # return log posterior, see eq (4.12) p. 110 of the ESL.
  781. check_is_fitted(self)
  782. X = self._validate_data(X, reset=False)
  783. norm2 = []
  784. for i in range(len(self.classes_)):
  785. R = self.rotations_[i]
  786. S = self.scalings_[i]
  787. Xm = X - self.means_[i]
  788. X2 = np.dot(Xm, R * (S ** (-0.5)))
  789. norm2.append(np.sum(X2**2, axis=1))
  790. norm2 = np.array(norm2).T # shape = [len(X), n_classes]
  791. u = np.asarray([np.sum(np.log(s)) for s in self.scalings_])
  792. return -0.5 * (norm2 + u) + np.log(self.priors_)
  793. def decision_function(self, X):
  794. """Apply decision function to an array of samples.
  795. The decision function is equal (up to a constant factor) to the
  796. log-posterior of the model, i.e. `log p(y = k | x)`. In a binary
  797. classification setting this instead corresponds to the difference
  798. `log p(y = 1 | x) - log p(y = 0 | x)`. See :ref:`lda_qda_math`.
  799. Parameters
  800. ----------
  801. X : array-like of shape (n_samples, n_features)
  802. Array of samples (test vectors).
  803. Returns
  804. -------
  805. C : ndarray of shape (n_samples,) or (n_samples, n_classes)
  806. Decision function values related to each class, per sample.
  807. In the two-class case, the shape is (n_samples,), giving the
  808. log likelihood ratio of the positive class.
  809. """
  810. dec_func = self._decision_function(X)
  811. # handle special case of two classes
  812. if len(self.classes_) == 2:
  813. return dec_func[:, 1] - dec_func[:, 0]
  814. return dec_func
  815. def predict(self, X):
  816. """Perform classification on an array of test vectors X.
  817. The predicted class C for each sample in X is returned.
  818. Parameters
  819. ----------
  820. X : array-like of shape (n_samples, n_features)
  821. Vector to be scored, where `n_samples` is the number of samples and
  822. `n_features` is the number of features.
  823. Returns
  824. -------
  825. C : ndarray of shape (n_samples,)
  826. Estimated probabilities.
  827. """
  828. d = self._decision_function(X)
  829. y_pred = self.classes_.take(d.argmax(1))
  830. return y_pred
  831. def predict_proba(self, X):
  832. """Return posterior probabilities of classification.
  833. Parameters
  834. ----------
  835. X : array-like of shape (n_samples, n_features)
  836. Array of samples/test vectors.
  837. Returns
  838. -------
  839. C : ndarray of shape (n_samples, n_classes)
  840. Posterior probabilities of classification per class.
  841. """
  842. values = self._decision_function(X)
  843. # compute the likelihood of the underlying gaussian models
  844. # up to a multiplicative constant.
  845. likelihood = np.exp(values - values.max(axis=1)[:, np.newaxis])
  846. # compute posterior probabilities
  847. return likelihood / likelihood.sum(axis=1)[:, np.newaxis]
  848. def predict_log_proba(self, X):
  849. """Return log of posterior probabilities of classification.
  850. Parameters
  851. ----------
  852. X : array-like of shape (n_samples, n_features)
  853. Array of samples/test vectors.
  854. Returns
  855. -------
  856. C : ndarray of shape (n_samples, n_classes)
  857. Posterior log-probabilities of classification per class.
  858. """
  859. # XXX : can do better to avoid precision overflows
  860. probas_ = self.predict_proba(X)
  861. return np.log(probas_)