_bayesian_mixture.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888
  1. """Bayesian Gaussian Mixture Model."""
  2. # Author: Wei Xue <xuewei4d@gmail.com>
  3. # Thierry Guillemot <thierry.guillemot.work@gmail.com>
  4. # License: BSD 3 clause
  5. import math
  6. from numbers import Real
  7. import numpy as np
  8. from scipy.special import betaln, digamma, gammaln
  9. from ..utils import check_array
  10. from ..utils._param_validation import Interval, StrOptions
  11. from ._base import BaseMixture, _check_shape
  12. from ._gaussian_mixture import (
  13. _check_precision_matrix,
  14. _check_precision_positivity,
  15. _compute_log_det_cholesky,
  16. _compute_precision_cholesky,
  17. _estimate_gaussian_parameters,
  18. _estimate_log_gaussian_prob,
  19. )
  20. def _log_dirichlet_norm(dirichlet_concentration):
  21. """Compute the log of the Dirichlet distribution normalization term.
  22. Parameters
  23. ----------
  24. dirichlet_concentration : array-like of shape (n_samples,)
  25. The parameters values of the Dirichlet distribution.
  26. Returns
  27. -------
  28. log_dirichlet_norm : float
  29. The log normalization of the Dirichlet distribution.
  30. """
  31. return gammaln(np.sum(dirichlet_concentration)) - np.sum(
  32. gammaln(dirichlet_concentration)
  33. )
  34. def _log_wishart_norm(degrees_of_freedom, log_det_precisions_chol, n_features):
  35. """Compute the log of the Wishart distribution normalization term.
  36. Parameters
  37. ----------
  38. degrees_of_freedom : array-like of shape (n_components,)
  39. The number of degrees of freedom on the covariance Wishart
  40. distributions.
  41. log_det_precision_chol : array-like of shape (n_components,)
  42. The determinant of the precision matrix for each component.
  43. n_features : int
  44. The number of features.
  45. Return
  46. ------
  47. log_wishart_norm : array-like of shape (n_components,)
  48. The log normalization of the Wishart distribution.
  49. """
  50. # To simplify the computation we have removed the np.log(np.pi) term
  51. return -(
  52. degrees_of_freedom * log_det_precisions_chol
  53. + degrees_of_freedom * n_features * 0.5 * math.log(2.0)
  54. + np.sum(
  55. gammaln(0.5 * (degrees_of_freedom - np.arange(n_features)[:, np.newaxis])),
  56. 0,
  57. )
  58. )
  59. class BayesianGaussianMixture(BaseMixture):
  60. """Variational Bayesian estimation of a Gaussian mixture.
  61. This class allows to infer an approximate posterior distribution over the
  62. parameters of a Gaussian mixture distribution. The effective number of
  63. components can be inferred from the data.
  64. This class implements two types of prior for the weights distribution: a
  65. finite mixture model with Dirichlet distribution and an infinite mixture
  66. model with the Dirichlet Process. In practice Dirichlet Process inference
  67. algorithm is approximated and uses a truncated distribution with a fixed
  68. maximum number of components (called the Stick-breaking representation).
  69. The number of components actually used almost always depends on the data.
  70. .. versionadded:: 0.18
  71. Read more in the :ref:`User Guide <bgmm>`.
  72. Parameters
  73. ----------
  74. n_components : int, default=1
  75. The number of mixture components. Depending on the data and the value
  76. of the `weight_concentration_prior` the model can decide to not use
  77. all the components by setting some component `weights_` to values very
  78. close to zero. The number of effective components is therefore smaller
  79. than n_components.
  80. covariance_type : {'full', 'tied', 'diag', 'spherical'}, default='full'
  81. String describing the type of covariance parameters to use.
  82. Must be one of::
  83. 'full' (each component has its own general covariance matrix),
  84. 'tied' (all components share the same general covariance matrix),
  85. 'diag' (each component has its own diagonal covariance matrix),
  86. 'spherical' (each component has its own single variance).
  87. tol : float, default=1e-3
  88. The convergence threshold. EM iterations will stop when the
  89. lower bound average gain on the likelihood (of the training data with
  90. respect to the model) is below this threshold.
  91. reg_covar : float, default=1e-6
  92. Non-negative regularization added to the diagonal of covariance.
  93. Allows to assure that the covariance matrices are all positive.
  94. max_iter : int, default=100
  95. The number of EM iterations to perform.
  96. n_init : int, default=1
  97. The number of initializations to perform. The result with the highest
  98. lower bound value on the likelihood is kept.
  99. init_params : {'kmeans', 'k-means++', 'random', 'random_from_data'}, \
  100. default='kmeans'
  101. The method used to initialize the weights, the means and the
  102. covariances.
  103. String must be one of:
  104. 'kmeans' : responsibilities are initialized using kmeans.
  105. 'k-means++' : use the k-means++ method to initialize.
  106. 'random' : responsibilities are initialized randomly.
  107. 'random_from_data' : initial means are randomly selected data points.
  108. .. versionchanged:: v1.1
  109. `init_params` now accepts 'random_from_data' and 'k-means++' as
  110. initialization methods.
  111. weight_concentration_prior_type : {'dirichlet_process', 'dirichlet_distribution'}, \
  112. default='dirichlet_process'
  113. String describing the type of the weight concentration prior.
  114. weight_concentration_prior : float or None, default=None
  115. The dirichlet concentration of each component on the weight
  116. distribution (Dirichlet). This is commonly called gamma in the
  117. literature. The higher concentration puts more mass in
  118. the center and will lead to more components being active, while a lower
  119. concentration parameter will lead to more mass at the edge of the
  120. mixture weights simplex. The value of the parameter must be greater
  121. than 0. If it is None, it's set to ``1. / n_components``.
  122. mean_precision_prior : float or None, default=None
  123. The precision prior on the mean distribution (Gaussian).
  124. Controls the extent of where means can be placed. Larger
  125. values concentrate the cluster means around `mean_prior`.
  126. The value of the parameter must be greater than 0.
  127. If it is None, it is set to 1.
  128. mean_prior : array-like, shape (n_features,), default=None
  129. The prior on the mean distribution (Gaussian).
  130. If it is None, it is set to the mean of X.
  131. degrees_of_freedom_prior : float or None, default=None
  132. The prior of the number of degrees of freedom on the covariance
  133. distributions (Wishart). If it is None, it's set to `n_features`.
  134. covariance_prior : float or array-like, default=None
  135. The prior on the covariance distribution (Wishart).
  136. If it is None, the emiprical covariance prior is initialized using the
  137. covariance of X. The shape depends on `covariance_type`::
  138. (n_features, n_features) if 'full',
  139. (n_features, n_features) if 'tied',
  140. (n_features) if 'diag',
  141. float if 'spherical'
  142. random_state : int, RandomState instance or None, default=None
  143. Controls the random seed given to the method chosen to initialize the
  144. parameters (see `init_params`).
  145. In addition, it controls the generation of random samples from the
  146. fitted distribution (see the method `sample`).
  147. Pass an int for reproducible output across multiple function calls.
  148. See :term:`Glossary <random_state>`.
  149. warm_start : bool, default=False
  150. If 'warm_start' is True, the solution of the last fitting is used as
  151. initialization for the next call of fit(). This can speed up
  152. convergence when fit is called several times on similar problems.
  153. See :term:`the Glossary <warm_start>`.
  154. verbose : int, default=0
  155. Enable verbose output. If 1 then it prints the current
  156. initialization and each iteration step. If greater than 1 then
  157. it prints also the log probability and the time needed
  158. for each step.
  159. verbose_interval : int, default=10
  160. Number of iteration done before the next print.
  161. Attributes
  162. ----------
  163. weights_ : array-like of shape (n_components,)
  164. The weights of each mixture components.
  165. means_ : array-like of shape (n_components, n_features)
  166. The mean of each mixture component.
  167. covariances_ : array-like
  168. The covariance of each mixture component.
  169. The shape depends on `covariance_type`::
  170. (n_components,) if 'spherical',
  171. (n_features, n_features) if 'tied',
  172. (n_components, n_features) if 'diag',
  173. (n_components, n_features, n_features) if 'full'
  174. precisions_ : array-like
  175. The precision matrices for each component in the mixture. A precision
  176. matrix is the inverse of a covariance matrix. A covariance matrix is
  177. symmetric positive definite so the mixture of Gaussian can be
  178. equivalently parameterized by the precision matrices. Storing the
  179. precision matrices instead of the covariance matrices makes it more
  180. efficient to compute the log-likelihood of new samples at test time.
  181. The shape depends on ``covariance_type``::
  182. (n_components,) if 'spherical',
  183. (n_features, n_features) if 'tied',
  184. (n_components, n_features) if 'diag',
  185. (n_components, n_features, n_features) if 'full'
  186. precisions_cholesky_ : array-like
  187. The cholesky decomposition of the precision matrices of each mixture
  188. component. A precision matrix is the inverse of a covariance matrix.
  189. A covariance matrix is symmetric positive definite so the mixture of
  190. Gaussian can be equivalently parameterized by the precision matrices.
  191. Storing the precision matrices instead of the covariance matrices makes
  192. it more efficient to compute the log-likelihood of new samples at test
  193. time. The shape depends on ``covariance_type``::
  194. (n_components,) if 'spherical',
  195. (n_features, n_features) if 'tied',
  196. (n_components, n_features) if 'diag',
  197. (n_components, n_features, n_features) if 'full'
  198. converged_ : bool
  199. True when convergence was reached in fit(), False otherwise.
  200. n_iter_ : int
  201. Number of step used by the best fit of inference to reach the
  202. convergence.
  203. lower_bound_ : float
  204. Lower bound value on the model evidence (of the training data) of the
  205. best fit of inference.
  206. weight_concentration_prior_ : tuple or float
  207. The dirichlet concentration of each component on the weight
  208. distribution (Dirichlet). The type depends on
  209. ``weight_concentration_prior_type``::
  210. (float, float) if 'dirichlet_process' (Beta parameters),
  211. float if 'dirichlet_distribution' (Dirichlet parameters).
  212. The higher concentration puts more mass in
  213. the center and will lead to more components being active, while a lower
  214. concentration parameter will lead to more mass at the edge of the
  215. simplex.
  216. weight_concentration_ : array-like of shape (n_components,)
  217. The dirichlet concentration of each component on the weight
  218. distribution (Dirichlet).
  219. mean_precision_prior_ : float
  220. The precision prior on the mean distribution (Gaussian).
  221. Controls the extent of where means can be placed.
  222. Larger values concentrate the cluster means around `mean_prior`.
  223. If mean_precision_prior is set to None, `mean_precision_prior_` is set
  224. to 1.
  225. mean_precision_ : array-like of shape (n_components,)
  226. The precision of each components on the mean distribution (Gaussian).
  227. mean_prior_ : array-like of shape (n_features,)
  228. The prior on the mean distribution (Gaussian).
  229. degrees_of_freedom_prior_ : float
  230. The prior of the number of degrees of freedom on the covariance
  231. distributions (Wishart).
  232. degrees_of_freedom_ : array-like of shape (n_components,)
  233. The number of degrees of freedom of each components in the model.
  234. covariance_prior_ : float or array-like
  235. The prior on the covariance distribution (Wishart).
  236. The shape depends on `covariance_type`::
  237. (n_features, n_features) if 'full',
  238. (n_features, n_features) if 'tied',
  239. (n_features) if 'diag',
  240. float if 'spherical'
  241. n_features_in_ : int
  242. Number of features seen during :term:`fit`.
  243. .. versionadded:: 0.24
  244. feature_names_in_ : ndarray of shape (`n_features_in_`,)
  245. Names of features seen during :term:`fit`. Defined only when `X`
  246. has feature names that are all strings.
  247. .. versionadded:: 1.0
  248. See Also
  249. --------
  250. GaussianMixture : Finite Gaussian mixture fit with EM.
  251. References
  252. ----------
  253. .. [1] `Bishop, Christopher M. (2006). "Pattern recognition and machine
  254. learning". Vol. 4 No. 4. New York: Springer.
  255. <https://www.springer.com/kr/book/9780387310732>`_
  256. .. [2] `Hagai Attias. (2000). "A Variational Bayesian Framework for
  257. Graphical Models". In Advances in Neural Information Processing
  258. Systems 12.
  259. <https://citeseerx.ist.psu.edu/doc_view/pid/ee844fd96db7041a9681b5a18bff008912052c7e>`_
  260. .. [3] `Blei, David M. and Michael I. Jordan. (2006). "Variational
  261. inference for Dirichlet process mixtures". Bayesian analysis 1.1
  262. <https://www.cs.princeton.edu/courses/archive/fall11/cos597C/reading/BleiJordan2005.pdf>`_
  263. Examples
  264. --------
  265. >>> import numpy as np
  266. >>> from sklearn.mixture import BayesianGaussianMixture
  267. >>> X = np.array([[1, 2], [1, 4], [1, 0], [4, 2], [12, 4], [10, 7]])
  268. >>> bgm = BayesianGaussianMixture(n_components=2, random_state=42).fit(X)
  269. >>> bgm.means_
  270. array([[2.49... , 2.29...],
  271. [8.45..., 4.52... ]])
  272. >>> bgm.predict([[0, 0], [9, 3]])
  273. array([0, 1])
  274. """
  275. _parameter_constraints: dict = {
  276. **BaseMixture._parameter_constraints,
  277. "covariance_type": [StrOptions({"spherical", "tied", "diag", "full"})],
  278. "weight_concentration_prior_type": [
  279. StrOptions({"dirichlet_process", "dirichlet_distribution"})
  280. ],
  281. "weight_concentration_prior": [
  282. None,
  283. Interval(Real, 0.0, None, closed="neither"),
  284. ],
  285. "mean_precision_prior": [None, Interval(Real, 0.0, None, closed="neither")],
  286. "mean_prior": [None, "array-like"],
  287. "degrees_of_freedom_prior": [None, Interval(Real, 0.0, None, closed="neither")],
  288. "covariance_prior": [
  289. None,
  290. "array-like",
  291. Interval(Real, 0.0, None, closed="neither"),
  292. ],
  293. }
  294. def __init__(
  295. self,
  296. *,
  297. n_components=1,
  298. covariance_type="full",
  299. tol=1e-3,
  300. reg_covar=1e-6,
  301. max_iter=100,
  302. n_init=1,
  303. init_params="kmeans",
  304. weight_concentration_prior_type="dirichlet_process",
  305. weight_concentration_prior=None,
  306. mean_precision_prior=None,
  307. mean_prior=None,
  308. degrees_of_freedom_prior=None,
  309. covariance_prior=None,
  310. random_state=None,
  311. warm_start=False,
  312. verbose=0,
  313. verbose_interval=10,
  314. ):
  315. super().__init__(
  316. n_components=n_components,
  317. tol=tol,
  318. reg_covar=reg_covar,
  319. max_iter=max_iter,
  320. n_init=n_init,
  321. init_params=init_params,
  322. random_state=random_state,
  323. warm_start=warm_start,
  324. verbose=verbose,
  325. verbose_interval=verbose_interval,
  326. )
  327. self.covariance_type = covariance_type
  328. self.weight_concentration_prior_type = weight_concentration_prior_type
  329. self.weight_concentration_prior = weight_concentration_prior
  330. self.mean_precision_prior = mean_precision_prior
  331. self.mean_prior = mean_prior
  332. self.degrees_of_freedom_prior = degrees_of_freedom_prior
  333. self.covariance_prior = covariance_prior
  334. def _check_parameters(self, X):
  335. """Check that the parameters are well defined.
  336. Parameters
  337. ----------
  338. X : array-like of shape (n_samples, n_features)
  339. """
  340. self._check_weights_parameters()
  341. self._check_means_parameters(X)
  342. self._check_precision_parameters(X)
  343. self._checkcovariance_prior_parameter(X)
  344. def _check_weights_parameters(self):
  345. """Check the parameter of the Dirichlet distribution."""
  346. if self.weight_concentration_prior is None:
  347. self.weight_concentration_prior_ = 1.0 / self.n_components
  348. else:
  349. self.weight_concentration_prior_ = self.weight_concentration_prior
  350. def _check_means_parameters(self, X):
  351. """Check the parameters of the Gaussian distribution.
  352. Parameters
  353. ----------
  354. X : array-like of shape (n_samples, n_features)
  355. """
  356. _, n_features = X.shape
  357. if self.mean_precision_prior is None:
  358. self.mean_precision_prior_ = 1.0
  359. else:
  360. self.mean_precision_prior_ = self.mean_precision_prior
  361. if self.mean_prior is None:
  362. self.mean_prior_ = X.mean(axis=0)
  363. else:
  364. self.mean_prior_ = check_array(
  365. self.mean_prior, dtype=[np.float64, np.float32], ensure_2d=False
  366. )
  367. _check_shape(self.mean_prior_, (n_features,), "means")
  368. def _check_precision_parameters(self, X):
  369. """Check the prior parameters of the precision distribution.
  370. Parameters
  371. ----------
  372. X : array-like of shape (n_samples, n_features)
  373. """
  374. _, n_features = X.shape
  375. if self.degrees_of_freedom_prior is None:
  376. self.degrees_of_freedom_prior_ = n_features
  377. elif self.degrees_of_freedom_prior > n_features - 1.0:
  378. self.degrees_of_freedom_prior_ = self.degrees_of_freedom_prior
  379. else:
  380. raise ValueError(
  381. "The parameter 'degrees_of_freedom_prior' "
  382. "should be greater than %d, but got %.3f."
  383. % (n_features - 1, self.degrees_of_freedom_prior)
  384. )
  385. def _checkcovariance_prior_parameter(self, X):
  386. """Check the `covariance_prior_`.
  387. Parameters
  388. ----------
  389. X : array-like of shape (n_samples, n_features)
  390. """
  391. _, n_features = X.shape
  392. if self.covariance_prior is None:
  393. self.covariance_prior_ = {
  394. "full": np.atleast_2d(np.cov(X.T)),
  395. "tied": np.atleast_2d(np.cov(X.T)),
  396. "diag": np.var(X, axis=0, ddof=1),
  397. "spherical": np.var(X, axis=0, ddof=1).mean(),
  398. }[self.covariance_type]
  399. elif self.covariance_type in ["full", "tied"]:
  400. self.covariance_prior_ = check_array(
  401. self.covariance_prior, dtype=[np.float64, np.float32], ensure_2d=False
  402. )
  403. _check_shape(
  404. self.covariance_prior_,
  405. (n_features, n_features),
  406. "%s covariance_prior" % self.covariance_type,
  407. )
  408. _check_precision_matrix(self.covariance_prior_, self.covariance_type)
  409. elif self.covariance_type == "diag":
  410. self.covariance_prior_ = check_array(
  411. self.covariance_prior, dtype=[np.float64, np.float32], ensure_2d=False
  412. )
  413. _check_shape(
  414. self.covariance_prior_,
  415. (n_features,),
  416. "%s covariance_prior" % self.covariance_type,
  417. )
  418. _check_precision_positivity(self.covariance_prior_, self.covariance_type)
  419. # spherical case
  420. else:
  421. self.covariance_prior_ = self.covariance_prior
  422. def _initialize(self, X, resp):
  423. """Initialization of the mixture parameters.
  424. Parameters
  425. ----------
  426. X : array-like of shape (n_samples, n_features)
  427. resp : array-like of shape (n_samples, n_components)
  428. """
  429. nk, xk, sk = _estimate_gaussian_parameters(
  430. X, resp, self.reg_covar, self.covariance_type
  431. )
  432. self._estimate_weights(nk)
  433. self._estimate_means(nk, xk)
  434. self._estimate_precisions(nk, xk, sk)
  435. def _estimate_weights(self, nk):
  436. """Estimate the parameters of the Dirichlet distribution.
  437. Parameters
  438. ----------
  439. nk : array-like of shape (n_components,)
  440. """
  441. if self.weight_concentration_prior_type == "dirichlet_process":
  442. # For dirichlet process weight_concentration will be a tuple
  443. # containing the two parameters of the beta distribution
  444. self.weight_concentration_ = (
  445. 1.0 + nk,
  446. (
  447. self.weight_concentration_prior_
  448. + np.hstack((np.cumsum(nk[::-1])[-2::-1], 0))
  449. ),
  450. )
  451. else:
  452. # case Variational Gaussian mixture with dirichlet distribution
  453. self.weight_concentration_ = self.weight_concentration_prior_ + nk
  454. def _estimate_means(self, nk, xk):
  455. """Estimate the parameters of the Gaussian distribution.
  456. Parameters
  457. ----------
  458. nk : array-like of shape (n_components,)
  459. xk : array-like of shape (n_components, n_features)
  460. """
  461. self.mean_precision_ = self.mean_precision_prior_ + nk
  462. self.means_ = (
  463. self.mean_precision_prior_ * self.mean_prior_ + nk[:, np.newaxis] * xk
  464. ) / self.mean_precision_[:, np.newaxis]
  465. def _estimate_precisions(self, nk, xk, sk):
  466. """Estimate the precisions parameters of the precision distribution.
  467. Parameters
  468. ----------
  469. nk : array-like of shape (n_components,)
  470. xk : array-like of shape (n_components, n_features)
  471. sk : array-like
  472. The shape depends of `covariance_type`:
  473. 'full' : (n_components, n_features, n_features)
  474. 'tied' : (n_features, n_features)
  475. 'diag' : (n_components, n_features)
  476. 'spherical' : (n_components,)
  477. """
  478. {
  479. "full": self._estimate_wishart_full,
  480. "tied": self._estimate_wishart_tied,
  481. "diag": self._estimate_wishart_diag,
  482. "spherical": self._estimate_wishart_spherical,
  483. }[self.covariance_type](nk, xk, sk)
  484. self.precisions_cholesky_ = _compute_precision_cholesky(
  485. self.covariances_, self.covariance_type
  486. )
  487. def _estimate_wishart_full(self, nk, xk, sk):
  488. """Estimate the full Wishart distribution parameters.
  489. Parameters
  490. ----------
  491. X : array-like of shape (n_samples, n_features)
  492. nk : array-like of shape (n_components,)
  493. xk : array-like of shape (n_components, n_features)
  494. sk : array-like of shape (n_components, n_features, n_features)
  495. """
  496. _, n_features = xk.shape
  497. # Warning : in some Bishop book, there is a typo on the formula 10.63
  498. # `degrees_of_freedom_k = degrees_of_freedom_0 + Nk` is
  499. # the correct formula
  500. self.degrees_of_freedom_ = self.degrees_of_freedom_prior_ + nk
  501. self.covariances_ = np.empty((self.n_components, n_features, n_features))
  502. for k in range(self.n_components):
  503. diff = xk[k] - self.mean_prior_
  504. self.covariances_[k] = (
  505. self.covariance_prior_
  506. + nk[k] * sk[k]
  507. + nk[k]
  508. * self.mean_precision_prior_
  509. / self.mean_precision_[k]
  510. * np.outer(diff, diff)
  511. )
  512. # Contrary to the original bishop book, we normalize the covariances
  513. self.covariances_ /= self.degrees_of_freedom_[:, np.newaxis, np.newaxis]
  514. def _estimate_wishart_tied(self, nk, xk, sk):
  515. """Estimate the tied Wishart distribution parameters.
  516. Parameters
  517. ----------
  518. X : array-like of shape (n_samples, n_features)
  519. nk : array-like of shape (n_components,)
  520. xk : array-like of shape (n_components, n_features)
  521. sk : array-like of shape (n_features, n_features)
  522. """
  523. _, n_features = xk.shape
  524. # Warning : in some Bishop book, there is a typo on the formula 10.63
  525. # `degrees_of_freedom_k = degrees_of_freedom_0 + Nk`
  526. # is the correct formula
  527. self.degrees_of_freedom_ = (
  528. self.degrees_of_freedom_prior_ + nk.sum() / self.n_components
  529. )
  530. diff = xk - self.mean_prior_
  531. self.covariances_ = (
  532. self.covariance_prior_
  533. + sk * nk.sum() / self.n_components
  534. + self.mean_precision_prior_
  535. / self.n_components
  536. * np.dot((nk / self.mean_precision_) * diff.T, diff)
  537. )
  538. # Contrary to the original bishop book, we normalize the covariances
  539. self.covariances_ /= self.degrees_of_freedom_
  540. def _estimate_wishart_diag(self, nk, xk, sk):
  541. """Estimate the diag Wishart distribution parameters.
  542. Parameters
  543. ----------
  544. X : array-like of shape (n_samples, n_features)
  545. nk : array-like of shape (n_components,)
  546. xk : array-like of shape (n_components, n_features)
  547. sk : array-like of shape (n_components, n_features)
  548. """
  549. _, n_features = xk.shape
  550. # Warning : in some Bishop book, there is a typo on the formula 10.63
  551. # `degrees_of_freedom_k = degrees_of_freedom_0 + Nk`
  552. # is the correct formula
  553. self.degrees_of_freedom_ = self.degrees_of_freedom_prior_ + nk
  554. diff = xk - self.mean_prior_
  555. self.covariances_ = self.covariance_prior_ + nk[:, np.newaxis] * (
  556. sk
  557. + (self.mean_precision_prior_ / self.mean_precision_)[:, np.newaxis]
  558. * np.square(diff)
  559. )
  560. # Contrary to the original bishop book, we normalize the covariances
  561. self.covariances_ /= self.degrees_of_freedom_[:, np.newaxis]
  562. def _estimate_wishart_spherical(self, nk, xk, sk):
  563. """Estimate the spherical Wishart distribution parameters.
  564. Parameters
  565. ----------
  566. X : array-like of shape (n_samples, n_features)
  567. nk : array-like of shape (n_components,)
  568. xk : array-like of shape (n_components, n_features)
  569. sk : array-like of shape (n_components,)
  570. """
  571. _, n_features = xk.shape
  572. # Warning : in some Bishop book, there is a typo on the formula 10.63
  573. # `degrees_of_freedom_k = degrees_of_freedom_0 + Nk`
  574. # is the correct formula
  575. self.degrees_of_freedom_ = self.degrees_of_freedom_prior_ + nk
  576. diff = xk - self.mean_prior_
  577. self.covariances_ = self.covariance_prior_ + nk * (
  578. sk
  579. + self.mean_precision_prior_
  580. / self.mean_precision_
  581. * np.mean(np.square(diff), 1)
  582. )
  583. # Contrary to the original bishop book, we normalize the covariances
  584. self.covariances_ /= self.degrees_of_freedom_
  585. def _m_step(self, X, log_resp):
  586. """M step.
  587. Parameters
  588. ----------
  589. X : array-like of shape (n_samples, n_features)
  590. log_resp : array-like of shape (n_samples, n_components)
  591. Logarithm of the posterior probabilities (or responsibilities) of
  592. the point of each sample in X.
  593. """
  594. n_samples, _ = X.shape
  595. nk, xk, sk = _estimate_gaussian_parameters(
  596. X, np.exp(log_resp), self.reg_covar, self.covariance_type
  597. )
  598. self._estimate_weights(nk)
  599. self._estimate_means(nk, xk)
  600. self._estimate_precisions(nk, xk, sk)
  601. def _estimate_log_weights(self):
  602. if self.weight_concentration_prior_type == "dirichlet_process":
  603. digamma_sum = digamma(
  604. self.weight_concentration_[0] + self.weight_concentration_[1]
  605. )
  606. digamma_a = digamma(self.weight_concentration_[0])
  607. digamma_b = digamma(self.weight_concentration_[1])
  608. return (
  609. digamma_a
  610. - digamma_sum
  611. + np.hstack((0, np.cumsum(digamma_b - digamma_sum)[:-1]))
  612. )
  613. else:
  614. # case Variational Gaussian mixture with dirichlet distribution
  615. return digamma(self.weight_concentration_) - digamma(
  616. np.sum(self.weight_concentration_)
  617. )
  618. def _estimate_log_prob(self, X):
  619. _, n_features = X.shape
  620. # We remove `n_features * np.log(self.degrees_of_freedom_)` because
  621. # the precision matrix is normalized
  622. log_gauss = _estimate_log_gaussian_prob(
  623. X, self.means_, self.precisions_cholesky_, self.covariance_type
  624. ) - 0.5 * n_features * np.log(self.degrees_of_freedom_)
  625. log_lambda = n_features * np.log(2.0) + np.sum(
  626. digamma(
  627. 0.5
  628. * (self.degrees_of_freedom_ - np.arange(0, n_features)[:, np.newaxis])
  629. ),
  630. 0,
  631. )
  632. return log_gauss + 0.5 * (log_lambda - n_features / self.mean_precision_)
  633. def _compute_lower_bound(self, log_resp, log_prob_norm):
  634. """Estimate the lower bound of the model.
  635. The lower bound on the likelihood (of the training data with respect to
  636. the model) is used to detect the convergence and has to increase at
  637. each iteration.
  638. Parameters
  639. ----------
  640. X : array-like of shape (n_samples, n_features)
  641. log_resp : array, shape (n_samples, n_components)
  642. Logarithm of the posterior probabilities (or responsibilities) of
  643. the point of each sample in X.
  644. log_prob_norm : float
  645. Logarithm of the probability of each sample in X.
  646. Returns
  647. -------
  648. lower_bound : float
  649. """
  650. # Contrary to the original formula, we have done some simplification
  651. # and removed all the constant terms.
  652. (n_features,) = self.mean_prior_.shape
  653. # We removed `.5 * n_features * np.log(self.degrees_of_freedom_)`
  654. # because the precision matrix is normalized.
  655. log_det_precisions_chol = _compute_log_det_cholesky(
  656. self.precisions_cholesky_, self.covariance_type, n_features
  657. ) - 0.5 * n_features * np.log(self.degrees_of_freedom_)
  658. if self.covariance_type == "tied":
  659. log_wishart = self.n_components * np.float64(
  660. _log_wishart_norm(
  661. self.degrees_of_freedom_, log_det_precisions_chol, n_features
  662. )
  663. )
  664. else:
  665. log_wishart = np.sum(
  666. _log_wishart_norm(
  667. self.degrees_of_freedom_, log_det_precisions_chol, n_features
  668. )
  669. )
  670. if self.weight_concentration_prior_type == "dirichlet_process":
  671. log_norm_weight = -np.sum(
  672. betaln(self.weight_concentration_[0], self.weight_concentration_[1])
  673. )
  674. else:
  675. log_norm_weight = _log_dirichlet_norm(self.weight_concentration_)
  676. return (
  677. -np.sum(np.exp(log_resp) * log_resp)
  678. - log_wishart
  679. - log_norm_weight
  680. - 0.5 * n_features * np.sum(np.log(self.mean_precision_))
  681. )
  682. def _get_parameters(self):
  683. return (
  684. self.weight_concentration_,
  685. self.mean_precision_,
  686. self.means_,
  687. self.degrees_of_freedom_,
  688. self.covariances_,
  689. self.precisions_cholesky_,
  690. )
  691. def _set_parameters(self, params):
  692. (
  693. self.weight_concentration_,
  694. self.mean_precision_,
  695. self.means_,
  696. self.degrees_of_freedom_,
  697. self.covariances_,
  698. self.precisions_cholesky_,
  699. ) = params
  700. # Weights computation
  701. if self.weight_concentration_prior_type == "dirichlet_process":
  702. weight_dirichlet_sum = (
  703. self.weight_concentration_[0] + self.weight_concentration_[1]
  704. )
  705. tmp = self.weight_concentration_[1] / weight_dirichlet_sum
  706. self.weights_ = (
  707. self.weight_concentration_[0]
  708. / weight_dirichlet_sum
  709. * np.hstack((1, np.cumprod(tmp[:-1])))
  710. )
  711. self.weights_ /= np.sum(self.weights_)
  712. else:
  713. self.weights_ = self.weight_concentration_ / np.sum(
  714. self.weight_concentration_
  715. )
  716. # Precisions matrices computation
  717. if self.covariance_type == "full":
  718. self.precisions_ = np.array(
  719. [
  720. np.dot(prec_chol, prec_chol.T)
  721. for prec_chol in self.precisions_cholesky_
  722. ]
  723. )
  724. elif self.covariance_type == "tied":
  725. self.precisions_ = np.dot(
  726. self.precisions_cholesky_, self.precisions_cholesky_.T
  727. )
  728. else:
  729. self.precisions_ = self.precisions_cholesky_**2