_dict_learning.py 81 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495
  1. """ Dictionary learning.
  2. """
  3. # Author: Vlad Niculae, Gael Varoquaux, Alexandre Gramfort
  4. # License: BSD 3 clause
  5. import itertools
  6. import sys
  7. import time
  8. import warnings
  9. from math import ceil
  10. from numbers import Integral, Real
  11. import numpy as np
  12. from joblib import effective_n_jobs
  13. from scipy import linalg
  14. from ..base import (
  15. BaseEstimator,
  16. ClassNamePrefixFeaturesOutMixin,
  17. TransformerMixin,
  18. _fit_context,
  19. )
  20. from ..linear_model import Lars, Lasso, LassoLars, orthogonal_mp_gram
  21. from ..utils import check_array, check_random_state, gen_batches, gen_even_slices
  22. from ..utils._param_validation import Hidden, Interval, StrOptions, validate_params
  23. from ..utils.extmath import randomized_svd, row_norms, svd_flip
  24. from ..utils.parallel import Parallel, delayed
  25. from ..utils.validation import check_is_fitted
  26. def _check_positive_coding(method, positive):
  27. if positive and method in ["omp", "lars"]:
  28. raise ValueError(
  29. "Positive constraint not supported for '{}' coding method.".format(method)
  30. )
  31. def _sparse_encode_precomputed(
  32. X,
  33. dictionary,
  34. *,
  35. gram=None,
  36. cov=None,
  37. algorithm="lasso_lars",
  38. regularization=None,
  39. copy_cov=True,
  40. init=None,
  41. max_iter=1000,
  42. verbose=0,
  43. positive=False,
  44. ):
  45. """Generic sparse coding with precomputed Gram and/or covariance matrices.
  46. Each row of the result is the solution to a Lasso problem.
  47. Parameters
  48. ----------
  49. X : ndarray of shape (n_samples, n_features)
  50. Data matrix.
  51. dictionary : ndarray of shape (n_components, n_features)
  52. The dictionary matrix against which to solve the sparse coding of
  53. the data. Some of the algorithms assume normalized rows.
  54. gram : ndarray of shape (n_components, n_components), default=None
  55. Precomputed Gram matrix, `dictionary * dictionary'`
  56. gram can be `None` if method is 'threshold'.
  57. cov : ndarray of shape (n_components, n_samples), default=None
  58. Precomputed covariance, `dictionary * X'`.
  59. algorithm : {'lasso_lars', 'lasso_cd', 'lars', 'omp', 'threshold'}, \
  60. default='lasso_lars'
  61. The algorithm used:
  62. * `'lars'`: uses the least angle regression method
  63. (`linear_model.lars_path`);
  64. * `'lasso_lars'`: uses Lars to compute the Lasso solution;
  65. * `'lasso_cd'`: uses the coordinate descent method to compute the
  66. Lasso solution (`linear_model.Lasso`). lasso_lars will be faster if
  67. the estimated components are sparse;
  68. * `'omp'`: uses orthogonal matching pursuit to estimate the sparse
  69. solution;
  70. * `'threshold'`: squashes to zero all coefficients less than
  71. regularization from the projection `dictionary * data'`.
  72. regularization : int or float, default=None
  73. The regularization parameter. It corresponds to alpha when
  74. algorithm is `'lasso_lars'`, `'lasso_cd'` or `'threshold'`.
  75. Otherwise it corresponds to `n_nonzero_coefs`.
  76. init : ndarray of shape (n_samples, n_components), default=None
  77. Initialization value of the sparse code. Only used if
  78. `algorithm='lasso_cd'`.
  79. max_iter : int, default=1000
  80. Maximum number of iterations to perform if `algorithm='lasso_cd'` or
  81. `'lasso_lars'`.
  82. copy_cov : bool, default=True
  83. Whether to copy the precomputed covariance matrix; if `False`, it may
  84. be overwritten.
  85. verbose : int, default=0
  86. Controls the verbosity; the higher, the more messages.
  87. positive: bool, default=False
  88. Whether to enforce a positivity constraint on the sparse code.
  89. .. versionadded:: 0.20
  90. Returns
  91. -------
  92. code : ndarray of shape (n_components, n_features)
  93. The sparse codes.
  94. """
  95. n_samples, n_features = X.shape
  96. n_components = dictionary.shape[0]
  97. if algorithm == "lasso_lars":
  98. alpha = float(regularization) / n_features # account for scaling
  99. try:
  100. err_mgt = np.seterr(all="ignore")
  101. # Not passing in verbose=max(0, verbose-1) because Lars.fit already
  102. # corrects the verbosity level.
  103. lasso_lars = LassoLars(
  104. alpha=alpha,
  105. fit_intercept=False,
  106. verbose=verbose,
  107. precompute=gram,
  108. fit_path=False,
  109. positive=positive,
  110. max_iter=max_iter,
  111. )
  112. lasso_lars.fit(dictionary.T, X.T, Xy=cov)
  113. new_code = lasso_lars.coef_
  114. finally:
  115. np.seterr(**err_mgt)
  116. elif algorithm == "lasso_cd":
  117. alpha = float(regularization) / n_features # account for scaling
  118. # TODO: Make verbosity argument for Lasso?
  119. # sklearn.linear_model.coordinate_descent.enet_path has a verbosity
  120. # argument that we could pass in from Lasso.
  121. clf = Lasso(
  122. alpha=alpha,
  123. fit_intercept=False,
  124. precompute=gram,
  125. max_iter=max_iter,
  126. warm_start=True,
  127. positive=positive,
  128. )
  129. if init is not None:
  130. # In some workflows using coordinate descent algorithms:
  131. # - users might provide NumPy arrays with read-only buffers
  132. # - `joblib` might memmap arrays making their buffer read-only
  133. # TODO: move this handling (which is currently too broad)
  134. # closer to the actual private function which need buffers to be writable.
  135. if not init.flags["WRITEABLE"]:
  136. init = np.array(init)
  137. clf.coef_ = init
  138. clf.fit(dictionary.T, X.T, check_input=False)
  139. new_code = clf.coef_
  140. elif algorithm == "lars":
  141. try:
  142. err_mgt = np.seterr(all="ignore")
  143. # Not passing in verbose=max(0, verbose-1) because Lars.fit already
  144. # corrects the verbosity level.
  145. lars = Lars(
  146. fit_intercept=False,
  147. verbose=verbose,
  148. precompute=gram,
  149. n_nonzero_coefs=int(regularization),
  150. fit_path=False,
  151. )
  152. lars.fit(dictionary.T, X.T, Xy=cov)
  153. new_code = lars.coef_
  154. finally:
  155. np.seterr(**err_mgt)
  156. elif algorithm == "threshold":
  157. new_code = (np.sign(cov) * np.maximum(np.abs(cov) - regularization, 0)).T
  158. if positive:
  159. np.clip(new_code, 0, None, out=new_code)
  160. elif algorithm == "omp":
  161. new_code = orthogonal_mp_gram(
  162. Gram=gram,
  163. Xy=cov,
  164. n_nonzero_coefs=int(regularization),
  165. tol=None,
  166. norms_squared=row_norms(X, squared=True),
  167. copy_Xy=copy_cov,
  168. ).T
  169. return new_code.reshape(n_samples, n_components)
  170. @validate_params(
  171. {
  172. "X": ["array-like"],
  173. "dictionary": ["array-like"],
  174. "gram": ["array-like", None],
  175. "cov": ["array-like", None],
  176. "algorithm": [
  177. StrOptions({"lasso_lars", "lasso_cd", "lars", "omp", "threshold"})
  178. ],
  179. "n_nonzero_coefs": [Interval(Integral, 1, None, closed="left"), None],
  180. "alpha": [Interval(Real, 0, None, closed="left"), None],
  181. "copy_cov": ["boolean"],
  182. "init": ["array-like", None],
  183. "max_iter": [Interval(Integral, 0, None, closed="left")],
  184. "n_jobs": [Integral, None],
  185. "check_input": ["boolean"],
  186. "verbose": ["verbose"],
  187. "positive": ["boolean"],
  188. },
  189. prefer_skip_nested_validation=True,
  190. )
  191. # XXX : could be moved to the linear_model module
  192. def sparse_encode(
  193. X,
  194. dictionary,
  195. *,
  196. gram=None,
  197. cov=None,
  198. algorithm="lasso_lars",
  199. n_nonzero_coefs=None,
  200. alpha=None,
  201. copy_cov=True,
  202. init=None,
  203. max_iter=1000,
  204. n_jobs=None,
  205. check_input=True,
  206. verbose=0,
  207. positive=False,
  208. ):
  209. """Sparse coding.
  210. Each row of the result is the solution to a sparse coding problem.
  211. The goal is to find a sparse array `code` such that::
  212. X ~= code * dictionary
  213. Read more in the :ref:`User Guide <SparseCoder>`.
  214. Parameters
  215. ----------
  216. X : array-like of shape (n_samples, n_features)
  217. Data matrix.
  218. dictionary : array-like of shape (n_components, n_features)
  219. The dictionary matrix against which to solve the sparse coding of
  220. the data. Some of the algorithms assume normalized rows for meaningful
  221. output.
  222. gram : array-like of shape (n_components, n_components), default=None
  223. Precomputed Gram matrix, `dictionary * dictionary'`.
  224. cov : array-like of shape (n_components, n_samples), default=None
  225. Precomputed covariance, `dictionary' * X`.
  226. algorithm : {'lasso_lars', 'lasso_cd', 'lars', 'omp', 'threshold'}, \
  227. default='lasso_lars'
  228. The algorithm used:
  229. * `'lars'`: uses the least angle regression method
  230. (`linear_model.lars_path`);
  231. * `'lasso_lars'`: uses Lars to compute the Lasso solution;
  232. * `'lasso_cd'`: uses the coordinate descent method to compute the
  233. Lasso solution (`linear_model.Lasso`). lasso_lars will be faster if
  234. the estimated components are sparse;
  235. * `'omp'`: uses orthogonal matching pursuit to estimate the sparse
  236. solution;
  237. * `'threshold'`: squashes to zero all coefficients less than
  238. regularization from the projection `dictionary * data'`.
  239. n_nonzero_coefs : int, default=None
  240. Number of nonzero coefficients to target in each column of the
  241. solution. This is only used by `algorithm='lars'` and `algorithm='omp'`
  242. and is overridden by `alpha` in the `omp` case. If `None`, then
  243. `n_nonzero_coefs=int(n_features / 10)`.
  244. alpha : float, default=None
  245. If `algorithm='lasso_lars'` or `algorithm='lasso_cd'`, `alpha` is the
  246. penalty applied to the L1 norm.
  247. If `algorithm='threshold'`, `alpha` is the absolute value of the
  248. threshold below which coefficients will be squashed to zero.
  249. If `algorithm='omp'`, `alpha` is the tolerance parameter: the value of
  250. the reconstruction error targeted. In this case, it overrides
  251. `n_nonzero_coefs`.
  252. If `None`, default to 1.
  253. copy_cov : bool, default=True
  254. Whether to copy the precomputed covariance matrix; if `False`, it may
  255. be overwritten.
  256. init : ndarray of shape (n_samples, n_components), default=None
  257. Initialization value of the sparse codes. Only used if
  258. `algorithm='lasso_cd'`.
  259. max_iter : int, default=1000
  260. Maximum number of iterations to perform if `algorithm='lasso_cd'` or
  261. `'lasso_lars'`.
  262. n_jobs : int, default=None
  263. Number of parallel jobs to run.
  264. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
  265. ``-1`` means using all processors. See :term:`Glossary <n_jobs>`
  266. for more details.
  267. check_input : bool, default=True
  268. If `False`, the input arrays X and dictionary will not be checked.
  269. verbose : int, default=0
  270. Controls the verbosity; the higher, the more messages.
  271. positive : bool, default=False
  272. Whether to enforce positivity when finding the encoding.
  273. .. versionadded:: 0.20
  274. Returns
  275. -------
  276. code : ndarray of shape (n_samples, n_components)
  277. The sparse codes.
  278. See Also
  279. --------
  280. sklearn.linear_model.lars_path : Compute Least Angle Regression or Lasso
  281. path using LARS algorithm.
  282. sklearn.linear_model.orthogonal_mp : Solves Orthogonal Matching Pursuit problems.
  283. sklearn.linear_model.Lasso : Train Linear Model with L1 prior as regularizer.
  284. SparseCoder : Find a sparse representation of data from a fixed precomputed
  285. dictionary.
  286. """
  287. if check_input:
  288. if algorithm == "lasso_cd":
  289. dictionary = check_array(
  290. dictionary, order="C", dtype=[np.float64, np.float32]
  291. )
  292. X = check_array(X, order="C", dtype=[np.float64, np.float32])
  293. else:
  294. dictionary = check_array(dictionary)
  295. X = check_array(X)
  296. if dictionary.shape[1] != X.shape[1]:
  297. raise ValueError(
  298. "Dictionary and X have different numbers of features:"
  299. "dictionary.shape: {} X.shape{}".format(dictionary.shape, X.shape)
  300. )
  301. _check_positive_coding(algorithm, positive)
  302. return _sparse_encode(
  303. X,
  304. dictionary,
  305. gram=gram,
  306. cov=cov,
  307. algorithm=algorithm,
  308. n_nonzero_coefs=n_nonzero_coefs,
  309. alpha=alpha,
  310. copy_cov=copy_cov,
  311. init=init,
  312. max_iter=max_iter,
  313. n_jobs=n_jobs,
  314. verbose=verbose,
  315. positive=positive,
  316. )
  317. def _sparse_encode(
  318. X,
  319. dictionary,
  320. *,
  321. gram=None,
  322. cov=None,
  323. algorithm="lasso_lars",
  324. n_nonzero_coefs=None,
  325. alpha=None,
  326. copy_cov=True,
  327. init=None,
  328. max_iter=1000,
  329. n_jobs=None,
  330. verbose=0,
  331. positive=False,
  332. ):
  333. """Sparse coding without input/parameter validation."""
  334. n_samples, n_features = X.shape
  335. n_components = dictionary.shape[0]
  336. if algorithm in ("lars", "omp"):
  337. regularization = n_nonzero_coefs
  338. if regularization is None:
  339. regularization = min(max(n_features / 10, 1), n_components)
  340. else:
  341. regularization = alpha
  342. if regularization is None:
  343. regularization = 1.0
  344. if gram is None and algorithm != "threshold":
  345. gram = np.dot(dictionary, dictionary.T)
  346. if cov is None and algorithm != "lasso_cd":
  347. copy_cov = False
  348. cov = np.dot(dictionary, X.T)
  349. if effective_n_jobs(n_jobs) == 1 or algorithm == "threshold":
  350. code = _sparse_encode_precomputed(
  351. X,
  352. dictionary,
  353. gram=gram,
  354. cov=cov,
  355. algorithm=algorithm,
  356. regularization=regularization,
  357. copy_cov=copy_cov,
  358. init=init,
  359. max_iter=max_iter,
  360. verbose=verbose,
  361. positive=positive,
  362. )
  363. return code
  364. # Enter parallel code block
  365. n_samples = X.shape[0]
  366. n_components = dictionary.shape[0]
  367. code = np.empty((n_samples, n_components))
  368. slices = list(gen_even_slices(n_samples, effective_n_jobs(n_jobs)))
  369. code_views = Parallel(n_jobs=n_jobs, verbose=verbose)(
  370. delayed(_sparse_encode_precomputed)(
  371. X[this_slice],
  372. dictionary,
  373. gram=gram,
  374. cov=cov[:, this_slice] if cov is not None else None,
  375. algorithm=algorithm,
  376. regularization=regularization,
  377. copy_cov=copy_cov,
  378. init=init[this_slice] if init is not None else None,
  379. max_iter=max_iter,
  380. verbose=verbose,
  381. positive=positive,
  382. )
  383. for this_slice in slices
  384. )
  385. for this_slice, this_view in zip(slices, code_views):
  386. code[this_slice] = this_view
  387. return code
  388. def _update_dict(
  389. dictionary,
  390. Y,
  391. code,
  392. A=None,
  393. B=None,
  394. verbose=False,
  395. random_state=None,
  396. positive=False,
  397. ):
  398. """Update the dense dictionary factor in place.
  399. Parameters
  400. ----------
  401. dictionary : ndarray of shape (n_components, n_features)
  402. Value of the dictionary at the previous iteration.
  403. Y : ndarray of shape (n_samples, n_features)
  404. Data matrix.
  405. code : ndarray of shape (n_samples, n_components)
  406. Sparse coding of the data against which to optimize the dictionary.
  407. A : ndarray of shape (n_components, n_components), default=None
  408. Together with `B`, sufficient stats of the online model to update the
  409. dictionary.
  410. B : ndarray of shape (n_features, n_components), default=None
  411. Together with `A`, sufficient stats of the online model to update the
  412. dictionary.
  413. verbose: bool, default=False
  414. Degree of output the procedure will print.
  415. random_state : int, RandomState instance or None, default=None
  416. Used for randomly initializing the dictionary. Pass an int for
  417. reproducible results across multiple function calls.
  418. See :term:`Glossary <random_state>`.
  419. positive : bool, default=False
  420. Whether to enforce positivity when finding the dictionary.
  421. .. versionadded:: 0.20
  422. """
  423. n_samples, n_components = code.shape
  424. random_state = check_random_state(random_state)
  425. if A is None:
  426. A = code.T @ code
  427. if B is None:
  428. B = Y.T @ code
  429. n_unused = 0
  430. for k in range(n_components):
  431. if A[k, k] > 1e-6:
  432. # 1e-6 is arbitrary but consistent with the spams implementation
  433. dictionary[k] += (B[:, k] - A[k] @ dictionary) / A[k, k]
  434. else:
  435. # kth atom is almost never used -> sample a new one from the data
  436. newd = Y[random_state.choice(n_samples)]
  437. # add small noise to avoid making the sparse coding ill conditioned
  438. noise_level = 0.01 * (newd.std() or 1) # avoid 0 std
  439. noise = random_state.normal(0, noise_level, size=len(newd))
  440. dictionary[k] = newd + noise
  441. code[:, k] = 0
  442. n_unused += 1
  443. if positive:
  444. np.clip(dictionary[k], 0, None, out=dictionary[k])
  445. # Projection on the constraint set ||V_k|| <= 1
  446. dictionary[k] /= max(linalg.norm(dictionary[k]), 1)
  447. if verbose and n_unused > 0:
  448. print(f"{n_unused} unused atoms resampled.")
  449. def _dict_learning(
  450. X,
  451. n_components,
  452. *,
  453. alpha,
  454. max_iter,
  455. tol,
  456. method,
  457. n_jobs,
  458. dict_init,
  459. code_init,
  460. callback,
  461. verbose,
  462. random_state,
  463. return_n_iter,
  464. positive_dict,
  465. positive_code,
  466. method_max_iter,
  467. ):
  468. """Main dictionary learning algorithm"""
  469. t0 = time.time()
  470. # Init the code and the dictionary with SVD of Y
  471. if code_init is not None and dict_init is not None:
  472. code = np.array(code_init, order="F")
  473. # Don't copy V, it will happen below
  474. dictionary = dict_init
  475. else:
  476. code, S, dictionary = linalg.svd(X, full_matrices=False)
  477. # flip the initial code's sign to enforce deterministic output
  478. code, dictionary = svd_flip(code, dictionary)
  479. dictionary = S[:, np.newaxis] * dictionary
  480. r = len(dictionary)
  481. if n_components <= r: # True even if n_components=None
  482. code = code[:, :n_components]
  483. dictionary = dictionary[:n_components, :]
  484. else:
  485. code = np.c_[code, np.zeros((len(code), n_components - r))]
  486. dictionary = np.r_[
  487. dictionary, np.zeros((n_components - r, dictionary.shape[1]))
  488. ]
  489. # Fortran-order dict better suited for the sparse coding which is the
  490. # bottleneck of this algorithm.
  491. dictionary = np.asfortranarray(dictionary)
  492. errors = []
  493. current_cost = np.nan
  494. if verbose == 1:
  495. print("[dict_learning]", end=" ")
  496. # If max_iter is 0, number of iterations returned should be zero
  497. ii = -1
  498. for ii in range(max_iter):
  499. dt = time.time() - t0
  500. if verbose == 1:
  501. sys.stdout.write(".")
  502. sys.stdout.flush()
  503. elif verbose:
  504. print(
  505. "Iteration % 3i (elapsed time: % 3is, % 4.1fmn, current cost % 7.3f)"
  506. % (ii, dt, dt / 60, current_cost)
  507. )
  508. # Update code
  509. code = sparse_encode(
  510. X,
  511. dictionary,
  512. algorithm=method,
  513. alpha=alpha,
  514. init=code,
  515. n_jobs=n_jobs,
  516. positive=positive_code,
  517. max_iter=method_max_iter,
  518. verbose=verbose,
  519. )
  520. # Update dictionary in place
  521. _update_dict(
  522. dictionary,
  523. X,
  524. code,
  525. verbose=verbose,
  526. random_state=random_state,
  527. positive=positive_dict,
  528. )
  529. # Cost function
  530. current_cost = 0.5 * np.sum((X - code @ dictionary) ** 2) + alpha * np.sum(
  531. np.abs(code)
  532. )
  533. errors.append(current_cost)
  534. if ii > 0:
  535. dE = errors[-2] - errors[-1]
  536. # assert(dE >= -tol * errors[-1])
  537. if dE < tol * errors[-1]:
  538. if verbose == 1:
  539. # A line return
  540. print("")
  541. elif verbose:
  542. print("--- Convergence reached after %d iterations" % ii)
  543. break
  544. if ii % 5 == 0 and callback is not None:
  545. callback(locals())
  546. if return_n_iter:
  547. return code, dictionary, errors, ii + 1
  548. else:
  549. return code, dictionary, errors
  550. def _check_warn_deprecated(param, name, default, additional_message=None):
  551. if param != "deprecated":
  552. msg = (
  553. f"'{name}' is deprecated in version 1.1 and will be removed in version 1.4."
  554. )
  555. if additional_message:
  556. msg += f" {additional_message}"
  557. warnings.warn(msg, FutureWarning)
  558. return param
  559. else:
  560. return default
  561. def dict_learning_online(
  562. X,
  563. n_components=2,
  564. *,
  565. alpha=1,
  566. n_iter="deprecated",
  567. max_iter=None,
  568. return_code=True,
  569. dict_init=None,
  570. callback=None,
  571. batch_size=256,
  572. verbose=False,
  573. shuffle=True,
  574. n_jobs=None,
  575. method="lars",
  576. iter_offset="deprecated",
  577. random_state=None,
  578. return_inner_stats="deprecated",
  579. inner_stats="deprecated",
  580. return_n_iter="deprecated",
  581. positive_dict=False,
  582. positive_code=False,
  583. method_max_iter=1000,
  584. tol=1e-3,
  585. max_no_improvement=10,
  586. ):
  587. """Solve a dictionary learning matrix factorization problem online.
  588. Finds the best dictionary and the corresponding sparse code for
  589. approximating the data matrix X by solving::
  590. (U^*, V^*) = argmin 0.5 || X - U V ||_Fro^2 + alpha * || U ||_1,1
  591. (U,V)
  592. with || V_k ||_2 = 1 for all 0 <= k < n_components
  593. where V is the dictionary and U is the sparse code. ||.||_Fro stands for
  594. the Frobenius norm and ||.||_1,1 stands for the entry-wise matrix norm
  595. which is the sum of the absolute values of all the entries in the matrix.
  596. This is accomplished by repeatedly iterating over mini-batches by slicing
  597. the input data.
  598. Read more in the :ref:`User Guide <DictionaryLearning>`.
  599. Parameters
  600. ----------
  601. X : ndarray of shape (n_samples, n_features)
  602. Data matrix.
  603. n_components : int or None, default=2
  604. Number of dictionary atoms to extract. If None, then ``n_components``
  605. is set to ``n_features``.
  606. alpha : float, default=1
  607. Sparsity controlling parameter.
  608. n_iter : int, default=100
  609. Number of mini-batch iterations to perform.
  610. .. deprecated:: 1.1
  611. `n_iter` is deprecated in 1.1 and will be removed in 1.4. Use
  612. `max_iter` instead.
  613. max_iter : int, default=None
  614. Maximum number of iterations over the complete dataset before
  615. stopping independently of any early stopping criterion heuristics.
  616. If ``max_iter`` is not None, ``n_iter`` is ignored.
  617. .. versionadded:: 1.1
  618. return_code : bool, default=True
  619. Whether to also return the code U or just the dictionary `V`.
  620. dict_init : ndarray of shape (n_components, n_features), default=None
  621. Initial values for the dictionary for warm restart scenarios.
  622. If `None`, the initial values for the dictionary are created
  623. with an SVD decomposition of the data via
  624. :func:`~sklearn.utils.extmath.randomized_svd`.
  625. callback : callable, default=None
  626. A callable that gets invoked at the end of each iteration.
  627. batch_size : int, default=256
  628. The number of samples to take in each batch.
  629. .. versionchanged:: 1.3
  630. The default value of `batch_size` changed from 3 to 256 in version 1.3.
  631. verbose : bool, default=False
  632. To control the verbosity of the procedure.
  633. shuffle : bool, default=True
  634. Whether to shuffle the data before splitting it in batches.
  635. n_jobs : int, default=None
  636. Number of parallel jobs to run.
  637. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
  638. ``-1`` means using all processors. See :term:`Glossary <n_jobs>`
  639. for more details.
  640. method : {'lars', 'cd'}, default='lars'
  641. * `'lars'`: uses the least angle regression method to solve the lasso
  642. problem (`linear_model.lars_path`);
  643. * `'cd'`: uses the coordinate descent method to compute the
  644. Lasso solution (`linear_model.Lasso`). Lars will be faster if
  645. the estimated components are sparse.
  646. iter_offset : int, default=0
  647. Number of previous iterations completed on the dictionary used for
  648. initialization.
  649. .. deprecated:: 1.1
  650. `iter_offset` serves internal purpose only and will be removed in 1.4.
  651. random_state : int, RandomState instance or None, default=None
  652. Used for initializing the dictionary when ``dict_init`` is not
  653. specified, randomly shuffling the data when ``shuffle`` is set to
  654. ``True``, and updating the dictionary. Pass an int for reproducible
  655. results across multiple function calls.
  656. See :term:`Glossary <random_state>`.
  657. return_inner_stats : bool, default=False
  658. Return the inner statistics A (dictionary covariance) and B
  659. (data approximation). Useful to restart the algorithm in an
  660. online setting. If `return_inner_stats` is `True`, `return_code` is
  661. ignored.
  662. .. deprecated:: 1.1
  663. `return_inner_stats` serves internal purpose only and will be removed in 1.4.
  664. inner_stats : tuple of (A, B) ndarrays, default=None
  665. Inner sufficient statistics that are kept by the algorithm.
  666. Passing them at initialization is useful in online settings, to
  667. avoid losing the history of the evolution.
  668. `A` `(n_components, n_components)` is the dictionary covariance matrix.
  669. `B` `(n_features, n_components)` is the data approximation matrix.
  670. .. deprecated:: 1.1
  671. `inner_stats` serves internal purpose only and will be removed in 1.4.
  672. return_n_iter : bool, default=False
  673. Whether or not to return the number of iterations.
  674. .. deprecated:: 1.1
  675. `return_n_iter` will be removed in 1.4 and n_iter will never be returned.
  676. positive_dict : bool, default=False
  677. Whether to enforce positivity when finding the dictionary.
  678. .. versionadded:: 0.20
  679. positive_code : bool, default=False
  680. Whether to enforce positivity when finding the code.
  681. .. versionadded:: 0.20
  682. method_max_iter : int, default=1000
  683. Maximum number of iterations to perform when solving the lasso problem.
  684. .. versionadded:: 0.22
  685. tol : float, default=1e-3
  686. Control early stopping based on the norm of the differences in the
  687. dictionary between 2 steps. Used only if `max_iter` is not None.
  688. To disable early stopping based on changes in the dictionary, set
  689. `tol` to 0.0.
  690. .. versionadded:: 1.1
  691. max_no_improvement : int, default=10
  692. Control early stopping based on the consecutive number of mini batches
  693. that does not yield an improvement on the smoothed cost function. Used only if
  694. `max_iter` is not None.
  695. To disable convergence detection based on cost function, set
  696. `max_no_improvement` to None.
  697. .. versionadded:: 1.1
  698. Returns
  699. -------
  700. code : ndarray of shape (n_samples, n_components),
  701. The sparse code (only returned if `return_code=True`).
  702. dictionary : ndarray of shape (n_components, n_features),
  703. The solutions to the dictionary learning problem.
  704. n_iter : int
  705. Number of iterations run. Returned only if `return_n_iter` is
  706. set to `True`.
  707. See Also
  708. --------
  709. dict_learning : Solve a dictionary learning matrix factorization problem.
  710. DictionaryLearning : Find a dictionary that sparsely encodes data.
  711. MiniBatchDictionaryLearning : A faster, less accurate, version of the dictionary
  712. learning algorithm.
  713. SparsePCA : Sparse Principal Components Analysis.
  714. MiniBatchSparsePCA : Mini-batch Sparse Principal Components Analysis.
  715. """
  716. deps = (return_n_iter, return_inner_stats, iter_offset, inner_stats)
  717. if max_iter is not None and not all(arg == "deprecated" for arg in deps):
  718. raise ValueError(
  719. "The following arguments are incompatible with 'max_iter': "
  720. "return_n_iter, return_inner_stats, iter_offset, inner_stats"
  721. )
  722. iter_offset = _check_warn_deprecated(iter_offset, "iter_offset", default=0)
  723. return_inner_stats = _check_warn_deprecated(
  724. return_inner_stats,
  725. "return_inner_stats",
  726. default=False,
  727. additional_message="From 1.4 inner_stats will never be returned.",
  728. )
  729. inner_stats = _check_warn_deprecated(inner_stats, "inner_stats", default=None)
  730. return_n_iter = _check_warn_deprecated(
  731. return_n_iter,
  732. "return_n_iter",
  733. default=False,
  734. additional_message=(
  735. "From 1.4 'n_iter' will never be returned. Refer to the 'n_iter_' and "
  736. "'n_steps_' attributes of the MiniBatchDictionaryLearning object instead."
  737. ),
  738. )
  739. if max_iter is not None:
  740. transform_algorithm = "lasso_" + method
  741. est = MiniBatchDictionaryLearning(
  742. n_components=n_components,
  743. alpha=alpha,
  744. n_iter=n_iter,
  745. n_jobs=n_jobs,
  746. fit_algorithm=method,
  747. batch_size=batch_size,
  748. shuffle=shuffle,
  749. dict_init=dict_init,
  750. random_state=random_state,
  751. transform_algorithm=transform_algorithm,
  752. transform_alpha=alpha,
  753. positive_code=positive_code,
  754. positive_dict=positive_dict,
  755. transform_max_iter=method_max_iter,
  756. verbose=verbose,
  757. callback=callback,
  758. tol=tol,
  759. max_no_improvement=max_no_improvement,
  760. ).fit(X)
  761. if not return_code:
  762. return est.components_
  763. else:
  764. code = est.transform(X)
  765. return code, est.components_
  766. # TODO(1.4) remove the whole old behavior
  767. # Fallback to old behavior
  768. n_iter = _check_warn_deprecated(
  769. n_iter, "n_iter", default=100, additional_message="Use 'max_iter' instead."
  770. )
  771. if n_components is None:
  772. n_components = X.shape[1]
  773. if method not in ("lars", "cd"):
  774. raise ValueError("Coding method not supported as a fit algorithm.")
  775. _check_positive_coding(method, positive_code)
  776. method = "lasso_" + method
  777. t0 = time.time()
  778. n_samples, n_features = X.shape
  779. # Avoid integer division problems
  780. alpha = float(alpha)
  781. random_state = check_random_state(random_state)
  782. # Init V with SVD of X
  783. if dict_init is not None:
  784. dictionary = dict_init
  785. else:
  786. _, S, dictionary = randomized_svd(X, n_components, random_state=random_state)
  787. dictionary = S[:, np.newaxis] * dictionary
  788. r = len(dictionary)
  789. if n_components <= r:
  790. dictionary = dictionary[:n_components, :]
  791. else:
  792. dictionary = np.r_[
  793. dictionary,
  794. np.zeros((n_components - r, dictionary.shape[1]), dtype=dictionary.dtype),
  795. ]
  796. if verbose == 1:
  797. print("[dict_learning]", end=" ")
  798. if shuffle:
  799. X_train = X.copy()
  800. random_state.shuffle(X_train)
  801. else:
  802. X_train = X
  803. X_train = check_array(
  804. X_train, order="C", dtype=[np.float64, np.float32], copy=False
  805. )
  806. # Fortran-order dict better suited for the sparse coding which is the
  807. # bottleneck of this algorithm.
  808. dictionary = check_array(dictionary, order="F", dtype=X_train.dtype, copy=False)
  809. dictionary = np.require(dictionary, requirements="W")
  810. batches = gen_batches(n_samples, batch_size)
  811. batches = itertools.cycle(batches)
  812. # The covariance of the dictionary
  813. if inner_stats is None:
  814. A = np.zeros((n_components, n_components), dtype=X_train.dtype)
  815. # The data approximation
  816. B = np.zeros((n_features, n_components), dtype=X_train.dtype)
  817. else:
  818. A = inner_stats[0].copy()
  819. B = inner_stats[1].copy()
  820. # If n_iter is zero, we need to return zero.
  821. ii = iter_offset - 1
  822. for ii, batch in zip(range(iter_offset, iter_offset + n_iter), batches):
  823. this_X = X_train[batch]
  824. dt = time.time() - t0
  825. if verbose == 1:
  826. sys.stdout.write(".")
  827. sys.stdout.flush()
  828. elif verbose:
  829. if verbose > 10 or ii % ceil(100.0 / verbose) == 0:
  830. print(
  831. "Iteration % 3i (elapsed time: % 3is, % 4.1fmn)" % (ii, dt, dt / 60)
  832. )
  833. this_code = sparse_encode(
  834. this_X,
  835. dictionary,
  836. algorithm=method,
  837. alpha=alpha,
  838. n_jobs=n_jobs,
  839. check_input=False,
  840. positive=positive_code,
  841. max_iter=method_max_iter,
  842. verbose=verbose,
  843. )
  844. # Update the auxiliary variables
  845. if ii < batch_size - 1:
  846. theta = float((ii + 1) * batch_size)
  847. else:
  848. theta = float(batch_size**2 + ii + 1 - batch_size)
  849. beta = (theta + 1 - batch_size) / (theta + 1)
  850. A *= beta
  851. A += np.dot(this_code.T, this_code)
  852. B *= beta
  853. B += np.dot(this_X.T, this_code)
  854. # Update dictionary in place
  855. _update_dict(
  856. dictionary,
  857. this_X,
  858. this_code,
  859. A,
  860. B,
  861. verbose=verbose,
  862. random_state=random_state,
  863. positive=positive_dict,
  864. )
  865. # Maybe we need a stopping criteria based on the amount of
  866. # modification in the dictionary
  867. if callback is not None:
  868. callback(locals())
  869. if return_inner_stats:
  870. if return_n_iter:
  871. return dictionary, (A, B), ii - iter_offset + 1
  872. else:
  873. return dictionary, (A, B)
  874. if return_code:
  875. if verbose > 1:
  876. print("Learning code...", end=" ")
  877. elif verbose == 1:
  878. print("|", end=" ")
  879. code = sparse_encode(
  880. X,
  881. dictionary,
  882. algorithm=method,
  883. alpha=alpha,
  884. n_jobs=n_jobs,
  885. check_input=False,
  886. positive=positive_code,
  887. max_iter=method_max_iter,
  888. verbose=verbose,
  889. )
  890. if verbose > 1:
  891. dt = time.time() - t0
  892. print("done (total time: % 3is, % 4.1fmn)" % (dt, dt / 60))
  893. if return_n_iter:
  894. return code, dictionary, ii - iter_offset + 1
  895. else:
  896. return code, dictionary
  897. if return_n_iter:
  898. return dictionary, ii - iter_offset + 1
  899. else:
  900. return dictionary
  901. @validate_params(
  902. {
  903. "X": ["array-like"],
  904. "method": [StrOptions({"lars", "cd"})],
  905. "return_n_iter": ["boolean"],
  906. "method_max_iter": [Interval(Integral, 0, None, closed="left")],
  907. },
  908. prefer_skip_nested_validation=False,
  909. )
  910. def dict_learning(
  911. X,
  912. n_components,
  913. *,
  914. alpha,
  915. max_iter=100,
  916. tol=1e-8,
  917. method="lars",
  918. n_jobs=None,
  919. dict_init=None,
  920. code_init=None,
  921. callback=None,
  922. verbose=False,
  923. random_state=None,
  924. return_n_iter=False,
  925. positive_dict=False,
  926. positive_code=False,
  927. method_max_iter=1000,
  928. ):
  929. """Solve a dictionary learning matrix factorization problem.
  930. Finds the best dictionary and the corresponding sparse code for
  931. approximating the data matrix X by solving::
  932. (U^*, V^*) = argmin 0.5 || X - U V ||_Fro^2 + alpha * || U ||_1,1
  933. (U,V)
  934. with || V_k ||_2 = 1 for all 0 <= k < n_components
  935. where V is the dictionary and U is the sparse code. ||.||_Fro stands for
  936. the Frobenius norm and ||.||_1,1 stands for the entry-wise matrix norm
  937. which is the sum of the absolute values of all the entries in the matrix.
  938. Read more in the :ref:`User Guide <DictionaryLearning>`.
  939. Parameters
  940. ----------
  941. X : array-like of shape (n_samples, n_features)
  942. Data matrix.
  943. n_components : int
  944. Number of dictionary atoms to extract.
  945. alpha : int or float
  946. Sparsity controlling parameter.
  947. max_iter : int, default=100
  948. Maximum number of iterations to perform.
  949. tol : float, default=1e-8
  950. Tolerance for the stopping condition.
  951. method : {'lars', 'cd'}, default='lars'
  952. The method used:
  953. * `'lars'`: uses the least angle regression method to solve the lasso
  954. problem (`linear_model.lars_path`);
  955. * `'cd'`: uses the coordinate descent method to compute the
  956. Lasso solution (`linear_model.Lasso`). Lars will be faster if
  957. the estimated components are sparse.
  958. n_jobs : int, default=None
  959. Number of parallel jobs to run.
  960. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
  961. ``-1`` means using all processors. See :term:`Glossary <n_jobs>`
  962. for more details.
  963. dict_init : ndarray of shape (n_components, n_features), default=None
  964. Initial value for the dictionary for warm restart scenarios. Only used
  965. if `code_init` and `dict_init` are not None.
  966. code_init : ndarray of shape (n_samples, n_components), default=None
  967. Initial value for the sparse code for warm restart scenarios. Only used
  968. if `code_init` and `dict_init` are not None.
  969. callback : callable, default=None
  970. Callable that gets invoked every five iterations.
  971. verbose : bool, default=False
  972. To control the verbosity of the procedure.
  973. random_state : int, RandomState instance or None, default=None
  974. Used for randomly initializing the dictionary. Pass an int for
  975. reproducible results across multiple function calls.
  976. See :term:`Glossary <random_state>`.
  977. return_n_iter : bool, default=False
  978. Whether or not to return the number of iterations.
  979. positive_dict : bool, default=False
  980. Whether to enforce positivity when finding the dictionary.
  981. .. versionadded:: 0.20
  982. positive_code : bool, default=False
  983. Whether to enforce positivity when finding the code.
  984. .. versionadded:: 0.20
  985. method_max_iter : int, default=1000
  986. Maximum number of iterations to perform.
  987. .. versionadded:: 0.22
  988. Returns
  989. -------
  990. code : ndarray of shape (n_samples, n_components)
  991. The sparse code factor in the matrix factorization.
  992. dictionary : ndarray of shape (n_components, n_features),
  993. The dictionary factor in the matrix factorization.
  994. errors : array
  995. Vector of errors at each iteration.
  996. n_iter : int
  997. Number of iterations run. Returned only if `return_n_iter` is
  998. set to True.
  999. See Also
  1000. --------
  1001. dict_learning_online : Solve a dictionary learning matrix factorization
  1002. problem online.
  1003. DictionaryLearning : Find a dictionary that sparsely encodes data.
  1004. MiniBatchDictionaryLearning : A faster, less accurate version
  1005. of the dictionary learning algorithm.
  1006. SparsePCA : Sparse Principal Components Analysis.
  1007. MiniBatchSparsePCA : Mini-batch Sparse Principal Components Analysis.
  1008. """
  1009. estimator = DictionaryLearning(
  1010. n_components=n_components,
  1011. alpha=alpha,
  1012. max_iter=max_iter,
  1013. tol=tol,
  1014. fit_algorithm=method,
  1015. n_jobs=n_jobs,
  1016. dict_init=dict_init,
  1017. callback=callback,
  1018. code_init=code_init,
  1019. verbose=verbose,
  1020. random_state=random_state,
  1021. positive_code=positive_code,
  1022. positive_dict=positive_dict,
  1023. transform_max_iter=method_max_iter,
  1024. )
  1025. code = estimator.fit_transform(X)
  1026. if return_n_iter:
  1027. return (
  1028. code,
  1029. estimator.components_,
  1030. estimator.error_,
  1031. estimator.n_iter_,
  1032. )
  1033. return code, estimator.components_, estimator.error_
  1034. class _BaseSparseCoding(ClassNamePrefixFeaturesOutMixin, TransformerMixin):
  1035. """Base class from SparseCoder and DictionaryLearning algorithms."""
  1036. def __init__(
  1037. self,
  1038. transform_algorithm,
  1039. transform_n_nonzero_coefs,
  1040. transform_alpha,
  1041. split_sign,
  1042. n_jobs,
  1043. positive_code,
  1044. transform_max_iter,
  1045. ):
  1046. self.transform_algorithm = transform_algorithm
  1047. self.transform_n_nonzero_coefs = transform_n_nonzero_coefs
  1048. self.transform_alpha = transform_alpha
  1049. self.transform_max_iter = transform_max_iter
  1050. self.split_sign = split_sign
  1051. self.n_jobs = n_jobs
  1052. self.positive_code = positive_code
  1053. def _transform(self, X, dictionary):
  1054. """Private method allowing to accommodate both DictionaryLearning and
  1055. SparseCoder."""
  1056. X = self._validate_data(X, reset=False)
  1057. if hasattr(self, "alpha") and self.transform_alpha is None:
  1058. transform_alpha = self.alpha
  1059. else:
  1060. transform_alpha = self.transform_alpha
  1061. code = sparse_encode(
  1062. X,
  1063. dictionary,
  1064. algorithm=self.transform_algorithm,
  1065. n_nonzero_coefs=self.transform_n_nonzero_coefs,
  1066. alpha=transform_alpha,
  1067. max_iter=self.transform_max_iter,
  1068. n_jobs=self.n_jobs,
  1069. positive=self.positive_code,
  1070. )
  1071. if self.split_sign:
  1072. # feature vector is split into a positive and negative side
  1073. n_samples, n_features = code.shape
  1074. split_code = np.empty((n_samples, 2 * n_features))
  1075. split_code[:, :n_features] = np.maximum(code, 0)
  1076. split_code[:, n_features:] = -np.minimum(code, 0)
  1077. code = split_code
  1078. return code
  1079. def transform(self, X):
  1080. """Encode the data as a sparse combination of the dictionary atoms.
  1081. Coding method is determined by the object parameter
  1082. `transform_algorithm`.
  1083. Parameters
  1084. ----------
  1085. X : ndarray of shape (n_samples, n_features)
  1086. Test data to be transformed, must have the same number of
  1087. features as the data used to train the model.
  1088. Returns
  1089. -------
  1090. X_new : ndarray of shape (n_samples, n_components)
  1091. Transformed data.
  1092. """
  1093. check_is_fitted(self)
  1094. return self._transform(X, self.components_)
  1095. class SparseCoder(_BaseSparseCoding, BaseEstimator):
  1096. """Sparse coding.
  1097. Finds a sparse representation of data against a fixed, precomputed
  1098. dictionary.
  1099. Each row of the result is the solution to a sparse coding problem.
  1100. The goal is to find a sparse array `code` such that::
  1101. X ~= code * dictionary
  1102. Read more in the :ref:`User Guide <SparseCoder>`.
  1103. Parameters
  1104. ----------
  1105. dictionary : ndarray of shape (n_components, n_features)
  1106. The dictionary atoms used for sparse coding. Lines are assumed to be
  1107. normalized to unit norm.
  1108. transform_algorithm : {'lasso_lars', 'lasso_cd', 'lars', 'omp', \
  1109. 'threshold'}, default='omp'
  1110. Algorithm used to transform the data:
  1111. - `'lars'`: uses the least angle regression method
  1112. (`linear_model.lars_path`);
  1113. - `'lasso_lars'`: uses Lars to compute the Lasso solution;
  1114. - `'lasso_cd'`: uses the coordinate descent method to compute the
  1115. Lasso solution (linear_model.Lasso). `'lasso_lars'` will be faster if
  1116. the estimated components are sparse;
  1117. - `'omp'`: uses orthogonal matching pursuit to estimate the sparse
  1118. solution;
  1119. - `'threshold'`: squashes to zero all coefficients less than alpha from
  1120. the projection ``dictionary * X'``.
  1121. transform_n_nonzero_coefs : int, default=None
  1122. Number of nonzero coefficients to target in each column of the
  1123. solution. This is only used by `algorithm='lars'` and `algorithm='omp'`
  1124. and is overridden by `alpha` in the `omp` case. If `None`, then
  1125. `transform_n_nonzero_coefs=int(n_features / 10)`.
  1126. transform_alpha : float, default=None
  1127. If `algorithm='lasso_lars'` or `algorithm='lasso_cd'`, `alpha` is the
  1128. penalty applied to the L1 norm.
  1129. If `algorithm='threshold'`, `alpha` is the absolute value of the
  1130. threshold below which coefficients will be squashed to zero.
  1131. If `algorithm='omp'`, `alpha` is the tolerance parameter: the value of
  1132. the reconstruction error targeted. In this case, it overrides
  1133. `n_nonzero_coefs`.
  1134. If `None`, default to 1.
  1135. split_sign : bool, default=False
  1136. Whether to split the sparse feature vector into the concatenation of
  1137. its negative part and its positive part. This can improve the
  1138. performance of downstream classifiers.
  1139. n_jobs : int, default=None
  1140. Number of parallel jobs to run.
  1141. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
  1142. ``-1`` means using all processors. See :term:`Glossary <n_jobs>`
  1143. for more details.
  1144. positive_code : bool, default=False
  1145. Whether to enforce positivity when finding the code.
  1146. .. versionadded:: 0.20
  1147. transform_max_iter : int, default=1000
  1148. Maximum number of iterations to perform if `algorithm='lasso_cd'` or
  1149. `lasso_lars`.
  1150. .. versionadded:: 0.22
  1151. Attributes
  1152. ----------
  1153. n_components_ : int
  1154. Number of atoms.
  1155. n_features_in_ : int
  1156. Number of features seen during :term:`fit`.
  1157. .. versionadded:: 0.24
  1158. feature_names_in_ : ndarray of shape (`n_features_in_`,)
  1159. Names of features seen during :term:`fit`. Defined only when `X`
  1160. has feature names that are all strings.
  1161. .. versionadded:: 1.0
  1162. See Also
  1163. --------
  1164. DictionaryLearning : Find a dictionary that sparsely encodes data.
  1165. MiniBatchDictionaryLearning : A faster, less accurate, version of the
  1166. dictionary learning algorithm.
  1167. MiniBatchSparsePCA : Mini-batch Sparse Principal Components Analysis.
  1168. SparsePCA : Sparse Principal Components Analysis.
  1169. sparse_encode : Sparse coding where each row of the result is the solution
  1170. to a sparse coding problem.
  1171. Examples
  1172. --------
  1173. >>> import numpy as np
  1174. >>> from sklearn.decomposition import SparseCoder
  1175. >>> X = np.array([[-1, -1, -1], [0, 0, 3]])
  1176. >>> dictionary = np.array(
  1177. ... [[0, 1, 0],
  1178. ... [-1, -1, 2],
  1179. ... [1, 1, 1],
  1180. ... [0, 1, 1],
  1181. ... [0, 2, 1]],
  1182. ... dtype=np.float64
  1183. ... )
  1184. >>> coder = SparseCoder(
  1185. ... dictionary=dictionary, transform_algorithm='lasso_lars',
  1186. ... transform_alpha=1e-10,
  1187. ... )
  1188. >>> coder.transform(X)
  1189. array([[ 0., 0., -1., 0., 0.],
  1190. [ 0., 1., 1., 0., 0.]])
  1191. """
  1192. _required_parameters = ["dictionary"]
  1193. def __init__(
  1194. self,
  1195. dictionary,
  1196. *,
  1197. transform_algorithm="omp",
  1198. transform_n_nonzero_coefs=None,
  1199. transform_alpha=None,
  1200. split_sign=False,
  1201. n_jobs=None,
  1202. positive_code=False,
  1203. transform_max_iter=1000,
  1204. ):
  1205. super().__init__(
  1206. transform_algorithm,
  1207. transform_n_nonzero_coefs,
  1208. transform_alpha,
  1209. split_sign,
  1210. n_jobs,
  1211. positive_code,
  1212. transform_max_iter,
  1213. )
  1214. self.dictionary = dictionary
  1215. def fit(self, X, y=None):
  1216. """Do nothing and return the estimator unchanged.
  1217. This method is just there to implement the usual API and hence
  1218. work in pipelines.
  1219. Parameters
  1220. ----------
  1221. X : Ignored
  1222. Not used, present for API consistency by convention.
  1223. y : Ignored
  1224. Not used, present for API consistency by convention.
  1225. Returns
  1226. -------
  1227. self : object
  1228. Returns the instance itself.
  1229. """
  1230. return self
  1231. def transform(self, X, y=None):
  1232. """Encode the data as a sparse combination of the dictionary atoms.
  1233. Coding method is determined by the object parameter
  1234. `transform_algorithm`.
  1235. Parameters
  1236. ----------
  1237. X : ndarray of shape (n_samples, n_features)
  1238. Training vector, where `n_samples` is the number of samples
  1239. and `n_features` is the number of features.
  1240. y : Ignored
  1241. Not used, present for API consistency by convention.
  1242. Returns
  1243. -------
  1244. X_new : ndarray of shape (n_samples, n_components)
  1245. Transformed data.
  1246. """
  1247. return super()._transform(X, self.dictionary)
  1248. def _more_tags(self):
  1249. return {
  1250. "requires_fit": False,
  1251. "preserves_dtype": [np.float64, np.float32],
  1252. }
  1253. @property
  1254. def n_components_(self):
  1255. """Number of atoms."""
  1256. return self.dictionary.shape[0]
  1257. @property
  1258. def n_features_in_(self):
  1259. """Number of features seen during `fit`."""
  1260. return self.dictionary.shape[1]
  1261. @property
  1262. def _n_features_out(self):
  1263. """Number of transformed output features."""
  1264. return self.n_components_
  1265. class DictionaryLearning(_BaseSparseCoding, BaseEstimator):
  1266. """Dictionary learning.
  1267. Finds a dictionary (a set of atoms) that performs well at sparsely
  1268. encoding the fitted data.
  1269. Solves the optimization problem::
  1270. (U^*,V^*) = argmin 0.5 || X - U V ||_Fro^2 + alpha * || U ||_1,1
  1271. (U,V)
  1272. with || V_k ||_2 <= 1 for all 0 <= k < n_components
  1273. ||.||_Fro stands for the Frobenius norm and ||.||_1,1 stands for
  1274. the entry-wise matrix norm which is the sum of the absolute values
  1275. of all the entries in the matrix.
  1276. Read more in the :ref:`User Guide <DictionaryLearning>`.
  1277. Parameters
  1278. ----------
  1279. n_components : int, default=None
  1280. Number of dictionary elements to extract. If None, then ``n_components``
  1281. is set to ``n_features``.
  1282. alpha : float, default=1.0
  1283. Sparsity controlling parameter.
  1284. max_iter : int, default=1000
  1285. Maximum number of iterations to perform.
  1286. tol : float, default=1e-8
  1287. Tolerance for numerical error.
  1288. fit_algorithm : {'lars', 'cd'}, default='lars'
  1289. * `'lars'`: uses the least angle regression method to solve the lasso
  1290. problem (:func:`~sklearn.linear_model.lars_path`);
  1291. * `'cd'`: uses the coordinate descent method to compute the
  1292. Lasso solution (:class:`~sklearn.linear_model.Lasso`). Lars will be
  1293. faster if the estimated components are sparse.
  1294. .. versionadded:: 0.17
  1295. *cd* coordinate descent method to improve speed.
  1296. transform_algorithm : {'lasso_lars', 'lasso_cd', 'lars', 'omp', \
  1297. 'threshold'}, default='omp'
  1298. Algorithm used to transform the data:
  1299. - `'lars'`: uses the least angle regression method
  1300. (:func:`~sklearn.linear_model.lars_path`);
  1301. - `'lasso_lars'`: uses Lars to compute the Lasso solution.
  1302. - `'lasso_cd'`: uses the coordinate descent method to compute the
  1303. Lasso solution (:class:`~sklearn.linear_model.Lasso`). `'lasso_lars'`
  1304. will be faster if the estimated components are sparse.
  1305. - `'omp'`: uses orthogonal matching pursuit to estimate the sparse
  1306. solution.
  1307. - `'threshold'`: squashes to zero all coefficients less than alpha from
  1308. the projection ``dictionary * X'``.
  1309. .. versionadded:: 0.17
  1310. *lasso_cd* coordinate descent method to improve speed.
  1311. transform_n_nonzero_coefs : int, default=None
  1312. Number of nonzero coefficients to target in each column of the
  1313. solution. This is only used by `algorithm='lars'` and
  1314. `algorithm='omp'`. If `None`, then
  1315. `transform_n_nonzero_coefs=int(n_features / 10)`.
  1316. transform_alpha : float, default=None
  1317. If `algorithm='lasso_lars'` or `algorithm='lasso_cd'`, `alpha` is the
  1318. penalty applied to the L1 norm.
  1319. If `algorithm='threshold'`, `alpha` is the absolute value of the
  1320. threshold below which coefficients will be squashed to zero.
  1321. If `None`, defaults to `alpha`.
  1322. .. versionchanged:: 1.2
  1323. When None, default value changed from 1.0 to `alpha`.
  1324. n_jobs : int or None, default=None
  1325. Number of parallel jobs to run.
  1326. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
  1327. ``-1`` means using all processors. See :term:`Glossary <n_jobs>`
  1328. for more details.
  1329. code_init : ndarray of shape (n_samples, n_components), default=None
  1330. Initial value for the code, for warm restart. Only used if `code_init`
  1331. and `dict_init` are not None.
  1332. dict_init : ndarray of shape (n_components, n_features), default=None
  1333. Initial values for the dictionary, for warm restart. Only used if
  1334. `code_init` and `dict_init` are not None.
  1335. callback : callable, default=None
  1336. Callable that gets invoked every five iterations.
  1337. .. versionadded:: 1.3
  1338. verbose : bool, default=False
  1339. To control the verbosity of the procedure.
  1340. split_sign : bool, default=False
  1341. Whether to split the sparse feature vector into the concatenation of
  1342. its negative part and its positive part. This can improve the
  1343. performance of downstream classifiers.
  1344. random_state : int, RandomState instance or None, default=None
  1345. Used for initializing the dictionary when ``dict_init`` is not
  1346. specified, randomly shuffling the data when ``shuffle`` is set to
  1347. ``True``, and updating the dictionary. Pass an int for reproducible
  1348. results across multiple function calls.
  1349. See :term:`Glossary <random_state>`.
  1350. positive_code : bool, default=False
  1351. Whether to enforce positivity when finding the code.
  1352. .. versionadded:: 0.20
  1353. positive_dict : bool, default=False
  1354. Whether to enforce positivity when finding the dictionary.
  1355. .. versionadded:: 0.20
  1356. transform_max_iter : int, default=1000
  1357. Maximum number of iterations to perform if `algorithm='lasso_cd'` or
  1358. `'lasso_lars'`.
  1359. .. versionadded:: 0.22
  1360. Attributes
  1361. ----------
  1362. components_ : ndarray of shape (n_components, n_features)
  1363. dictionary atoms extracted from the data
  1364. error_ : array
  1365. vector of errors at each iteration
  1366. n_features_in_ : int
  1367. Number of features seen during :term:`fit`.
  1368. .. versionadded:: 0.24
  1369. feature_names_in_ : ndarray of shape (`n_features_in_`,)
  1370. Names of features seen during :term:`fit`. Defined only when `X`
  1371. has feature names that are all strings.
  1372. .. versionadded:: 1.0
  1373. n_iter_ : int
  1374. Number of iterations run.
  1375. See Also
  1376. --------
  1377. MiniBatchDictionaryLearning: A faster, less accurate, version of the
  1378. dictionary learning algorithm.
  1379. MiniBatchSparsePCA : Mini-batch Sparse Principal Components Analysis.
  1380. SparseCoder : Find a sparse representation of data from a fixed,
  1381. precomputed dictionary.
  1382. SparsePCA : Sparse Principal Components Analysis.
  1383. References
  1384. ----------
  1385. J. Mairal, F. Bach, J. Ponce, G. Sapiro, 2009: Online dictionary learning
  1386. for sparse coding (https://www.di.ens.fr/sierra/pdfs/icml09.pdf)
  1387. Examples
  1388. --------
  1389. >>> import numpy as np
  1390. >>> from sklearn.datasets import make_sparse_coded_signal
  1391. >>> from sklearn.decomposition import DictionaryLearning
  1392. >>> X, dictionary, code = make_sparse_coded_signal(
  1393. ... n_samples=30, n_components=15, n_features=20, n_nonzero_coefs=10,
  1394. ... random_state=42,
  1395. ... )
  1396. >>> dict_learner = DictionaryLearning(
  1397. ... n_components=15, transform_algorithm='lasso_lars', transform_alpha=0.1,
  1398. ... random_state=42,
  1399. ... )
  1400. >>> X_transformed = dict_learner.fit(X).transform(X)
  1401. We can check the level of sparsity of `X_transformed`:
  1402. >>> np.mean(X_transformed == 0)
  1403. 0.52...
  1404. We can compare the average squared euclidean norm of the reconstruction
  1405. error of the sparse coded signal relative to the squared euclidean norm of
  1406. the original signal:
  1407. >>> X_hat = X_transformed @ dict_learner.components_
  1408. >>> np.mean(np.sum((X_hat - X) ** 2, axis=1) / np.sum(X ** 2, axis=1))
  1409. 0.05...
  1410. """
  1411. _parameter_constraints: dict = {
  1412. "n_components": [Interval(Integral, 1, None, closed="left"), None],
  1413. "alpha": [Interval(Real, 0, None, closed="left")],
  1414. "max_iter": [Interval(Integral, 0, None, closed="left")],
  1415. "tol": [Interval(Real, 0, None, closed="left")],
  1416. "fit_algorithm": [StrOptions({"lars", "cd"})],
  1417. "transform_algorithm": [
  1418. StrOptions({"lasso_lars", "lasso_cd", "lars", "omp", "threshold"})
  1419. ],
  1420. "transform_n_nonzero_coefs": [Interval(Integral, 1, None, closed="left"), None],
  1421. "transform_alpha": [Interval(Real, 0, None, closed="left"), None],
  1422. "n_jobs": [Integral, None],
  1423. "code_init": [np.ndarray, None],
  1424. "dict_init": [np.ndarray, None],
  1425. "callback": [callable, None],
  1426. "verbose": ["verbose"],
  1427. "split_sign": ["boolean"],
  1428. "random_state": ["random_state"],
  1429. "positive_code": ["boolean"],
  1430. "positive_dict": ["boolean"],
  1431. "transform_max_iter": [Interval(Integral, 0, None, closed="left")],
  1432. }
  1433. def __init__(
  1434. self,
  1435. n_components=None,
  1436. *,
  1437. alpha=1,
  1438. max_iter=1000,
  1439. tol=1e-8,
  1440. fit_algorithm="lars",
  1441. transform_algorithm="omp",
  1442. transform_n_nonzero_coefs=None,
  1443. transform_alpha=None,
  1444. n_jobs=None,
  1445. code_init=None,
  1446. dict_init=None,
  1447. callback=None,
  1448. verbose=False,
  1449. split_sign=False,
  1450. random_state=None,
  1451. positive_code=False,
  1452. positive_dict=False,
  1453. transform_max_iter=1000,
  1454. ):
  1455. super().__init__(
  1456. transform_algorithm,
  1457. transform_n_nonzero_coefs,
  1458. transform_alpha,
  1459. split_sign,
  1460. n_jobs,
  1461. positive_code,
  1462. transform_max_iter,
  1463. )
  1464. self.n_components = n_components
  1465. self.alpha = alpha
  1466. self.max_iter = max_iter
  1467. self.tol = tol
  1468. self.fit_algorithm = fit_algorithm
  1469. self.code_init = code_init
  1470. self.dict_init = dict_init
  1471. self.callback = callback
  1472. self.verbose = verbose
  1473. self.random_state = random_state
  1474. self.positive_dict = positive_dict
  1475. def fit(self, X, y=None):
  1476. """Fit the model from data in X.
  1477. Parameters
  1478. ----------
  1479. X : array-like of shape (n_samples, n_features)
  1480. Training vector, where `n_samples` is the number of samples
  1481. and `n_features` is the number of features.
  1482. y : Ignored
  1483. Not used, present for API consistency by convention.
  1484. Returns
  1485. -------
  1486. self : object
  1487. Returns the instance itself.
  1488. """
  1489. self.fit_transform(X)
  1490. return self
  1491. @_fit_context(prefer_skip_nested_validation=True)
  1492. def fit_transform(self, X, y=None):
  1493. """Fit the model from data in X and return the transformed data.
  1494. Parameters
  1495. ----------
  1496. X : array-like of shape (n_samples, n_features)
  1497. Training vector, where `n_samples` is the number of samples
  1498. and `n_features` is the number of features.
  1499. y : Ignored
  1500. Not used, present for API consistency by convention.
  1501. Returns
  1502. -------
  1503. V : ndarray of shape (n_samples, n_components)
  1504. Transformed data.
  1505. """
  1506. _check_positive_coding(method=self.fit_algorithm, positive=self.positive_code)
  1507. method = "lasso_" + self.fit_algorithm
  1508. random_state = check_random_state(self.random_state)
  1509. X = self._validate_data(X)
  1510. if self.n_components is None:
  1511. n_components = X.shape[1]
  1512. else:
  1513. n_components = self.n_components
  1514. V, U, E, self.n_iter_ = _dict_learning(
  1515. X,
  1516. n_components,
  1517. alpha=self.alpha,
  1518. tol=self.tol,
  1519. max_iter=self.max_iter,
  1520. method=method,
  1521. method_max_iter=self.transform_max_iter,
  1522. n_jobs=self.n_jobs,
  1523. code_init=self.code_init,
  1524. dict_init=self.dict_init,
  1525. callback=self.callback,
  1526. verbose=self.verbose,
  1527. random_state=random_state,
  1528. return_n_iter=True,
  1529. positive_dict=self.positive_dict,
  1530. positive_code=self.positive_code,
  1531. )
  1532. self.components_ = U
  1533. self.error_ = E
  1534. return V
  1535. @property
  1536. def _n_features_out(self):
  1537. """Number of transformed output features."""
  1538. return self.components_.shape[0]
  1539. def _more_tags(self):
  1540. return {
  1541. "preserves_dtype": [np.float64, np.float32],
  1542. }
  1543. class MiniBatchDictionaryLearning(_BaseSparseCoding, BaseEstimator):
  1544. """Mini-batch dictionary learning.
  1545. Finds a dictionary (a set of atoms) that performs well at sparsely
  1546. encoding the fitted data.
  1547. Solves the optimization problem::
  1548. (U^*,V^*) = argmin 0.5 || X - U V ||_Fro^2 + alpha * || U ||_1,1
  1549. (U,V)
  1550. with || V_k ||_2 <= 1 for all 0 <= k < n_components
  1551. ||.||_Fro stands for the Frobenius norm and ||.||_1,1 stands for
  1552. the entry-wise matrix norm which is the sum of the absolute values
  1553. of all the entries in the matrix.
  1554. Read more in the :ref:`User Guide <DictionaryLearning>`.
  1555. Parameters
  1556. ----------
  1557. n_components : int, default=None
  1558. Number of dictionary elements to extract.
  1559. alpha : float, default=1
  1560. Sparsity controlling parameter.
  1561. n_iter : int, default=1000
  1562. Total number of iterations over data batches to perform.
  1563. .. deprecated:: 1.1
  1564. ``n_iter`` is deprecated in 1.1 and will be removed in 1.4. Use
  1565. ``max_iter`` instead.
  1566. max_iter : int, default=None
  1567. Maximum number of iterations over the complete dataset before
  1568. stopping independently of any early stopping criterion heuristics.
  1569. If ``max_iter`` is not None, ``n_iter`` is ignored.
  1570. .. versionadded:: 1.1
  1571. fit_algorithm : {'lars', 'cd'}, default='lars'
  1572. The algorithm used:
  1573. - `'lars'`: uses the least angle regression method to solve the lasso
  1574. problem (`linear_model.lars_path`)
  1575. - `'cd'`: uses the coordinate descent method to compute the
  1576. Lasso solution (`linear_model.Lasso`). Lars will be faster if
  1577. the estimated components are sparse.
  1578. n_jobs : int, default=None
  1579. Number of parallel jobs to run.
  1580. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
  1581. ``-1`` means using all processors. See :term:`Glossary <n_jobs>`
  1582. for more details.
  1583. batch_size : int, default=256
  1584. Number of samples in each mini-batch.
  1585. .. versionchanged:: 1.3
  1586. The default value of `batch_size` changed from 3 to 256 in version 1.3.
  1587. shuffle : bool, default=True
  1588. Whether to shuffle the samples before forming batches.
  1589. dict_init : ndarray of shape (n_components, n_features), default=None
  1590. Initial value of the dictionary for warm restart scenarios.
  1591. transform_algorithm : {'lasso_lars', 'lasso_cd', 'lars', 'omp', \
  1592. 'threshold'}, default='omp'
  1593. Algorithm used to transform the data:
  1594. - `'lars'`: uses the least angle regression method
  1595. (`linear_model.lars_path`);
  1596. - `'lasso_lars'`: uses Lars to compute the Lasso solution.
  1597. - `'lasso_cd'`: uses the coordinate descent method to compute the
  1598. Lasso solution (`linear_model.Lasso`). `'lasso_lars'` will be faster
  1599. if the estimated components are sparse.
  1600. - `'omp'`: uses orthogonal matching pursuit to estimate the sparse
  1601. solution.
  1602. - `'threshold'`: squashes to zero all coefficients less than alpha from
  1603. the projection ``dictionary * X'``.
  1604. transform_n_nonzero_coefs : int, default=None
  1605. Number of nonzero coefficients to target in each column of the
  1606. solution. This is only used by `algorithm='lars'` and
  1607. `algorithm='omp'`. If `None`, then
  1608. `transform_n_nonzero_coefs=int(n_features / 10)`.
  1609. transform_alpha : float, default=None
  1610. If `algorithm='lasso_lars'` or `algorithm='lasso_cd'`, `alpha` is the
  1611. penalty applied to the L1 norm.
  1612. If `algorithm='threshold'`, `alpha` is the absolute value of the
  1613. threshold below which coefficients will be squashed to zero.
  1614. If `None`, defaults to `alpha`.
  1615. .. versionchanged:: 1.2
  1616. When None, default value changed from 1.0 to `alpha`.
  1617. verbose : bool or int, default=False
  1618. To control the verbosity of the procedure.
  1619. split_sign : bool, default=False
  1620. Whether to split the sparse feature vector into the concatenation of
  1621. its negative part and its positive part. This can improve the
  1622. performance of downstream classifiers.
  1623. random_state : int, RandomState instance or None, default=None
  1624. Used for initializing the dictionary when ``dict_init`` is not
  1625. specified, randomly shuffling the data when ``shuffle`` is set to
  1626. ``True``, and updating the dictionary. Pass an int for reproducible
  1627. results across multiple function calls.
  1628. See :term:`Glossary <random_state>`.
  1629. positive_code : bool, default=False
  1630. Whether to enforce positivity when finding the code.
  1631. .. versionadded:: 0.20
  1632. positive_dict : bool, default=False
  1633. Whether to enforce positivity when finding the dictionary.
  1634. .. versionadded:: 0.20
  1635. transform_max_iter : int, default=1000
  1636. Maximum number of iterations to perform if `algorithm='lasso_cd'` or
  1637. `'lasso_lars'`.
  1638. .. versionadded:: 0.22
  1639. callback : callable, default=None
  1640. A callable that gets invoked at the end of each iteration.
  1641. .. versionadded:: 1.1
  1642. tol : float, default=1e-3
  1643. Control early stopping based on the norm of the differences in the
  1644. dictionary between 2 steps. Used only if `max_iter` is not None.
  1645. To disable early stopping based on changes in the dictionary, set
  1646. `tol` to 0.0.
  1647. .. versionadded:: 1.1
  1648. max_no_improvement : int, default=10
  1649. Control early stopping based on the consecutive number of mini batches
  1650. that does not yield an improvement on the smoothed cost function. Used only if
  1651. `max_iter` is not None.
  1652. To disable convergence detection based on cost function, set
  1653. `max_no_improvement` to None.
  1654. .. versionadded:: 1.1
  1655. Attributes
  1656. ----------
  1657. components_ : ndarray of shape (n_components, n_features)
  1658. Components extracted from the data.
  1659. n_features_in_ : int
  1660. Number of features seen during :term:`fit`.
  1661. .. versionadded:: 0.24
  1662. feature_names_in_ : ndarray of shape (`n_features_in_`,)
  1663. Names of features seen during :term:`fit`. Defined only when `X`
  1664. has feature names that are all strings.
  1665. .. versionadded:: 1.0
  1666. n_iter_ : int
  1667. Number of iterations over the full dataset.
  1668. n_steps_ : int
  1669. Number of mini-batches processed.
  1670. .. versionadded:: 1.1
  1671. See Also
  1672. --------
  1673. DictionaryLearning : Find a dictionary that sparsely encodes data.
  1674. MiniBatchSparsePCA : Mini-batch Sparse Principal Components Analysis.
  1675. SparseCoder : Find a sparse representation of data from a fixed,
  1676. precomputed dictionary.
  1677. SparsePCA : Sparse Principal Components Analysis.
  1678. References
  1679. ----------
  1680. J. Mairal, F. Bach, J. Ponce, G. Sapiro, 2009: Online dictionary learning
  1681. for sparse coding (https://www.di.ens.fr/sierra/pdfs/icml09.pdf)
  1682. Examples
  1683. --------
  1684. >>> import numpy as np
  1685. >>> from sklearn.datasets import make_sparse_coded_signal
  1686. >>> from sklearn.decomposition import MiniBatchDictionaryLearning
  1687. >>> X, dictionary, code = make_sparse_coded_signal(
  1688. ... n_samples=30, n_components=15, n_features=20, n_nonzero_coefs=10,
  1689. ... random_state=42)
  1690. >>> dict_learner = MiniBatchDictionaryLearning(
  1691. ... n_components=15, batch_size=3, transform_algorithm='lasso_lars',
  1692. ... transform_alpha=0.1, max_iter=20, random_state=42)
  1693. >>> X_transformed = dict_learner.fit_transform(X)
  1694. We can check the level of sparsity of `X_transformed`:
  1695. >>> np.mean(X_transformed == 0) > 0.5
  1696. True
  1697. We can compare the average squared euclidean norm of the reconstruction
  1698. error of the sparse coded signal relative to the squared euclidean norm of
  1699. the original signal:
  1700. >>> X_hat = X_transformed @ dict_learner.components_
  1701. >>> np.mean(np.sum((X_hat - X) ** 2, axis=1) / np.sum(X ** 2, axis=1))
  1702. 0.052...
  1703. """
  1704. _parameter_constraints: dict = {
  1705. "n_components": [Interval(Integral, 1, None, closed="left"), None],
  1706. "alpha": [Interval(Real, 0, None, closed="left")],
  1707. "n_iter": [
  1708. Interval(Integral, 0, None, closed="left"),
  1709. Hidden(StrOptions({"deprecated"})),
  1710. ],
  1711. "max_iter": [Interval(Integral, 0, None, closed="left"), None],
  1712. "fit_algorithm": [StrOptions({"cd", "lars"})],
  1713. "n_jobs": [None, Integral],
  1714. "batch_size": [Interval(Integral, 1, None, closed="left")],
  1715. "shuffle": ["boolean"],
  1716. "dict_init": [None, np.ndarray],
  1717. "transform_algorithm": [
  1718. StrOptions({"lasso_lars", "lasso_cd", "lars", "omp", "threshold"})
  1719. ],
  1720. "transform_n_nonzero_coefs": [Interval(Integral, 1, None, closed="left"), None],
  1721. "transform_alpha": [Interval(Real, 0, None, closed="left"), None],
  1722. "verbose": ["verbose"],
  1723. "split_sign": ["boolean"],
  1724. "random_state": ["random_state"],
  1725. "positive_code": ["boolean"],
  1726. "positive_dict": ["boolean"],
  1727. "transform_max_iter": [Interval(Integral, 0, None, closed="left")],
  1728. "callback": [None, callable],
  1729. "tol": [Interval(Real, 0, None, closed="left")],
  1730. "max_no_improvement": [Interval(Integral, 0, None, closed="left"), None],
  1731. }
  1732. def __init__(
  1733. self,
  1734. n_components=None,
  1735. *,
  1736. alpha=1,
  1737. n_iter="deprecated",
  1738. max_iter=None,
  1739. fit_algorithm="lars",
  1740. n_jobs=None,
  1741. batch_size=256,
  1742. shuffle=True,
  1743. dict_init=None,
  1744. transform_algorithm="omp",
  1745. transform_n_nonzero_coefs=None,
  1746. transform_alpha=None,
  1747. verbose=False,
  1748. split_sign=False,
  1749. random_state=None,
  1750. positive_code=False,
  1751. positive_dict=False,
  1752. transform_max_iter=1000,
  1753. callback=None,
  1754. tol=1e-3,
  1755. max_no_improvement=10,
  1756. ):
  1757. super().__init__(
  1758. transform_algorithm,
  1759. transform_n_nonzero_coefs,
  1760. transform_alpha,
  1761. split_sign,
  1762. n_jobs,
  1763. positive_code,
  1764. transform_max_iter,
  1765. )
  1766. self.n_components = n_components
  1767. self.alpha = alpha
  1768. self.n_iter = n_iter
  1769. self.max_iter = max_iter
  1770. self.fit_algorithm = fit_algorithm
  1771. self.dict_init = dict_init
  1772. self.verbose = verbose
  1773. self.shuffle = shuffle
  1774. self.batch_size = batch_size
  1775. self.split_sign = split_sign
  1776. self.random_state = random_state
  1777. self.positive_dict = positive_dict
  1778. self.callback = callback
  1779. self.max_no_improvement = max_no_improvement
  1780. self.tol = tol
  1781. def _check_params(self, X):
  1782. # n_components
  1783. self._n_components = self.n_components
  1784. if self._n_components is None:
  1785. self._n_components = X.shape[1]
  1786. # fit_algorithm
  1787. _check_positive_coding(self.fit_algorithm, self.positive_code)
  1788. self._fit_algorithm = "lasso_" + self.fit_algorithm
  1789. # batch_size
  1790. self._batch_size = min(self.batch_size, X.shape[0])
  1791. def _initialize_dict(self, X, random_state):
  1792. """Initialization of the dictionary."""
  1793. if self.dict_init is not None:
  1794. dictionary = self.dict_init
  1795. else:
  1796. # Init V with SVD of X
  1797. _, S, dictionary = randomized_svd(
  1798. X, self._n_components, random_state=random_state
  1799. )
  1800. dictionary = S[:, np.newaxis] * dictionary
  1801. if self._n_components <= len(dictionary):
  1802. dictionary = dictionary[: self._n_components, :]
  1803. else:
  1804. dictionary = np.concatenate(
  1805. (
  1806. dictionary,
  1807. np.zeros(
  1808. (self._n_components - len(dictionary), dictionary.shape[1]),
  1809. dtype=dictionary.dtype,
  1810. ),
  1811. )
  1812. )
  1813. dictionary = check_array(dictionary, order="F", dtype=X.dtype, copy=False)
  1814. dictionary = np.require(dictionary, requirements="W")
  1815. return dictionary
  1816. def _update_inner_stats(self, X, code, batch_size, step):
  1817. """Update the inner stats inplace."""
  1818. if step < batch_size - 1:
  1819. theta = (step + 1) * batch_size
  1820. else:
  1821. theta = batch_size**2 + step + 1 - batch_size
  1822. beta = (theta + 1 - batch_size) / (theta + 1)
  1823. self._A *= beta
  1824. self._A += code.T @ code / batch_size
  1825. self._B *= beta
  1826. self._B += X.T @ code / batch_size
  1827. def _minibatch_step(self, X, dictionary, random_state, step):
  1828. """Perform the update on the dictionary for one minibatch."""
  1829. batch_size = X.shape[0]
  1830. # Compute code for this batch
  1831. code = _sparse_encode(
  1832. X,
  1833. dictionary,
  1834. algorithm=self._fit_algorithm,
  1835. alpha=self.alpha,
  1836. n_jobs=self.n_jobs,
  1837. positive=self.positive_code,
  1838. max_iter=self.transform_max_iter,
  1839. verbose=self.verbose,
  1840. )
  1841. batch_cost = (
  1842. 0.5 * ((X - code @ dictionary) ** 2).sum()
  1843. + self.alpha * np.sum(np.abs(code))
  1844. ) / batch_size
  1845. # Update inner stats
  1846. self._update_inner_stats(X, code, batch_size, step)
  1847. # Update dictionary
  1848. _update_dict(
  1849. dictionary,
  1850. X,
  1851. code,
  1852. self._A,
  1853. self._B,
  1854. verbose=self.verbose,
  1855. random_state=random_state,
  1856. positive=self.positive_dict,
  1857. )
  1858. return batch_cost
  1859. def _check_convergence(
  1860. self, X, batch_cost, new_dict, old_dict, n_samples, step, n_steps
  1861. ):
  1862. """Helper function to encapsulate the early stopping logic.
  1863. Early stopping is based on two factors:
  1864. - A small change of the dictionary between two minibatch updates. This is
  1865. controlled by the tol parameter.
  1866. - No more improvement on a smoothed estimate of the objective function for a
  1867. a certain number of consecutive minibatch updates. This is controlled by
  1868. the max_no_improvement parameter.
  1869. """
  1870. batch_size = X.shape[0]
  1871. # counts steps starting from 1 for user friendly verbose mode.
  1872. step = step + 1
  1873. # Ignore 100 first steps or 1 epoch to avoid initializing the ewa_cost with a
  1874. # too bad value
  1875. if step <= min(100, n_samples / batch_size):
  1876. if self.verbose:
  1877. print(f"Minibatch step {step}/{n_steps}: mean batch cost: {batch_cost}")
  1878. return False
  1879. # Compute an Exponentially Weighted Average of the cost function to
  1880. # monitor the convergence while discarding minibatch-local stochastic
  1881. # variability: https://en.wikipedia.org/wiki/Moving_average
  1882. if self._ewa_cost is None:
  1883. self._ewa_cost = batch_cost
  1884. else:
  1885. alpha = batch_size / (n_samples + 1)
  1886. alpha = min(alpha, 1)
  1887. self._ewa_cost = self._ewa_cost * (1 - alpha) + batch_cost * alpha
  1888. if self.verbose:
  1889. print(
  1890. f"Minibatch step {step}/{n_steps}: mean batch cost: "
  1891. f"{batch_cost}, ewa cost: {self._ewa_cost}"
  1892. )
  1893. # Early stopping based on change of dictionary
  1894. dict_diff = linalg.norm(new_dict - old_dict) / self._n_components
  1895. if self.tol > 0 and dict_diff <= self.tol:
  1896. if self.verbose:
  1897. print(f"Converged (small dictionary change) at step {step}/{n_steps}")
  1898. return True
  1899. # Early stopping heuristic due to lack of improvement on smoothed
  1900. # cost function
  1901. if self._ewa_cost_min is None or self._ewa_cost < self._ewa_cost_min:
  1902. self._no_improvement = 0
  1903. self._ewa_cost_min = self._ewa_cost
  1904. else:
  1905. self._no_improvement += 1
  1906. if (
  1907. self.max_no_improvement is not None
  1908. and self._no_improvement >= self.max_no_improvement
  1909. ):
  1910. if self.verbose:
  1911. print(
  1912. "Converged (lack of improvement in objective function) "
  1913. f"at step {step}/{n_steps}"
  1914. )
  1915. return True
  1916. return False
  1917. @_fit_context(prefer_skip_nested_validation=True)
  1918. def fit(self, X, y=None):
  1919. """Fit the model from data in X.
  1920. Parameters
  1921. ----------
  1922. X : array-like of shape (n_samples, n_features)
  1923. Training vector, where `n_samples` is the number of samples
  1924. and `n_features` is the number of features.
  1925. y : Ignored
  1926. Not used, present for API consistency by convention.
  1927. Returns
  1928. -------
  1929. self : object
  1930. Returns the instance itself.
  1931. """
  1932. X = self._validate_data(
  1933. X, dtype=[np.float64, np.float32], order="C", copy=False
  1934. )
  1935. self._check_params(X)
  1936. if self.n_iter != "deprecated":
  1937. warnings.warn(
  1938. (
  1939. "'n_iter' is deprecated in version 1.1 and will be removed "
  1940. "in version 1.4. Use 'max_iter' and let 'n_iter' to its default "
  1941. "value instead. 'n_iter' is also ignored if 'max_iter' is "
  1942. "specified."
  1943. ),
  1944. FutureWarning,
  1945. )
  1946. n_iter = self.n_iter
  1947. self._random_state = check_random_state(self.random_state)
  1948. dictionary = self._initialize_dict(X, self._random_state)
  1949. old_dict = dictionary.copy()
  1950. if self.shuffle:
  1951. X_train = X.copy()
  1952. self._random_state.shuffle(X_train)
  1953. else:
  1954. X_train = X
  1955. n_samples, n_features = X_train.shape
  1956. if self.verbose:
  1957. print("[dict_learning]")
  1958. # Inner stats
  1959. self._A = np.zeros(
  1960. (self._n_components, self._n_components), dtype=X_train.dtype
  1961. )
  1962. self._B = np.zeros((n_features, self._n_components), dtype=X_train.dtype)
  1963. if self.max_iter is not None:
  1964. # Attributes to monitor the convergence
  1965. self._ewa_cost = None
  1966. self._ewa_cost_min = None
  1967. self._no_improvement = 0
  1968. batches = gen_batches(n_samples, self._batch_size)
  1969. batches = itertools.cycle(batches)
  1970. n_steps_per_iter = int(np.ceil(n_samples / self._batch_size))
  1971. n_steps = self.max_iter * n_steps_per_iter
  1972. i = -1 # to allow max_iter = 0
  1973. for i, batch in zip(range(n_steps), batches):
  1974. X_batch = X_train[batch]
  1975. batch_cost = self._minibatch_step(
  1976. X_batch, dictionary, self._random_state, i
  1977. )
  1978. if self._check_convergence(
  1979. X_batch, batch_cost, dictionary, old_dict, n_samples, i, n_steps
  1980. ):
  1981. break
  1982. # XXX callback param added for backward compat in #18975 but a common
  1983. # unified callback API should be preferred
  1984. if self.callback is not None:
  1985. self.callback(locals())
  1986. old_dict[:] = dictionary
  1987. self.n_steps_ = i + 1
  1988. self.n_iter_ = np.ceil(self.n_steps_ / n_steps_per_iter)
  1989. else:
  1990. # TODO remove this branch in 1.4
  1991. n_iter = 1000 if self.n_iter == "deprecated" else self.n_iter
  1992. batches = gen_batches(n_samples, self._batch_size)
  1993. batches = itertools.cycle(batches)
  1994. for i, batch in zip(range(n_iter), batches):
  1995. self._minibatch_step(X_train[batch], dictionary, self._random_state, i)
  1996. trigger_verbose = self.verbose and i % ceil(100.0 / self.verbose) == 0
  1997. if self.verbose > 10 or trigger_verbose:
  1998. print(f"{i} batches processed.")
  1999. if self.callback is not None:
  2000. self.callback(locals())
  2001. self.n_steps_ = n_iter
  2002. self.n_iter_ = np.ceil(n_iter / int(np.ceil(n_samples / self._batch_size)))
  2003. self.components_ = dictionary
  2004. return self
  2005. @_fit_context(prefer_skip_nested_validation=True)
  2006. def partial_fit(self, X, y=None):
  2007. """Update the model using the data in X as a mini-batch.
  2008. Parameters
  2009. ----------
  2010. X : array-like of shape (n_samples, n_features)
  2011. Training vector, where `n_samples` is the number of samples
  2012. and `n_features` is the number of features.
  2013. y : Ignored
  2014. Not used, present for API consistency by convention.
  2015. Returns
  2016. -------
  2017. self : object
  2018. Return the instance itself.
  2019. """
  2020. has_components = hasattr(self, "components_")
  2021. X = self._validate_data(
  2022. X, dtype=[np.float64, np.float32], order="C", reset=not has_components
  2023. )
  2024. if not has_components:
  2025. # This instance has not been fitted yet (fit or partial_fit)
  2026. self._check_params(X)
  2027. self._random_state = check_random_state(self.random_state)
  2028. dictionary = self._initialize_dict(X, self._random_state)
  2029. self.n_steps_ = 0
  2030. self._A = np.zeros((self._n_components, self._n_components), dtype=X.dtype)
  2031. self._B = np.zeros((X.shape[1], self._n_components), dtype=X.dtype)
  2032. else:
  2033. dictionary = self.components_
  2034. self._minibatch_step(X, dictionary, self._random_state, self.n_steps_)
  2035. self.components_ = dictionary
  2036. self.n_steps_ += 1
  2037. return self
  2038. @property
  2039. def _n_features_out(self):
  2040. """Number of transformed output features."""
  2041. return self.components_.shape[0]
  2042. def _more_tags(self):
  2043. return {
  2044. "preserves_dtype": [np.float64, np.float32],
  2045. }