test_bayesian_mixture.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. # Author: Wei Xue <xuewei4d@gmail.com>
  2. # Thierry Guillemot <thierry.guillemot.work@gmail.com>
  3. # License: BSD 3 clause
  4. import copy
  5. import numpy as np
  6. import pytest
  7. from scipy.special import gammaln
  8. from sklearn.exceptions import ConvergenceWarning, NotFittedError
  9. from sklearn.metrics.cluster import adjusted_rand_score
  10. from sklearn.mixture import BayesianGaussianMixture
  11. from sklearn.mixture._bayesian_mixture import _log_dirichlet_norm, _log_wishart_norm
  12. from sklearn.mixture.tests.test_gaussian_mixture import RandomData
  13. from sklearn.utils._testing import (
  14. assert_almost_equal,
  15. assert_array_equal,
  16. ignore_warnings,
  17. )
  18. COVARIANCE_TYPE = ["full", "tied", "diag", "spherical"]
  19. PRIOR_TYPE = ["dirichlet_process", "dirichlet_distribution"]
  20. def test_log_dirichlet_norm():
  21. rng = np.random.RandomState(0)
  22. weight_concentration = rng.rand(2)
  23. expected_norm = gammaln(np.sum(weight_concentration)) - np.sum(
  24. gammaln(weight_concentration)
  25. )
  26. predected_norm = _log_dirichlet_norm(weight_concentration)
  27. assert_almost_equal(expected_norm, predected_norm)
  28. def test_log_wishart_norm():
  29. rng = np.random.RandomState(0)
  30. n_components, n_features = 5, 2
  31. degrees_of_freedom = np.abs(rng.rand(n_components)) + 1.0
  32. log_det_precisions_chol = n_features * np.log(range(2, 2 + n_components))
  33. expected_norm = np.empty(5)
  34. for k, (degrees_of_freedom_k, log_det_k) in enumerate(
  35. zip(degrees_of_freedom, log_det_precisions_chol)
  36. ):
  37. expected_norm[k] = -(
  38. degrees_of_freedom_k * (log_det_k + 0.5 * n_features * np.log(2.0))
  39. + np.sum(
  40. gammaln(
  41. 0.5
  42. * (degrees_of_freedom_k - np.arange(0, n_features)[:, np.newaxis])
  43. ),
  44. 0,
  45. )
  46. ).item()
  47. predected_norm = _log_wishart_norm(
  48. degrees_of_freedom, log_det_precisions_chol, n_features
  49. )
  50. assert_almost_equal(expected_norm, predected_norm)
  51. def test_bayesian_mixture_weights_prior_initialisation():
  52. rng = np.random.RandomState(0)
  53. n_samples, n_components, n_features = 10, 5, 2
  54. X = rng.rand(n_samples, n_features)
  55. # Check correct init for a given value of weight_concentration_prior
  56. weight_concentration_prior = rng.rand()
  57. bgmm = BayesianGaussianMixture(
  58. weight_concentration_prior=weight_concentration_prior, random_state=rng
  59. ).fit(X)
  60. assert_almost_equal(weight_concentration_prior, bgmm.weight_concentration_prior_)
  61. # Check correct init for the default value of weight_concentration_prior
  62. bgmm = BayesianGaussianMixture(n_components=n_components, random_state=rng).fit(X)
  63. assert_almost_equal(1.0 / n_components, bgmm.weight_concentration_prior_)
  64. def test_bayesian_mixture_mean_prior_initialisation():
  65. rng = np.random.RandomState(0)
  66. n_samples, n_components, n_features = 10, 3, 2
  67. X = rng.rand(n_samples, n_features)
  68. # Check correct init for a given value of mean_precision_prior
  69. mean_precision_prior = rng.rand()
  70. bgmm = BayesianGaussianMixture(
  71. mean_precision_prior=mean_precision_prior, random_state=rng
  72. ).fit(X)
  73. assert_almost_equal(mean_precision_prior, bgmm.mean_precision_prior_)
  74. # Check correct init for the default value of mean_precision_prior
  75. bgmm = BayesianGaussianMixture(random_state=rng).fit(X)
  76. assert_almost_equal(1.0, bgmm.mean_precision_prior_)
  77. # Check correct init for a given value of mean_prior
  78. mean_prior = rng.rand(n_features)
  79. bgmm = BayesianGaussianMixture(
  80. n_components=n_components, mean_prior=mean_prior, random_state=rng
  81. ).fit(X)
  82. assert_almost_equal(mean_prior, bgmm.mean_prior_)
  83. # Check correct init for the default value of bemean_priorta
  84. bgmm = BayesianGaussianMixture(n_components=n_components, random_state=rng).fit(X)
  85. assert_almost_equal(X.mean(axis=0), bgmm.mean_prior_)
  86. def test_bayesian_mixture_precisions_prior_initialisation():
  87. rng = np.random.RandomState(0)
  88. n_samples, n_features = 10, 2
  89. X = rng.rand(n_samples, n_features)
  90. # Check raise message for a bad value of degrees_of_freedom_prior
  91. bad_degrees_of_freedom_prior_ = n_features - 1.0
  92. bgmm = BayesianGaussianMixture(
  93. degrees_of_freedom_prior=bad_degrees_of_freedom_prior_, random_state=rng
  94. )
  95. msg = (
  96. "The parameter 'degrees_of_freedom_prior' should be greater than"
  97. f" {n_features -1}, but got {bad_degrees_of_freedom_prior_:.3f}."
  98. )
  99. with pytest.raises(ValueError, match=msg):
  100. bgmm.fit(X)
  101. # Check correct init for a given value of degrees_of_freedom_prior
  102. degrees_of_freedom_prior = rng.rand() + n_features - 1.0
  103. bgmm = BayesianGaussianMixture(
  104. degrees_of_freedom_prior=degrees_of_freedom_prior, random_state=rng
  105. ).fit(X)
  106. assert_almost_equal(degrees_of_freedom_prior, bgmm.degrees_of_freedom_prior_)
  107. # Check correct init for the default value of degrees_of_freedom_prior
  108. degrees_of_freedom_prior_default = n_features
  109. bgmm = BayesianGaussianMixture(
  110. degrees_of_freedom_prior=degrees_of_freedom_prior_default, random_state=rng
  111. ).fit(X)
  112. assert_almost_equal(
  113. degrees_of_freedom_prior_default, bgmm.degrees_of_freedom_prior_
  114. )
  115. # Check correct init for a given value of covariance_prior
  116. covariance_prior = {
  117. "full": np.cov(X.T, bias=1) + 10,
  118. "tied": np.cov(X.T, bias=1) + 5,
  119. "diag": np.diag(np.atleast_2d(np.cov(X.T, bias=1))) + 3,
  120. "spherical": rng.rand(),
  121. }
  122. bgmm = BayesianGaussianMixture(random_state=rng)
  123. for cov_type in ["full", "tied", "diag", "spherical"]:
  124. bgmm.covariance_type = cov_type
  125. bgmm.covariance_prior = covariance_prior[cov_type]
  126. bgmm.fit(X)
  127. assert_almost_equal(covariance_prior[cov_type], bgmm.covariance_prior_)
  128. # Check correct init for the default value of covariance_prior
  129. covariance_prior_default = {
  130. "full": np.atleast_2d(np.cov(X.T)),
  131. "tied": np.atleast_2d(np.cov(X.T)),
  132. "diag": np.var(X, axis=0, ddof=1),
  133. "spherical": np.var(X, axis=0, ddof=1).mean(),
  134. }
  135. bgmm = BayesianGaussianMixture(random_state=0)
  136. for cov_type in ["full", "tied", "diag", "spherical"]:
  137. bgmm.covariance_type = cov_type
  138. bgmm.fit(X)
  139. assert_almost_equal(covariance_prior_default[cov_type], bgmm.covariance_prior_)
  140. def test_bayesian_mixture_check_is_fitted():
  141. rng = np.random.RandomState(0)
  142. n_samples, n_features = 10, 2
  143. # Check raise message
  144. bgmm = BayesianGaussianMixture(random_state=rng)
  145. X = rng.rand(n_samples, n_features)
  146. msg = "This BayesianGaussianMixture instance is not fitted yet."
  147. with pytest.raises(ValueError, match=msg):
  148. bgmm.score(X)
  149. def test_bayesian_mixture_weights():
  150. rng = np.random.RandomState(0)
  151. n_samples, n_features = 10, 2
  152. X = rng.rand(n_samples, n_features)
  153. # Case Dirichlet distribution for the weight concentration prior type
  154. bgmm = BayesianGaussianMixture(
  155. weight_concentration_prior_type="dirichlet_distribution",
  156. n_components=3,
  157. random_state=rng,
  158. ).fit(X)
  159. expected_weights = bgmm.weight_concentration_ / np.sum(bgmm.weight_concentration_)
  160. assert_almost_equal(expected_weights, bgmm.weights_)
  161. assert_almost_equal(np.sum(bgmm.weights_), 1.0)
  162. # Case Dirichlet process for the weight concentration prior type
  163. dpgmm = BayesianGaussianMixture(
  164. weight_concentration_prior_type="dirichlet_process",
  165. n_components=3,
  166. random_state=rng,
  167. ).fit(X)
  168. weight_dirichlet_sum = (
  169. dpgmm.weight_concentration_[0] + dpgmm.weight_concentration_[1]
  170. )
  171. tmp = dpgmm.weight_concentration_[1] / weight_dirichlet_sum
  172. expected_weights = (
  173. dpgmm.weight_concentration_[0]
  174. / weight_dirichlet_sum
  175. * np.hstack((1, np.cumprod(tmp[:-1])))
  176. )
  177. expected_weights /= np.sum(expected_weights)
  178. assert_almost_equal(expected_weights, dpgmm.weights_)
  179. assert_almost_equal(np.sum(dpgmm.weights_), 1.0)
  180. @ignore_warnings(category=ConvergenceWarning)
  181. def test_monotonic_likelihood():
  182. # We check that each step of the each step of variational inference without
  183. # regularization improve monotonically the training set of the bound
  184. rng = np.random.RandomState(0)
  185. rand_data = RandomData(rng, scale=20)
  186. n_components = rand_data.n_components
  187. for prior_type in PRIOR_TYPE:
  188. for covar_type in COVARIANCE_TYPE:
  189. X = rand_data.X[covar_type]
  190. bgmm = BayesianGaussianMixture(
  191. weight_concentration_prior_type=prior_type,
  192. n_components=2 * n_components,
  193. covariance_type=covar_type,
  194. warm_start=True,
  195. max_iter=1,
  196. random_state=rng,
  197. tol=1e-3,
  198. )
  199. current_lower_bound = -np.inf
  200. # Do one training iteration at a time so we can make sure that the
  201. # training log likelihood increases after each iteration.
  202. for _ in range(600):
  203. prev_lower_bound = current_lower_bound
  204. current_lower_bound = bgmm.fit(X).lower_bound_
  205. assert current_lower_bound >= prev_lower_bound
  206. if bgmm.converged_:
  207. break
  208. assert bgmm.converged_
  209. def test_compare_covar_type():
  210. # We can compare the 'full' precision with the other cov_type if we apply
  211. # 1 iter of the M-step (done during _initialize_parameters).
  212. rng = np.random.RandomState(0)
  213. rand_data = RandomData(rng, scale=7)
  214. X = rand_data.X["full"]
  215. n_components = rand_data.n_components
  216. for prior_type in PRIOR_TYPE:
  217. # Computation of the full_covariance
  218. bgmm = BayesianGaussianMixture(
  219. weight_concentration_prior_type=prior_type,
  220. n_components=2 * n_components,
  221. covariance_type="full",
  222. max_iter=1,
  223. random_state=0,
  224. tol=1e-7,
  225. )
  226. bgmm._check_parameters(X)
  227. bgmm._initialize_parameters(X, np.random.RandomState(0))
  228. full_covariances = (
  229. bgmm.covariances_ * bgmm.degrees_of_freedom_[:, np.newaxis, np.newaxis]
  230. )
  231. # Check tied_covariance = mean(full_covariances, 0)
  232. bgmm = BayesianGaussianMixture(
  233. weight_concentration_prior_type=prior_type,
  234. n_components=2 * n_components,
  235. covariance_type="tied",
  236. max_iter=1,
  237. random_state=0,
  238. tol=1e-7,
  239. )
  240. bgmm._check_parameters(X)
  241. bgmm._initialize_parameters(X, np.random.RandomState(0))
  242. tied_covariance = bgmm.covariances_ * bgmm.degrees_of_freedom_
  243. assert_almost_equal(tied_covariance, np.mean(full_covariances, 0))
  244. # Check diag_covariance = diag(full_covariances)
  245. bgmm = BayesianGaussianMixture(
  246. weight_concentration_prior_type=prior_type,
  247. n_components=2 * n_components,
  248. covariance_type="diag",
  249. max_iter=1,
  250. random_state=0,
  251. tol=1e-7,
  252. )
  253. bgmm._check_parameters(X)
  254. bgmm._initialize_parameters(X, np.random.RandomState(0))
  255. diag_covariances = bgmm.covariances_ * bgmm.degrees_of_freedom_[:, np.newaxis]
  256. assert_almost_equal(
  257. diag_covariances, np.array([np.diag(cov) for cov in full_covariances])
  258. )
  259. # Check spherical_covariance = np.mean(diag_covariances, 0)
  260. bgmm = BayesianGaussianMixture(
  261. weight_concentration_prior_type=prior_type,
  262. n_components=2 * n_components,
  263. covariance_type="spherical",
  264. max_iter=1,
  265. random_state=0,
  266. tol=1e-7,
  267. )
  268. bgmm._check_parameters(X)
  269. bgmm._initialize_parameters(X, np.random.RandomState(0))
  270. spherical_covariances = bgmm.covariances_ * bgmm.degrees_of_freedom_
  271. assert_almost_equal(spherical_covariances, np.mean(diag_covariances, 1))
  272. @ignore_warnings(category=ConvergenceWarning)
  273. def test_check_covariance_precision():
  274. # We check that the dot product of the covariance and the precision
  275. # matrices is identity.
  276. rng = np.random.RandomState(0)
  277. rand_data = RandomData(rng, scale=7)
  278. n_components, n_features = 2 * rand_data.n_components, 2
  279. # Computation of the full_covariance
  280. bgmm = BayesianGaussianMixture(
  281. n_components=n_components, max_iter=100, random_state=rng, tol=1e-3, reg_covar=0
  282. )
  283. for covar_type in COVARIANCE_TYPE:
  284. bgmm.covariance_type = covar_type
  285. bgmm.fit(rand_data.X[covar_type])
  286. if covar_type == "full":
  287. for covar, precision in zip(bgmm.covariances_, bgmm.precisions_):
  288. assert_almost_equal(np.dot(covar, precision), np.eye(n_features))
  289. elif covar_type == "tied":
  290. assert_almost_equal(
  291. np.dot(bgmm.covariances_, bgmm.precisions_), np.eye(n_features)
  292. )
  293. elif covar_type == "diag":
  294. assert_almost_equal(
  295. bgmm.covariances_ * bgmm.precisions_,
  296. np.ones((n_components, n_features)),
  297. )
  298. else:
  299. assert_almost_equal(
  300. bgmm.covariances_ * bgmm.precisions_, np.ones(n_components)
  301. )
  302. @ignore_warnings(category=ConvergenceWarning)
  303. def test_invariant_translation():
  304. # We check here that adding a constant in the data change correctly the
  305. # parameters of the mixture
  306. rng = np.random.RandomState(0)
  307. rand_data = RandomData(rng, scale=100)
  308. n_components = 2 * rand_data.n_components
  309. for prior_type in PRIOR_TYPE:
  310. for covar_type in COVARIANCE_TYPE:
  311. X = rand_data.X[covar_type]
  312. bgmm1 = BayesianGaussianMixture(
  313. weight_concentration_prior_type=prior_type,
  314. n_components=n_components,
  315. max_iter=100,
  316. random_state=0,
  317. tol=1e-3,
  318. reg_covar=0,
  319. ).fit(X)
  320. bgmm2 = BayesianGaussianMixture(
  321. weight_concentration_prior_type=prior_type,
  322. n_components=n_components,
  323. max_iter=100,
  324. random_state=0,
  325. tol=1e-3,
  326. reg_covar=0,
  327. ).fit(X + 100)
  328. assert_almost_equal(bgmm1.means_, bgmm2.means_ - 100)
  329. assert_almost_equal(bgmm1.weights_, bgmm2.weights_)
  330. assert_almost_equal(bgmm1.covariances_, bgmm2.covariances_)
  331. @pytest.mark.filterwarnings("ignore:.*did not converge.*")
  332. @pytest.mark.parametrize(
  333. "seed, max_iter, tol",
  334. [
  335. (0, 2, 1e-7), # strict non-convergence
  336. (1, 2, 1e-1), # loose non-convergence
  337. (3, 300, 1e-7), # strict convergence
  338. (4, 300, 1e-1), # loose convergence
  339. ],
  340. )
  341. def test_bayesian_mixture_fit_predict(seed, max_iter, tol):
  342. rng = np.random.RandomState(seed)
  343. rand_data = RandomData(rng, n_samples=50, scale=7)
  344. n_components = 2 * rand_data.n_components
  345. for covar_type in COVARIANCE_TYPE:
  346. bgmm1 = BayesianGaussianMixture(
  347. n_components=n_components,
  348. max_iter=max_iter,
  349. random_state=rng,
  350. tol=tol,
  351. reg_covar=0,
  352. )
  353. bgmm1.covariance_type = covar_type
  354. bgmm2 = copy.deepcopy(bgmm1)
  355. X = rand_data.X[covar_type]
  356. Y_pred1 = bgmm1.fit(X).predict(X)
  357. Y_pred2 = bgmm2.fit_predict(X)
  358. assert_array_equal(Y_pred1, Y_pred2)
  359. def test_bayesian_mixture_fit_predict_n_init():
  360. # Check that fit_predict is equivalent to fit.predict, when n_init > 1
  361. X = np.random.RandomState(0).randn(50, 5)
  362. gm = BayesianGaussianMixture(n_components=5, n_init=10, random_state=0)
  363. y_pred1 = gm.fit_predict(X)
  364. y_pred2 = gm.predict(X)
  365. assert_array_equal(y_pred1, y_pred2)
  366. def test_bayesian_mixture_predict_predict_proba():
  367. # this is the same test as test_gaussian_mixture_predict_predict_proba()
  368. rng = np.random.RandomState(0)
  369. rand_data = RandomData(rng)
  370. for prior_type in PRIOR_TYPE:
  371. for covar_type in COVARIANCE_TYPE:
  372. X = rand_data.X[covar_type]
  373. Y = rand_data.Y
  374. bgmm = BayesianGaussianMixture(
  375. n_components=rand_data.n_components,
  376. random_state=rng,
  377. weight_concentration_prior_type=prior_type,
  378. covariance_type=covar_type,
  379. )
  380. # Check a warning message arrive if we don't do fit
  381. msg = (
  382. "This BayesianGaussianMixture instance is not fitted yet. "
  383. "Call 'fit' with appropriate arguments before using this "
  384. "estimator."
  385. )
  386. with pytest.raises(NotFittedError, match=msg):
  387. bgmm.predict(X)
  388. bgmm.fit(X)
  389. Y_pred = bgmm.predict(X)
  390. Y_pred_proba = bgmm.predict_proba(X).argmax(axis=1)
  391. assert_array_equal(Y_pred, Y_pred_proba)
  392. assert adjusted_rand_score(Y, Y_pred) >= 0.95