_lda.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929
  1. """
  2. =============================================================
  3. Online Latent Dirichlet Allocation with variational inference
  4. =============================================================
  5. This implementation is modified from Matthew D. Hoffman's onlineldavb code
  6. Link: https://github.com/blei-lab/onlineldavb
  7. """
  8. # Author: Chyi-Kwei Yau
  9. # Author: Matthew D. Hoffman (original onlineldavb implementation)
  10. from numbers import Integral, Real
  11. import numpy as np
  12. import scipy.sparse as sp
  13. from joblib import effective_n_jobs
  14. from scipy.special import gammaln, logsumexp
  15. from ..base import (
  16. BaseEstimator,
  17. ClassNamePrefixFeaturesOutMixin,
  18. TransformerMixin,
  19. _fit_context,
  20. )
  21. from ..utils import check_random_state, gen_batches, gen_even_slices
  22. from ..utils._param_validation import Interval, StrOptions
  23. from ..utils.parallel import Parallel, delayed
  24. from ..utils.validation import check_is_fitted, check_non_negative
  25. from ._online_lda_fast import (
  26. _dirichlet_expectation_1d as cy_dirichlet_expectation_1d,
  27. )
  28. from ._online_lda_fast import (
  29. _dirichlet_expectation_2d,
  30. )
  31. from ._online_lda_fast import (
  32. mean_change as cy_mean_change,
  33. )
  34. EPS = np.finfo(float).eps
  35. def _update_doc_distribution(
  36. X,
  37. exp_topic_word_distr,
  38. doc_topic_prior,
  39. max_doc_update_iter,
  40. mean_change_tol,
  41. cal_sstats,
  42. random_state,
  43. ):
  44. """E-step: update document-topic distribution.
  45. Parameters
  46. ----------
  47. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  48. Document word matrix.
  49. exp_topic_word_distr : ndarray of shape (n_topics, n_features)
  50. Exponential value of expectation of log topic word distribution.
  51. In the literature, this is `exp(E[log(beta)])`.
  52. doc_topic_prior : float
  53. Prior of document topic distribution `theta`.
  54. max_doc_update_iter : int
  55. Max number of iterations for updating document topic distribution in
  56. the E-step.
  57. mean_change_tol : float
  58. Stopping tolerance for updating document topic distribution in E-step.
  59. cal_sstats : bool
  60. Parameter that indicate to calculate sufficient statistics or not.
  61. Set `cal_sstats` to `True` when we need to run M-step.
  62. random_state : RandomState instance or None
  63. Parameter that indicate how to initialize document topic distribution.
  64. Set `random_state` to None will initialize document topic distribution
  65. to a constant number.
  66. Returns
  67. -------
  68. (doc_topic_distr, suff_stats) :
  69. `doc_topic_distr` is unnormalized topic distribution for each document.
  70. In the literature, this is `gamma`. we can calculate `E[log(theta)]`
  71. from it.
  72. `suff_stats` is expected sufficient statistics for the M-step.
  73. When `cal_sstats == False`, this will be None.
  74. """
  75. is_sparse_x = sp.issparse(X)
  76. n_samples, n_features = X.shape
  77. n_topics = exp_topic_word_distr.shape[0]
  78. if random_state:
  79. doc_topic_distr = random_state.gamma(100.0, 0.01, (n_samples, n_topics)).astype(
  80. X.dtype, copy=False
  81. )
  82. else:
  83. doc_topic_distr = np.ones((n_samples, n_topics), dtype=X.dtype)
  84. # In the literature, this is `exp(E[log(theta)])`
  85. exp_doc_topic = np.exp(_dirichlet_expectation_2d(doc_topic_distr))
  86. # diff on `component_` (only calculate it when `cal_diff` is True)
  87. suff_stats = (
  88. np.zeros(exp_topic_word_distr.shape, dtype=X.dtype) if cal_sstats else None
  89. )
  90. if is_sparse_x:
  91. X_data = X.data
  92. X_indices = X.indices
  93. X_indptr = X.indptr
  94. # These cython functions are called in a nested loop on usually very small arrays
  95. # (length=n_topics). In that case, finding the appropriate signature of the
  96. # fused-typed function can be more costly than its execution, hence the dispatch
  97. # is done outside of the loop.
  98. ctype = "float" if X.dtype == np.float32 else "double"
  99. mean_change = cy_mean_change[ctype]
  100. dirichlet_expectation_1d = cy_dirichlet_expectation_1d[ctype]
  101. eps = np.finfo(X.dtype).eps
  102. for idx_d in range(n_samples):
  103. if is_sparse_x:
  104. ids = X_indices[X_indptr[idx_d] : X_indptr[idx_d + 1]]
  105. cnts = X_data[X_indptr[idx_d] : X_indptr[idx_d + 1]]
  106. else:
  107. ids = np.nonzero(X[idx_d, :])[0]
  108. cnts = X[idx_d, ids]
  109. doc_topic_d = doc_topic_distr[idx_d, :]
  110. # The next one is a copy, since the inner loop overwrites it.
  111. exp_doc_topic_d = exp_doc_topic[idx_d, :].copy()
  112. exp_topic_word_d = exp_topic_word_distr[:, ids]
  113. # Iterate between `doc_topic_d` and `norm_phi` until convergence
  114. for _ in range(0, max_doc_update_iter):
  115. last_d = doc_topic_d
  116. # The optimal phi_{dwk} is proportional to
  117. # exp(E[log(theta_{dk})]) * exp(E[log(beta_{dw})]).
  118. norm_phi = np.dot(exp_doc_topic_d, exp_topic_word_d) + eps
  119. doc_topic_d = exp_doc_topic_d * np.dot(cnts / norm_phi, exp_topic_word_d.T)
  120. # Note: adds doc_topic_prior to doc_topic_d, in-place.
  121. dirichlet_expectation_1d(doc_topic_d, doc_topic_prior, exp_doc_topic_d)
  122. if mean_change(last_d, doc_topic_d) < mean_change_tol:
  123. break
  124. doc_topic_distr[idx_d, :] = doc_topic_d
  125. # Contribution of document d to the expected sufficient
  126. # statistics for the M step.
  127. if cal_sstats:
  128. norm_phi = np.dot(exp_doc_topic_d, exp_topic_word_d) + eps
  129. suff_stats[:, ids] += np.outer(exp_doc_topic_d, cnts / norm_phi)
  130. return (doc_topic_distr, suff_stats)
  131. class LatentDirichletAllocation(
  132. ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator
  133. ):
  134. """Latent Dirichlet Allocation with online variational Bayes algorithm.
  135. The implementation is based on [1]_ and [2]_.
  136. .. versionadded:: 0.17
  137. Read more in the :ref:`User Guide <LatentDirichletAllocation>`.
  138. Parameters
  139. ----------
  140. n_components : int, default=10
  141. Number of topics.
  142. .. versionchanged:: 0.19
  143. ``n_topics`` was renamed to ``n_components``
  144. doc_topic_prior : float, default=None
  145. Prior of document topic distribution `theta`. If the value is None,
  146. defaults to `1 / n_components`.
  147. In [1]_, this is called `alpha`.
  148. topic_word_prior : float, default=None
  149. Prior of topic word distribution `beta`. If the value is None, defaults
  150. to `1 / n_components`.
  151. In [1]_, this is called `eta`.
  152. learning_method : {'batch', 'online'}, default='batch'
  153. Method used to update `_component`. Only used in :meth:`fit` method.
  154. In general, if the data size is large, the online update will be much
  155. faster than the batch update.
  156. Valid options::
  157. 'batch': Batch variational Bayes method. Use all training data in
  158. each EM update.
  159. Old `components_` will be overwritten in each iteration.
  160. 'online': Online variational Bayes method. In each EM update, use
  161. mini-batch of training data to update the ``components_``
  162. variable incrementally. The learning rate is controlled by the
  163. ``learning_decay`` and the ``learning_offset`` parameters.
  164. .. versionchanged:: 0.20
  165. The default learning method is now ``"batch"``.
  166. learning_decay : float, default=0.7
  167. It is a parameter that control learning rate in the online learning
  168. method. The value should be set between (0.5, 1.0] to guarantee
  169. asymptotic convergence. When the value is 0.0 and batch_size is
  170. ``n_samples``, the update method is same as batch learning. In the
  171. literature, this is called kappa.
  172. learning_offset : float, default=10.0
  173. A (positive) parameter that downweights early iterations in online
  174. learning. It should be greater than 1.0. In the literature, this is
  175. called tau_0.
  176. max_iter : int, default=10
  177. The maximum number of passes over the training data (aka epochs).
  178. It only impacts the behavior in the :meth:`fit` method, and not the
  179. :meth:`partial_fit` method.
  180. batch_size : int, default=128
  181. Number of documents to use in each EM iteration. Only used in online
  182. learning.
  183. evaluate_every : int, default=-1
  184. How often to evaluate perplexity. Only used in `fit` method.
  185. set it to 0 or negative number to not evaluate perplexity in
  186. training at all. Evaluating perplexity can help you check convergence
  187. in training process, but it will also increase total training time.
  188. Evaluating perplexity in every iteration might increase training time
  189. up to two-fold.
  190. total_samples : int, default=1e6
  191. Total number of documents. Only used in the :meth:`partial_fit` method.
  192. perp_tol : float, default=1e-1
  193. Perplexity tolerance in batch learning. Only used when
  194. ``evaluate_every`` is greater than 0.
  195. mean_change_tol : float, default=1e-3
  196. Stopping tolerance for updating document topic distribution in E-step.
  197. max_doc_update_iter : int, default=100
  198. Max number of iterations for updating document topic distribution in
  199. the E-step.
  200. n_jobs : int, default=None
  201. The number of jobs to use in the E-step.
  202. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
  203. ``-1`` means using all processors. See :term:`Glossary <n_jobs>`
  204. for more details.
  205. verbose : int, default=0
  206. Verbosity level.
  207. random_state : int, RandomState instance or None, default=None
  208. Pass an int for reproducible results across multiple function calls.
  209. See :term:`Glossary <random_state>`.
  210. Attributes
  211. ----------
  212. components_ : ndarray of shape (n_components, n_features)
  213. Variational parameters for topic word distribution. Since the complete
  214. conditional for topic word distribution is a Dirichlet,
  215. ``components_[i, j]`` can be viewed as pseudocount that represents the
  216. number of times word `j` was assigned to topic `i`.
  217. It can also be viewed as distribution over the words for each topic
  218. after normalization:
  219. ``model.components_ / model.components_.sum(axis=1)[:, np.newaxis]``.
  220. exp_dirichlet_component_ : ndarray of shape (n_components, n_features)
  221. Exponential value of expectation of log topic word distribution.
  222. In the literature, this is `exp(E[log(beta)])`.
  223. n_batch_iter_ : int
  224. Number of iterations of the EM step.
  225. n_features_in_ : int
  226. Number of features seen during :term:`fit`.
  227. .. versionadded:: 0.24
  228. feature_names_in_ : ndarray of shape (`n_features_in_`,)
  229. Names of features seen during :term:`fit`. Defined only when `X`
  230. has feature names that are all strings.
  231. .. versionadded:: 1.0
  232. n_iter_ : int
  233. Number of passes over the dataset.
  234. bound_ : float
  235. Final perplexity score on training set.
  236. doc_topic_prior_ : float
  237. Prior of document topic distribution `theta`. If the value is None,
  238. it is `1 / n_components`.
  239. random_state_ : RandomState instance
  240. RandomState instance that is generated either from a seed, the random
  241. number generator or by `np.random`.
  242. topic_word_prior_ : float
  243. Prior of topic word distribution `beta`. If the value is None, it is
  244. `1 / n_components`.
  245. See Also
  246. --------
  247. sklearn.discriminant_analysis.LinearDiscriminantAnalysis:
  248. A classifier with a linear decision boundary, generated by fitting
  249. class conditional densities to the data and using Bayes' rule.
  250. References
  251. ----------
  252. .. [1] "Online Learning for Latent Dirichlet Allocation", Matthew D.
  253. Hoffman, David M. Blei, Francis Bach, 2010
  254. https://github.com/blei-lab/onlineldavb
  255. .. [2] "Stochastic Variational Inference", Matthew D. Hoffman,
  256. David M. Blei, Chong Wang, John Paisley, 2013
  257. Examples
  258. --------
  259. >>> from sklearn.decomposition import LatentDirichletAllocation
  260. >>> from sklearn.datasets import make_multilabel_classification
  261. >>> # This produces a feature matrix of token counts, similar to what
  262. >>> # CountVectorizer would produce on text.
  263. >>> X, _ = make_multilabel_classification(random_state=0)
  264. >>> lda = LatentDirichletAllocation(n_components=5,
  265. ... random_state=0)
  266. >>> lda.fit(X)
  267. LatentDirichletAllocation(...)
  268. >>> # get topics for some given samples:
  269. >>> lda.transform(X[-2:])
  270. array([[0.00360392, 0.25499205, 0.0036211 , 0.64236448, 0.09541846],
  271. [0.15297572, 0.00362644, 0.44412786, 0.39568399, 0.003586 ]])
  272. """
  273. _parameter_constraints: dict = {
  274. "n_components": [Interval(Integral, 0, None, closed="neither")],
  275. "doc_topic_prior": [None, Interval(Real, 0, 1, closed="both")],
  276. "topic_word_prior": [None, Interval(Real, 0, 1, closed="both")],
  277. "learning_method": [StrOptions({"batch", "online"})],
  278. "learning_decay": [Interval(Real, 0, 1, closed="both")],
  279. "learning_offset": [Interval(Real, 1.0, None, closed="left")],
  280. "max_iter": [Interval(Integral, 0, None, closed="left")],
  281. "batch_size": [Interval(Integral, 0, None, closed="neither")],
  282. "evaluate_every": [Interval(Integral, None, None, closed="neither")],
  283. "total_samples": [Interval(Real, 0, None, closed="neither")],
  284. "perp_tol": [Interval(Real, 0, None, closed="left")],
  285. "mean_change_tol": [Interval(Real, 0, None, closed="left")],
  286. "max_doc_update_iter": [Interval(Integral, 0, None, closed="left")],
  287. "n_jobs": [None, Integral],
  288. "verbose": ["verbose"],
  289. "random_state": ["random_state"],
  290. }
  291. def __init__(
  292. self,
  293. n_components=10,
  294. *,
  295. doc_topic_prior=None,
  296. topic_word_prior=None,
  297. learning_method="batch",
  298. learning_decay=0.7,
  299. learning_offset=10.0,
  300. max_iter=10,
  301. batch_size=128,
  302. evaluate_every=-1,
  303. total_samples=1e6,
  304. perp_tol=1e-1,
  305. mean_change_tol=1e-3,
  306. max_doc_update_iter=100,
  307. n_jobs=None,
  308. verbose=0,
  309. random_state=None,
  310. ):
  311. self.n_components = n_components
  312. self.doc_topic_prior = doc_topic_prior
  313. self.topic_word_prior = topic_word_prior
  314. self.learning_method = learning_method
  315. self.learning_decay = learning_decay
  316. self.learning_offset = learning_offset
  317. self.max_iter = max_iter
  318. self.batch_size = batch_size
  319. self.evaluate_every = evaluate_every
  320. self.total_samples = total_samples
  321. self.perp_tol = perp_tol
  322. self.mean_change_tol = mean_change_tol
  323. self.max_doc_update_iter = max_doc_update_iter
  324. self.n_jobs = n_jobs
  325. self.verbose = verbose
  326. self.random_state = random_state
  327. def _init_latent_vars(self, n_features, dtype=np.float64):
  328. """Initialize latent variables."""
  329. self.random_state_ = check_random_state(self.random_state)
  330. self.n_batch_iter_ = 1
  331. self.n_iter_ = 0
  332. if self.doc_topic_prior is None:
  333. self.doc_topic_prior_ = 1.0 / self.n_components
  334. else:
  335. self.doc_topic_prior_ = self.doc_topic_prior
  336. if self.topic_word_prior is None:
  337. self.topic_word_prior_ = 1.0 / self.n_components
  338. else:
  339. self.topic_word_prior_ = self.topic_word_prior
  340. init_gamma = 100.0
  341. init_var = 1.0 / init_gamma
  342. # In the literature, this is called `lambda`
  343. self.components_ = self.random_state_.gamma(
  344. init_gamma, init_var, (self.n_components, n_features)
  345. ).astype(dtype, copy=False)
  346. # In the literature, this is `exp(E[log(beta)])`
  347. self.exp_dirichlet_component_ = np.exp(
  348. _dirichlet_expectation_2d(self.components_)
  349. )
  350. def _e_step(self, X, cal_sstats, random_init, parallel=None):
  351. """E-step in EM update.
  352. Parameters
  353. ----------
  354. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  355. Document word matrix.
  356. cal_sstats : bool
  357. Parameter that indicate whether to calculate sufficient statistics
  358. or not. Set ``cal_sstats`` to True when we need to run M-step.
  359. random_init : bool
  360. Parameter that indicate whether to initialize document topic
  361. distribution randomly in the E-step. Set it to True in training
  362. steps.
  363. parallel : joblib.Parallel, default=None
  364. Pre-initialized instance of joblib.Parallel.
  365. Returns
  366. -------
  367. (doc_topic_distr, suff_stats) :
  368. `doc_topic_distr` is unnormalized topic distribution for each
  369. document. In the literature, this is called `gamma`.
  370. `suff_stats` is expected sufficient statistics for the M-step.
  371. When `cal_sstats == False`, it will be None.
  372. """
  373. # Run e-step in parallel
  374. random_state = self.random_state_ if random_init else None
  375. # TODO: make Parallel._effective_n_jobs public instead?
  376. n_jobs = effective_n_jobs(self.n_jobs)
  377. if parallel is None:
  378. parallel = Parallel(n_jobs=n_jobs, verbose=max(0, self.verbose - 1))
  379. results = parallel(
  380. delayed(_update_doc_distribution)(
  381. X[idx_slice, :],
  382. self.exp_dirichlet_component_,
  383. self.doc_topic_prior_,
  384. self.max_doc_update_iter,
  385. self.mean_change_tol,
  386. cal_sstats,
  387. random_state,
  388. )
  389. for idx_slice in gen_even_slices(X.shape[0], n_jobs)
  390. )
  391. # merge result
  392. doc_topics, sstats_list = zip(*results)
  393. doc_topic_distr = np.vstack(doc_topics)
  394. if cal_sstats:
  395. # This step finishes computing the sufficient statistics for the
  396. # M-step.
  397. suff_stats = np.zeros(self.components_.shape, dtype=self.components_.dtype)
  398. for sstats in sstats_list:
  399. suff_stats += sstats
  400. suff_stats *= self.exp_dirichlet_component_
  401. else:
  402. suff_stats = None
  403. return (doc_topic_distr, suff_stats)
  404. def _em_step(self, X, total_samples, batch_update, parallel=None):
  405. """EM update for 1 iteration.
  406. update `_component` by batch VB or online VB.
  407. Parameters
  408. ----------
  409. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  410. Document word matrix.
  411. total_samples : int
  412. Total number of documents. It is only used when
  413. batch_update is `False`.
  414. batch_update : bool
  415. Parameter that controls updating method.
  416. `True` for batch learning, `False` for online learning.
  417. parallel : joblib.Parallel, default=None
  418. Pre-initialized instance of joblib.Parallel
  419. Returns
  420. -------
  421. doc_topic_distr : ndarray of shape (n_samples, n_components)
  422. Unnormalized document topic distribution.
  423. """
  424. # E-step
  425. _, suff_stats = self._e_step(
  426. X, cal_sstats=True, random_init=True, parallel=parallel
  427. )
  428. # M-step
  429. if batch_update:
  430. self.components_ = self.topic_word_prior_ + suff_stats
  431. else:
  432. # online update
  433. # In the literature, the weight is `rho`
  434. weight = np.power(
  435. self.learning_offset + self.n_batch_iter_, -self.learning_decay
  436. )
  437. doc_ratio = float(total_samples) / X.shape[0]
  438. self.components_ *= 1 - weight
  439. self.components_ += weight * (
  440. self.topic_word_prior_ + doc_ratio * suff_stats
  441. )
  442. # update `component_` related variables
  443. self.exp_dirichlet_component_ = np.exp(
  444. _dirichlet_expectation_2d(self.components_)
  445. )
  446. self.n_batch_iter_ += 1
  447. return
  448. def _more_tags(self):
  449. return {
  450. "preserves_dtype": [np.float64, np.float32],
  451. "requires_positive_X": True,
  452. }
  453. def _check_non_neg_array(self, X, reset_n_features, whom):
  454. """check X format
  455. check X format and make sure no negative value in X.
  456. Parameters
  457. ----------
  458. X : array-like or sparse matrix
  459. """
  460. dtype = [np.float64, np.float32] if reset_n_features else self.components_.dtype
  461. X = self._validate_data(
  462. X,
  463. reset=reset_n_features,
  464. accept_sparse="csr",
  465. dtype=dtype,
  466. )
  467. check_non_negative(X, whom)
  468. return X
  469. @_fit_context(prefer_skip_nested_validation=True)
  470. def partial_fit(self, X, y=None):
  471. """Online VB with Mini-Batch update.
  472. Parameters
  473. ----------
  474. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  475. Document word matrix.
  476. y : Ignored
  477. Not used, present here for API consistency by convention.
  478. Returns
  479. -------
  480. self
  481. Partially fitted estimator.
  482. """
  483. first_time = not hasattr(self, "components_")
  484. X = self._check_non_neg_array(
  485. X, reset_n_features=first_time, whom="LatentDirichletAllocation.partial_fit"
  486. )
  487. n_samples, n_features = X.shape
  488. batch_size = self.batch_size
  489. # initialize parameters or check
  490. if first_time:
  491. self._init_latent_vars(n_features, dtype=X.dtype)
  492. if n_features != self.components_.shape[1]:
  493. raise ValueError(
  494. "The provided data has %d dimensions while "
  495. "the model was trained with feature size %d."
  496. % (n_features, self.components_.shape[1])
  497. )
  498. n_jobs = effective_n_jobs(self.n_jobs)
  499. with Parallel(n_jobs=n_jobs, verbose=max(0, self.verbose - 1)) as parallel:
  500. for idx_slice in gen_batches(n_samples, batch_size):
  501. self._em_step(
  502. X[idx_slice, :],
  503. total_samples=self.total_samples,
  504. batch_update=False,
  505. parallel=parallel,
  506. )
  507. return self
  508. @_fit_context(prefer_skip_nested_validation=True)
  509. def fit(self, X, y=None):
  510. """Learn model for the data X with variational Bayes method.
  511. When `learning_method` is 'online', use mini-batch update.
  512. Otherwise, use batch update.
  513. Parameters
  514. ----------
  515. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  516. Document word matrix.
  517. y : Ignored
  518. Not used, present here for API consistency by convention.
  519. Returns
  520. -------
  521. self
  522. Fitted estimator.
  523. """
  524. X = self._check_non_neg_array(
  525. X, reset_n_features=True, whom="LatentDirichletAllocation.fit"
  526. )
  527. n_samples, n_features = X.shape
  528. max_iter = self.max_iter
  529. evaluate_every = self.evaluate_every
  530. learning_method = self.learning_method
  531. batch_size = self.batch_size
  532. # initialize parameters
  533. self._init_latent_vars(n_features, dtype=X.dtype)
  534. # change to perplexity later
  535. last_bound = None
  536. n_jobs = effective_n_jobs(self.n_jobs)
  537. with Parallel(n_jobs=n_jobs, verbose=max(0, self.verbose - 1)) as parallel:
  538. for i in range(max_iter):
  539. if learning_method == "online":
  540. for idx_slice in gen_batches(n_samples, batch_size):
  541. self._em_step(
  542. X[idx_slice, :],
  543. total_samples=n_samples,
  544. batch_update=False,
  545. parallel=parallel,
  546. )
  547. else:
  548. # batch update
  549. self._em_step(
  550. X, total_samples=n_samples, batch_update=True, parallel=parallel
  551. )
  552. # check perplexity
  553. if evaluate_every > 0 and (i + 1) % evaluate_every == 0:
  554. doc_topics_distr, _ = self._e_step(
  555. X, cal_sstats=False, random_init=False, parallel=parallel
  556. )
  557. bound = self._perplexity_precomp_distr(
  558. X, doc_topics_distr, sub_sampling=False
  559. )
  560. if self.verbose:
  561. print(
  562. "iteration: %d of max_iter: %d, perplexity: %.4f"
  563. % (i + 1, max_iter, bound)
  564. )
  565. if last_bound and abs(last_bound - bound) < self.perp_tol:
  566. break
  567. last_bound = bound
  568. elif self.verbose:
  569. print("iteration: %d of max_iter: %d" % (i + 1, max_iter))
  570. self.n_iter_ += 1
  571. # calculate final perplexity value on train set
  572. doc_topics_distr, _ = self._e_step(
  573. X, cal_sstats=False, random_init=False, parallel=parallel
  574. )
  575. self.bound_ = self._perplexity_precomp_distr(
  576. X, doc_topics_distr, sub_sampling=False
  577. )
  578. return self
  579. def _unnormalized_transform(self, X):
  580. """Transform data X according to fitted model.
  581. Parameters
  582. ----------
  583. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  584. Document word matrix.
  585. Returns
  586. -------
  587. doc_topic_distr : ndarray of shape (n_samples, n_components)
  588. Document topic distribution for X.
  589. """
  590. doc_topic_distr, _ = self._e_step(X, cal_sstats=False, random_init=False)
  591. return doc_topic_distr
  592. def transform(self, X):
  593. """Transform data X according to the fitted model.
  594. .. versionchanged:: 0.18
  595. *doc_topic_distr* is now normalized
  596. Parameters
  597. ----------
  598. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  599. Document word matrix.
  600. Returns
  601. -------
  602. doc_topic_distr : ndarray of shape (n_samples, n_components)
  603. Document topic distribution for X.
  604. """
  605. check_is_fitted(self)
  606. X = self._check_non_neg_array(
  607. X, reset_n_features=False, whom="LatentDirichletAllocation.transform"
  608. )
  609. doc_topic_distr = self._unnormalized_transform(X)
  610. doc_topic_distr /= doc_topic_distr.sum(axis=1)[:, np.newaxis]
  611. return doc_topic_distr
  612. def _approx_bound(self, X, doc_topic_distr, sub_sampling):
  613. """Estimate the variational bound.
  614. Estimate the variational bound over "all documents" using only the
  615. documents passed in as X. Since log-likelihood of each word cannot
  616. be computed directly, we use this bound to estimate it.
  617. Parameters
  618. ----------
  619. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  620. Document word matrix.
  621. doc_topic_distr : ndarray of shape (n_samples, n_components)
  622. Document topic distribution. In the literature, this is called
  623. gamma.
  624. sub_sampling : bool, default=False
  625. Compensate for subsampling of documents.
  626. It is used in calculate bound in online learning.
  627. Returns
  628. -------
  629. score : float
  630. """
  631. def _loglikelihood(prior, distr, dirichlet_distr, size):
  632. # calculate log-likelihood
  633. score = np.sum((prior - distr) * dirichlet_distr)
  634. score += np.sum(gammaln(distr) - gammaln(prior))
  635. score += np.sum(gammaln(prior * size) - gammaln(np.sum(distr, 1)))
  636. return score
  637. is_sparse_x = sp.issparse(X)
  638. n_samples, n_components = doc_topic_distr.shape
  639. n_features = self.components_.shape[1]
  640. score = 0
  641. dirichlet_doc_topic = _dirichlet_expectation_2d(doc_topic_distr)
  642. dirichlet_component_ = _dirichlet_expectation_2d(self.components_)
  643. doc_topic_prior = self.doc_topic_prior_
  644. topic_word_prior = self.topic_word_prior_
  645. if is_sparse_x:
  646. X_data = X.data
  647. X_indices = X.indices
  648. X_indptr = X.indptr
  649. # E[log p(docs | theta, beta)]
  650. for idx_d in range(0, n_samples):
  651. if is_sparse_x:
  652. ids = X_indices[X_indptr[idx_d] : X_indptr[idx_d + 1]]
  653. cnts = X_data[X_indptr[idx_d] : X_indptr[idx_d + 1]]
  654. else:
  655. ids = np.nonzero(X[idx_d, :])[0]
  656. cnts = X[idx_d, ids]
  657. temp = (
  658. dirichlet_doc_topic[idx_d, :, np.newaxis] + dirichlet_component_[:, ids]
  659. )
  660. norm_phi = logsumexp(temp, axis=0)
  661. score += np.dot(cnts, norm_phi)
  662. # compute E[log p(theta | alpha) - log q(theta | gamma)]
  663. score += _loglikelihood(
  664. doc_topic_prior, doc_topic_distr, dirichlet_doc_topic, self.n_components
  665. )
  666. # Compensate for the subsampling of the population of documents
  667. if sub_sampling:
  668. doc_ratio = float(self.total_samples) / n_samples
  669. score *= doc_ratio
  670. # E[log p(beta | eta) - log q (beta | lambda)]
  671. score += _loglikelihood(
  672. topic_word_prior, self.components_, dirichlet_component_, n_features
  673. )
  674. return score
  675. def score(self, X, y=None):
  676. """Calculate approximate log-likelihood as score.
  677. Parameters
  678. ----------
  679. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  680. Document word matrix.
  681. y : Ignored
  682. Not used, present here for API consistency by convention.
  683. Returns
  684. -------
  685. score : float
  686. Use approximate bound as score.
  687. """
  688. check_is_fitted(self)
  689. X = self._check_non_neg_array(
  690. X, reset_n_features=False, whom="LatentDirichletAllocation.score"
  691. )
  692. doc_topic_distr = self._unnormalized_transform(X)
  693. score = self._approx_bound(X, doc_topic_distr, sub_sampling=False)
  694. return score
  695. def _perplexity_precomp_distr(self, X, doc_topic_distr=None, sub_sampling=False):
  696. """Calculate approximate perplexity for data X with ability to accept
  697. precomputed doc_topic_distr
  698. Perplexity is defined as exp(-1. * log-likelihood per word)
  699. Parameters
  700. ----------
  701. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  702. Document word matrix.
  703. doc_topic_distr : ndarray of shape (n_samples, n_components), \
  704. default=None
  705. Document topic distribution.
  706. If it is None, it will be generated by applying transform on X.
  707. Returns
  708. -------
  709. score : float
  710. Perplexity score.
  711. """
  712. if doc_topic_distr is None:
  713. doc_topic_distr = self._unnormalized_transform(X)
  714. else:
  715. n_samples, n_components = doc_topic_distr.shape
  716. if n_samples != X.shape[0]:
  717. raise ValueError(
  718. "Number of samples in X and doc_topic_distr do not match."
  719. )
  720. if n_components != self.n_components:
  721. raise ValueError("Number of topics does not match.")
  722. current_samples = X.shape[0]
  723. bound = self._approx_bound(X, doc_topic_distr, sub_sampling)
  724. if sub_sampling:
  725. word_cnt = X.sum() * (float(self.total_samples) / current_samples)
  726. else:
  727. word_cnt = X.sum()
  728. perword_bound = bound / word_cnt
  729. return np.exp(-1.0 * perword_bound)
  730. def perplexity(self, X, sub_sampling=False):
  731. """Calculate approximate perplexity for data X.
  732. Perplexity is defined as exp(-1. * log-likelihood per word)
  733. .. versionchanged:: 0.19
  734. *doc_topic_distr* argument has been deprecated and is ignored
  735. because user no longer has access to unnormalized distribution
  736. Parameters
  737. ----------
  738. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  739. Document word matrix.
  740. sub_sampling : bool
  741. Do sub-sampling or not.
  742. Returns
  743. -------
  744. score : float
  745. Perplexity score.
  746. """
  747. check_is_fitted(self)
  748. X = self._check_non_neg_array(
  749. X, reset_n_features=True, whom="LatentDirichletAllocation.perplexity"
  750. )
  751. return self._perplexity_precomp_distr(X, sub_sampling=sub_sampling)
  752. @property
  753. def _n_features_out(self):
  754. """Number of transformed output features."""
  755. return self.components_.shape[0]