test_discriminant_analysis.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670
  1. import numpy as np
  2. import pytest
  3. from scipy import linalg
  4. from sklearn.cluster import KMeans
  5. from sklearn.covariance import LedoitWolf, ShrunkCovariance, ledoit_wolf
  6. from sklearn.datasets import make_blobs
  7. from sklearn.discriminant_analysis import (
  8. LinearDiscriminantAnalysis,
  9. QuadraticDiscriminantAnalysis,
  10. _cov,
  11. )
  12. from sklearn.preprocessing import StandardScaler
  13. from sklearn.utils import check_random_state
  14. from sklearn.utils._testing import (
  15. _convert_container,
  16. assert_allclose,
  17. assert_almost_equal,
  18. assert_array_almost_equal,
  19. assert_array_equal,
  20. )
  21. # Data is just 6 separable points in the plane
  22. X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]], dtype="f")
  23. y = np.array([1, 1, 1, 2, 2, 2])
  24. y3 = np.array([1, 1, 2, 2, 3, 3])
  25. # Degenerate data with only one feature (still should be separable)
  26. X1 = np.array(
  27. [[-2], [-1], [-1], [1], [1], [2]],
  28. dtype="f",
  29. )
  30. # Data is just 9 separable points in the plane
  31. X6 = np.array(
  32. [[0, 0], [-2, -2], [-2, -1], [-1, -1], [-1, -2], [1, 3], [1, 2], [2, 1], [2, 2]]
  33. )
  34. y6 = np.array([1, 1, 1, 1, 1, 2, 2, 2, 2])
  35. y7 = np.array([1, 2, 3, 2, 3, 1, 2, 3, 1])
  36. # Degenerate data with 1 feature (still should be separable)
  37. X7 = np.array([[-3], [-2], [-1], [-1], [0], [1], [1], [2], [3]])
  38. # Data that has zero variance in one dimension and needs regularization
  39. X2 = np.array(
  40. [[-3, 0], [-2, 0], [-1, 0], [-1, 0], [0, 0], [1, 0], [1, 0], [2, 0], [3, 0]]
  41. )
  42. # One element class
  43. y4 = np.array([1, 1, 1, 1, 1, 1, 1, 1, 2])
  44. # Data with less samples in a class than n_features
  45. X5 = np.c_[np.arange(8), np.zeros((8, 3))]
  46. y5 = np.array([0, 0, 0, 0, 0, 1, 1, 1])
  47. solver_shrinkage = [
  48. ("svd", None),
  49. ("lsqr", None),
  50. ("eigen", None),
  51. ("lsqr", "auto"),
  52. ("lsqr", 0),
  53. ("lsqr", 0.43),
  54. ("eigen", "auto"),
  55. ("eigen", 0),
  56. ("eigen", 0.43),
  57. ]
  58. def test_lda_predict():
  59. # Test LDA classification.
  60. # This checks that LDA implements fit and predict and returns correct
  61. # values for simple toy data.
  62. for test_case in solver_shrinkage:
  63. solver, shrinkage = test_case
  64. clf = LinearDiscriminantAnalysis(solver=solver, shrinkage=shrinkage)
  65. y_pred = clf.fit(X, y).predict(X)
  66. assert_array_equal(y_pred, y, "solver %s" % solver)
  67. # Assert that it works with 1D data
  68. y_pred1 = clf.fit(X1, y).predict(X1)
  69. assert_array_equal(y_pred1, y, "solver %s" % solver)
  70. # Test probability estimates
  71. y_proba_pred1 = clf.predict_proba(X1)
  72. assert_array_equal((y_proba_pred1[:, 1] > 0.5) + 1, y, "solver %s" % solver)
  73. y_log_proba_pred1 = clf.predict_log_proba(X1)
  74. assert_allclose(
  75. np.exp(y_log_proba_pred1),
  76. y_proba_pred1,
  77. rtol=1e-6,
  78. atol=1e-6,
  79. err_msg="solver %s" % solver,
  80. )
  81. # Primarily test for commit 2f34950 -- "reuse" of priors
  82. y_pred3 = clf.fit(X, y3).predict(X)
  83. # LDA shouldn't be able to separate those
  84. assert np.any(y_pred3 != y3), "solver %s" % solver
  85. clf = LinearDiscriminantAnalysis(solver="svd", shrinkage="auto")
  86. with pytest.raises(NotImplementedError):
  87. clf.fit(X, y)
  88. clf = LinearDiscriminantAnalysis(
  89. solver="lsqr", shrinkage=0.1, covariance_estimator=ShrunkCovariance()
  90. )
  91. with pytest.raises(
  92. ValueError,
  93. match=(
  94. "covariance_estimator and shrinkage "
  95. "parameters are not None. "
  96. "Only one of the two can be set."
  97. ),
  98. ):
  99. clf.fit(X, y)
  100. # test bad solver with covariance_estimator
  101. clf = LinearDiscriminantAnalysis(solver="svd", covariance_estimator=LedoitWolf())
  102. with pytest.raises(
  103. ValueError, match="covariance estimator is not supported with svd"
  104. ):
  105. clf.fit(X, y)
  106. # test bad covariance estimator
  107. clf = LinearDiscriminantAnalysis(
  108. solver="lsqr", covariance_estimator=KMeans(n_clusters=2, n_init="auto")
  109. )
  110. with pytest.raises(ValueError):
  111. clf.fit(X, y)
  112. @pytest.mark.parametrize("n_classes", [2, 3])
  113. @pytest.mark.parametrize("solver", ["svd", "lsqr", "eigen"])
  114. def test_lda_predict_proba(solver, n_classes):
  115. def generate_dataset(n_samples, centers, covariances, random_state=None):
  116. """Generate a multivariate normal data given some centers and
  117. covariances"""
  118. rng = check_random_state(random_state)
  119. X = np.vstack(
  120. [
  121. rng.multivariate_normal(mean, cov, size=n_samples // len(centers))
  122. for mean, cov in zip(centers, covariances)
  123. ]
  124. )
  125. y = np.hstack(
  126. [[clazz] * (n_samples // len(centers)) for clazz in range(len(centers))]
  127. )
  128. return X, y
  129. blob_centers = np.array([[0, 0], [-10, 40], [-30, 30]])[:n_classes]
  130. blob_stds = np.array([[[10, 10], [10, 100]]] * len(blob_centers))
  131. X, y = generate_dataset(
  132. n_samples=90000, centers=blob_centers, covariances=blob_stds, random_state=42
  133. )
  134. lda = LinearDiscriminantAnalysis(
  135. solver=solver, store_covariance=True, shrinkage=None
  136. ).fit(X, y)
  137. # check that the empirical means and covariances are close enough to the
  138. # one used to generate the data
  139. assert_allclose(lda.means_, blob_centers, atol=1e-1)
  140. assert_allclose(lda.covariance_, blob_stds[0], atol=1)
  141. # implement the method to compute the probability given in The Elements
  142. # of Statistical Learning (cf. p.127, Sect. 4.4.5 "Logistic Regression
  143. # or LDA?")
  144. precision = linalg.inv(blob_stds[0])
  145. alpha_k = []
  146. alpha_k_0 = []
  147. for clazz in range(len(blob_centers) - 1):
  148. alpha_k.append(
  149. np.dot(precision, (blob_centers[clazz] - blob_centers[-1])[:, np.newaxis])
  150. )
  151. alpha_k_0.append(
  152. np.dot(
  153. -0.5 * (blob_centers[clazz] + blob_centers[-1])[np.newaxis, :],
  154. alpha_k[-1],
  155. )
  156. )
  157. sample = np.array([[-22, 22]])
  158. def discriminant_func(sample, coef, intercept, clazz):
  159. return np.exp(intercept[clazz] + np.dot(sample, coef[clazz])).item()
  160. prob = np.array(
  161. [
  162. float(
  163. discriminant_func(sample, alpha_k, alpha_k_0, clazz)
  164. / (
  165. 1
  166. + sum(
  167. [
  168. discriminant_func(sample, alpha_k, alpha_k_0, clazz)
  169. for clazz in range(n_classes - 1)
  170. ]
  171. )
  172. )
  173. )
  174. for clazz in range(n_classes - 1)
  175. ]
  176. )
  177. prob_ref = 1 - np.sum(prob)
  178. # check the consistency of the computed probability
  179. # all probabilities should sum to one
  180. prob_ref_2 = float(
  181. 1
  182. / (
  183. 1
  184. + sum(
  185. [
  186. discriminant_func(sample, alpha_k, alpha_k_0, clazz)
  187. for clazz in range(n_classes - 1)
  188. ]
  189. )
  190. )
  191. )
  192. assert prob_ref == pytest.approx(prob_ref_2)
  193. # check that the probability of LDA are close to the theoretical
  194. # probabilities
  195. assert_allclose(
  196. lda.predict_proba(sample), np.hstack([prob, prob_ref])[np.newaxis], atol=1e-2
  197. )
  198. def test_lda_priors():
  199. # Test priors (negative priors)
  200. priors = np.array([0.5, -0.5])
  201. clf = LinearDiscriminantAnalysis(priors=priors)
  202. msg = "priors must be non-negative"
  203. with pytest.raises(ValueError, match=msg):
  204. clf.fit(X, y)
  205. # Test that priors passed as a list are correctly handled (run to see if
  206. # failure)
  207. clf = LinearDiscriminantAnalysis(priors=[0.5, 0.5])
  208. clf.fit(X, y)
  209. # Test that priors always sum to 1
  210. priors = np.array([0.5, 0.6])
  211. prior_norm = np.array([0.45, 0.55])
  212. clf = LinearDiscriminantAnalysis(priors=priors)
  213. with pytest.warns(UserWarning):
  214. clf.fit(X, y)
  215. assert_array_almost_equal(clf.priors_, prior_norm, 2)
  216. def test_lda_coefs():
  217. # Test if the coefficients of the solvers are approximately the same.
  218. n_features = 2
  219. n_classes = 2
  220. n_samples = 1000
  221. X, y = make_blobs(
  222. n_samples=n_samples, n_features=n_features, centers=n_classes, random_state=11
  223. )
  224. clf_lda_svd = LinearDiscriminantAnalysis(solver="svd")
  225. clf_lda_lsqr = LinearDiscriminantAnalysis(solver="lsqr")
  226. clf_lda_eigen = LinearDiscriminantAnalysis(solver="eigen")
  227. clf_lda_svd.fit(X, y)
  228. clf_lda_lsqr.fit(X, y)
  229. clf_lda_eigen.fit(X, y)
  230. assert_array_almost_equal(clf_lda_svd.coef_, clf_lda_lsqr.coef_, 1)
  231. assert_array_almost_equal(clf_lda_svd.coef_, clf_lda_eigen.coef_, 1)
  232. assert_array_almost_equal(clf_lda_eigen.coef_, clf_lda_lsqr.coef_, 1)
  233. def test_lda_transform():
  234. # Test LDA transform.
  235. clf = LinearDiscriminantAnalysis(solver="svd", n_components=1)
  236. X_transformed = clf.fit(X, y).transform(X)
  237. assert X_transformed.shape[1] == 1
  238. clf = LinearDiscriminantAnalysis(solver="eigen", n_components=1)
  239. X_transformed = clf.fit(X, y).transform(X)
  240. assert X_transformed.shape[1] == 1
  241. clf = LinearDiscriminantAnalysis(solver="lsqr", n_components=1)
  242. clf.fit(X, y)
  243. msg = "transform not implemented for 'lsqr'"
  244. with pytest.raises(NotImplementedError, match=msg):
  245. clf.transform(X)
  246. def test_lda_explained_variance_ratio():
  247. # Test if the sum of the normalized eigen vectors values equals 1,
  248. # Also tests whether the explained_variance_ratio_ formed by the
  249. # eigen solver is the same as the explained_variance_ratio_ formed
  250. # by the svd solver
  251. state = np.random.RandomState(0)
  252. X = state.normal(loc=0, scale=100, size=(40, 20))
  253. y = state.randint(0, 3, size=(40,))
  254. clf_lda_eigen = LinearDiscriminantAnalysis(solver="eigen")
  255. clf_lda_eigen.fit(X, y)
  256. assert_almost_equal(clf_lda_eigen.explained_variance_ratio_.sum(), 1.0, 3)
  257. assert clf_lda_eigen.explained_variance_ratio_.shape == (
  258. 2,
  259. ), "Unexpected length for explained_variance_ratio_"
  260. clf_lda_svd = LinearDiscriminantAnalysis(solver="svd")
  261. clf_lda_svd.fit(X, y)
  262. assert_almost_equal(clf_lda_svd.explained_variance_ratio_.sum(), 1.0, 3)
  263. assert clf_lda_svd.explained_variance_ratio_.shape == (
  264. 2,
  265. ), "Unexpected length for explained_variance_ratio_"
  266. assert_array_almost_equal(
  267. clf_lda_svd.explained_variance_ratio_, clf_lda_eigen.explained_variance_ratio_
  268. )
  269. def test_lda_orthogonality():
  270. # arrange four classes with their means in a kite-shaped pattern
  271. # the longer distance should be transformed to the first component, and
  272. # the shorter distance to the second component.
  273. means = np.array([[0, 0, -1], [0, 2, 0], [0, -2, 0], [0, 0, 5]])
  274. # We construct perfectly symmetric distributions, so the LDA can estimate
  275. # precise means.
  276. scatter = np.array(
  277. [
  278. [0.1, 0, 0],
  279. [-0.1, 0, 0],
  280. [0, 0.1, 0],
  281. [0, -0.1, 0],
  282. [0, 0, 0.1],
  283. [0, 0, -0.1],
  284. ]
  285. )
  286. X = (means[:, np.newaxis, :] + scatter[np.newaxis, :, :]).reshape((-1, 3))
  287. y = np.repeat(np.arange(means.shape[0]), scatter.shape[0])
  288. # Fit LDA and transform the means
  289. clf = LinearDiscriminantAnalysis(solver="svd").fit(X, y)
  290. means_transformed = clf.transform(means)
  291. d1 = means_transformed[3] - means_transformed[0]
  292. d2 = means_transformed[2] - means_transformed[1]
  293. d1 /= np.sqrt(np.sum(d1**2))
  294. d2 /= np.sqrt(np.sum(d2**2))
  295. # the transformed within-class covariance should be the identity matrix
  296. assert_almost_equal(np.cov(clf.transform(scatter).T), np.eye(2))
  297. # the means of classes 0 and 3 should lie on the first component
  298. assert_almost_equal(np.abs(np.dot(d1[:2], [1, 0])), 1.0)
  299. # the means of classes 1 and 2 should lie on the second component
  300. assert_almost_equal(np.abs(np.dot(d2[:2], [0, 1])), 1.0)
  301. def test_lda_scaling():
  302. # Test if classification works correctly with differently scaled features.
  303. n = 100
  304. rng = np.random.RandomState(1234)
  305. # use uniform distribution of features to make sure there is absolutely no
  306. # overlap between classes.
  307. x1 = rng.uniform(-1, 1, (n, 3)) + [-10, 0, 0]
  308. x2 = rng.uniform(-1, 1, (n, 3)) + [10, 0, 0]
  309. x = np.vstack((x1, x2)) * [1, 100, 10000]
  310. y = [-1] * n + [1] * n
  311. for solver in ("svd", "lsqr", "eigen"):
  312. clf = LinearDiscriminantAnalysis(solver=solver)
  313. # should be able to separate the data perfectly
  314. assert clf.fit(x, y).score(x, y) == 1.0, "using covariance: %s" % solver
  315. def test_lda_store_covariance():
  316. # Test for solver 'lsqr' and 'eigen'
  317. # 'store_covariance' has no effect on 'lsqr' and 'eigen' solvers
  318. for solver in ("lsqr", "eigen"):
  319. clf = LinearDiscriminantAnalysis(solver=solver).fit(X6, y6)
  320. assert hasattr(clf, "covariance_")
  321. # Test the actual attribute:
  322. clf = LinearDiscriminantAnalysis(solver=solver, store_covariance=True).fit(
  323. X6, y6
  324. )
  325. assert hasattr(clf, "covariance_")
  326. assert_array_almost_equal(
  327. clf.covariance_, np.array([[0.422222, 0.088889], [0.088889, 0.533333]])
  328. )
  329. # Test for SVD solver, the default is to not set the covariances_ attribute
  330. clf = LinearDiscriminantAnalysis(solver="svd").fit(X6, y6)
  331. assert not hasattr(clf, "covariance_")
  332. # Test the actual attribute:
  333. clf = LinearDiscriminantAnalysis(solver=solver, store_covariance=True).fit(X6, y6)
  334. assert hasattr(clf, "covariance_")
  335. assert_array_almost_equal(
  336. clf.covariance_, np.array([[0.422222, 0.088889], [0.088889, 0.533333]])
  337. )
  338. @pytest.mark.parametrize("seed", range(10))
  339. def test_lda_shrinkage(seed):
  340. # Test that shrunk covariance estimator and shrinkage parameter behave the
  341. # same
  342. rng = np.random.RandomState(seed)
  343. X = rng.rand(100, 10)
  344. y = rng.randint(3, size=(100))
  345. c1 = LinearDiscriminantAnalysis(store_covariance=True, shrinkage=0.5, solver="lsqr")
  346. c2 = LinearDiscriminantAnalysis(
  347. store_covariance=True,
  348. covariance_estimator=ShrunkCovariance(shrinkage=0.5),
  349. solver="lsqr",
  350. )
  351. c1.fit(X, y)
  352. c2.fit(X, y)
  353. assert_allclose(c1.means_, c2.means_)
  354. assert_allclose(c1.covariance_, c2.covariance_)
  355. def test_lda_ledoitwolf():
  356. # When shrinkage="auto" current implementation uses ledoitwolf estimation
  357. # of covariance after standardizing the data. This checks that it is indeed
  358. # the case
  359. class StandardizedLedoitWolf:
  360. def fit(self, X):
  361. sc = StandardScaler() # standardize features
  362. X_sc = sc.fit_transform(X)
  363. s = ledoit_wolf(X_sc)[0]
  364. # rescale
  365. s = sc.scale_[:, np.newaxis] * s * sc.scale_[np.newaxis, :]
  366. self.covariance_ = s
  367. rng = np.random.RandomState(0)
  368. X = rng.rand(100, 10)
  369. y = rng.randint(3, size=(100,))
  370. c1 = LinearDiscriminantAnalysis(
  371. store_covariance=True, shrinkage="auto", solver="lsqr"
  372. )
  373. c2 = LinearDiscriminantAnalysis(
  374. store_covariance=True,
  375. covariance_estimator=StandardizedLedoitWolf(),
  376. solver="lsqr",
  377. )
  378. c1.fit(X, y)
  379. c2.fit(X, y)
  380. assert_allclose(c1.means_, c2.means_)
  381. assert_allclose(c1.covariance_, c2.covariance_)
  382. @pytest.mark.parametrize("n_features", [3, 5])
  383. @pytest.mark.parametrize("n_classes", [5, 3])
  384. def test_lda_dimension_warning(n_classes, n_features):
  385. rng = check_random_state(0)
  386. n_samples = 10
  387. X = rng.randn(n_samples, n_features)
  388. # we create n_classes labels by repeating and truncating a
  389. # range(n_classes) until n_samples
  390. y = np.tile(range(n_classes), n_samples // n_classes + 1)[:n_samples]
  391. max_components = min(n_features, n_classes - 1)
  392. for n_components in [max_components - 1, None, max_components]:
  393. # if n_components <= min(n_classes - 1, n_features), no warning
  394. lda = LinearDiscriminantAnalysis(n_components=n_components)
  395. lda.fit(X, y)
  396. for n_components in [max_components + 1, max(n_features, n_classes - 1) + 1]:
  397. # if n_components > min(n_classes - 1, n_features), raise error.
  398. # We test one unit higher than max_components, and then something
  399. # larger than both n_features and n_classes - 1 to ensure the test
  400. # works for any value of n_component
  401. lda = LinearDiscriminantAnalysis(n_components=n_components)
  402. msg = "n_components cannot be larger than "
  403. with pytest.raises(ValueError, match=msg):
  404. lda.fit(X, y)
  405. @pytest.mark.parametrize(
  406. "data_type, expected_type",
  407. [
  408. (np.float32, np.float32),
  409. (np.float64, np.float64),
  410. (np.int32, np.float64),
  411. (np.int64, np.float64),
  412. ],
  413. )
  414. def test_lda_dtype_match(data_type, expected_type):
  415. for solver, shrinkage in solver_shrinkage:
  416. clf = LinearDiscriminantAnalysis(solver=solver, shrinkage=shrinkage)
  417. clf.fit(X.astype(data_type), y.astype(data_type))
  418. assert clf.coef_.dtype == expected_type
  419. def test_lda_numeric_consistency_float32_float64():
  420. for solver, shrinkage in solver_shrinkage:
  421. clf_32 = LinearDiscriminantAnalysis(solver=solver, shrinkage=shrinkage)
  422. clf_32.fit(X.astype(np.float32), y.astype(np.float32))
  423. clf_64 = LinearDiscriminantAnalysis(solver=solver, shrinkage=shrinkage)
  424. clf_64.fit(X.astype(np.float64), y.astype(np.float64))
  425. # Check value consistency between types
  426. rtol = 1e-6
  427. assert_allclose(clf_32.coef_, clf_64.coef_, rtol=rtol)
  428. def test_qda():
  429. # QDA classification.
  430. # This checks that QDA implements fit and predict and returns
  431. # correct values for a simple toy dataset.
  432. clf = QuadraticDiscriminantAnalysis()
  433. y_pred = clf.fit(X6, y6).predict(X6)
  434. assert_array_equal(y_pred, y6)
  435. # Assure that it works with 1D data
  436. y_pred1 = clf.fit(X7, y6).predict(X7)
  437. assert_array_equal(y_pred1, y6)
  438. # Test probas estimates
  439. y_proba_pred1 = clf.predict_proba(X7)
  440. assert_array_equal((y_proba_pred1[:, 1] > 0.5) + 1, y6)
  441. y_log_proba_pred1 = clf.predict_log_proba(X7)
  442. assert_array_almost_equal(np.exp(y_log_proba_pred1), y_proba_pred1, 8)
  443. y_pred3 = clf.fit(X6, y7).predict(X6)
  444. # QDA shouldn't be able to separate those
  445. assert np.any(y_pred3 != y7)
  446. # Classes should have at least 2 elements
  447. with pytest.raises(ValueError):
  448. clf.fit(X6, y4)
  449. def test_qda_priors():
  450. clf = QuadraticDiscriminantAnalysis()
  451. y_pred = clf.fit(X6, y6).predict(X6)
  452. n_pos = np.sum(y_pred == 2)
  453. neg = 1e-10
  454. clf = QuadraticDiscriminantAnalysis(priors=np.array([neg, 1 - neg]))
  455. y_pred = clf.fit(X6, y6).predict(X6)
  456. n_pos2 = np.sum(y_pred == 2)
  457. assert n_pos2 > n_pos
  458. @pytest.mark.parametrize("priors_type", ["list", "tuple", "array"])
  459. def test_qda_prior_type(priors_type):
  460. """Check that priors accept array-like."""
  461. priors = [0.5, 0.5]
  462. clf = QuadraticDiscriminantAnalysis(
  463. priors=_convert_container([0.5, 0.5], priors_type)
  464. ).fit(X6, y6)
  465. assert isinstance(clf.priors_, np.ndarray)
  466. assert_array_equal(clf.priors_, priors)
  467. def test_qda_prior_copy():
  468. """Check that altering `priors` without `fit` doesn't change `priors_`"""
  469. priors = np.array([0.5, 0.5])
  470. qda = QuadraticDiscriminantAnalysis(priors=priors).fit(X, y)
  471. # we expect the following
  472. assert_array_equal(qda.priors_, qda.priors)
  473. # altering `priors` without `fit` should not change `priors_`
  474. priors[0] = 0.2
  475. assert qda.priors_[0] != qda.priors[0]
  476. def test_qda_store_covariance():
  477. # The default is to not set the covariances_ attribute
  478. clf = QuadraticDiscriminantAnalysis().fit(X6, y6)
  479. assert not hasattr(clf, "covariance_")
  480. # Test the actual attribute:
  481. clf = QuadraticDiscriminantAnalysis(store_covariance=True).fit(X6, y6)
  482. assert hasattr(clf, "covariance_")
  483. assert_array_almost_equal(clf.covariance_[0], np.array([[0.7, 0.45], [0.45, 0.7]]))
  484. assert_array_almost_equal(
  485. clf.covariance_[1],
  486. np.array([[0.33333333, -0.33333333], [-0.33333333, 0.66666667]]),
  487. )
  488. def test_qda_regularization():
  489. # The default is reg_param=0. and will cause issues when there is a
  490. # constant variable.
  491. # Fitting on data with constant variable triggers an UserWarning.
  492. collinear_msg = "Variables are collinear"
  493. clf = QuadraticDiscriminantAnalysis()
  494. with pytest.warns(UserWarning, match=collinear_msg):
  495. y_pred = clf.fit(X2, y6)
  496. # XXX: RuntimeWarning is also raised at predict time because of divisions
  497. # by zero when the model is fit with a constant feature and without
  498. # regularization: should this be considered a bug? Either by the fit-time
  499. # message more informative, raising and exception instead of a warning in
  500. # this case or somehow changing predict to avoid division by zero.
  501. with pytest.warns(RuntimeWarning, match="divide by zero"):
  502. y_pred = clf.predict(X2)
  503. assert np.any(y_pred != y6)
  504. # Adding a little regularization fixes the division by zero at predict
  505. # time. But UserWarning will persist at fit time.
  506. clf = QuadraticDiscriminantAnalysis(reg_param=0.01)
  507. with pytest.warns(UserWarning, match=collinear_msg):
  508. clf.fit(X2, y6)
  509. y_pred = clf.predict(X2)
  510. assert_array_equal(y_pred, y6)
  511. # UserWarning should also be there for the n_samples_in_a_class <
  512. # n_features case.
  513. clf = QuadraticDiscriminantAnalysis(reg_param=0.1)
  514. with pytest.warns(UserWarning, match=collinear_msg):
  515. clf.fit(X5, y5)
  516. y_pred5 = clf.predict(X5)
  517. assert_array_equal(y_pred5, y5)
  518. def test_covariance():
  519. x, y = make_blobs(n_samples=100, n_features=5, centers=1, random_state=42)
  520. # make features correlated
  521. x = np.dot(x, np.arange(x.shape[1] ** 2).reshape(x.shape[1], x.shape[1]))
  522. c_e = _cov(x, "empirical")
  523. assert_almost_equal(c_e, c_e.T)
  524. c_s = _cov(x, "auto")
  525. assert_almost_equal(c_s, c_s.T)
  526. @pytest.mark.parametrize("solver", ["svd", "lsqr", "eigen"])
  527. def test_raises_value_error_on_same_number_of_classes_and_samples(solver):
  528. """
  529. Tests that if the number of samples equals the number
  530. of classes, a ValueError is raised.
  531. """
  532. X = np.array([[0.5, 0.6], [0.6, 0.5]])
  533. y = np.array(["a", "b"])
  534. clf = LinearDiscriminantAnalysis(solver=solver)
  535. with pytest.raises(ValueError, match="The number of samples must be more"):
  536. clf.fit(X, y)
  537. def test_get_feature_names_out():
  538. """Check get_feature_names_out uses class name as prefix."""
  539. est = LinearDiscriminantAnalysis().fit(X, y)
  540. names_out = est.get_feature_names_out()
  541. class_name_lower = "LinearDiscriminantAnalysis".lower()
  542. expected_names_out = np.array(
  543. [
  544. f"{class_name_lower}{i}"
  545. for i in range(est.explained_variance_ratio_.shape[0])
  546. ],
  547. dtype=object,
  548. )
  549. assert_array_equal(names_out, expected_names_out)