test_dict_learning.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026
  1. import itertools
  2. import warnings
  3. from functools import partial
  4. import numpy as np
  5. import pytest
  6. import sklearn
  7. from sklearn.base import clone
  8. from sklearn.decomposition import (
  9. DictionaryLearning,
  10. MiniBatchDictionaryLearning,
  11. SparseCoder,
  12. dict_learning,
  13. dict_learning_online,
  14. sparse_encode,
  15. )
  16. from sklearn.decomposition._dict_learning import _update_dict
  17. from sklearn.exceptions import ConvergenceWarning
  18. from sklearn.utils import check_array
  19. from sklearn.utils._testing import (
  20. TempMemmap,
  21. assert_allclose,
  22. assert_array_almost_equal,
  23. assert_array_equal,
  24. ignore_warnings,
  25. )
  26. from sklearn.utils.estimator_checks import (
  27. check_transformer_data_not_an_array,
  28. check_transformer_general,
  29. check_transformers_unfitted,
  30. )
  31. from sklearn.utils.parallel import Parallel
  32. rng_global = np.random.RandomState(0)
  33. n_samples, n_features = 10, 8
  34. X = rng_global.randn(n_samples, n_features)
  35. def test_sparse_encode_shapes_omp():
  36. rng = np.random.RandomState(0)
  37. algorithms = ["omp", "lasso_lars", "lasso_cd", "lars", "threshold"]
  38. for n_components, n_samples in itertools.product([1, 5], [1, 9]):
  39. X_ = rng.randn(n_samples, n_features)
  40. dictionary = rng.randn(n_components, n_features)
  41. for algorithm, n_jobs in itertools.product(algorithms, [1, 2]):
  42. code = sparse_encode(X_, dictionary, algorithm=algorithm, n_jobs=n_jobs)
  43. assert code.shape == (n_samples, n_components)
  44. def test_dict_learning_shapes():
  45. n_components = 5
  46. dico = DictionaryLearning(n_components, random_state=0).fit(X)
  47. assert dico.components_.shape == (n_components, n_features)
  48. n_components = 1
  49. dico = DictionaryLearning(n_components, random_state=0).fit(X)
  50. assert dico.components_.shape == (n_components, n_features)
  51. assert dico.transform(X).shape == (X.shape[0], n_components)
  52. def test_dict_learning_overcomplete():
  53. n_components = 12
  54. dico = DictionaryLearning(n_components, random_state=0).fit(X)
  55. assert dico.components_.shape == (n_components, n_features)
  56. def test_max_iter():
  57. def ricker_function(resolution, center, width):
  58. """Discrete sub-sampled Ricker (Mexican hat) wavelet"""
  59. x = np.linspace(0, resolution - 1, resolution)
  60. x = (
  61. (2 / (np.sqrt(3 * width) * np.pi**0.25))
  62. * (1 - (x - center) ** 2 / width**2)
  63. * np.exp(-((x - center) ** 2) / (2 * width**2))
  64. )
  65. return x
  66. def ricker_matrix(width, resolution, n_components):
  67. """Dictionary of Ricker (Mexican hat) wavelets"""
  68. centers = np.linspace(0, resolution - 1, n_components)
  69. D = np.empty((n_components, resolution))
  70. for i, center in enumerate(centers):
  71. D[i] = ricker_function(resolution, center, width)
  72. D /= np.sqrt(np.sum(D**2, axis=1))[:, np.newaxis]
  73. return D
  74. transform_algorithm = "lasso_cd"
  75. resolution = 1024
  76. subsampling = 3 # subsampling factor
  77. n_components = resolution // subsampling
  78. # Compute a wavelet dictionary
  79. D_multi = np.r_[
  80. tuple(
  81. ricker_matrix(
  82. width=w, resolution=resolution, n_components=n_components // 5
  83. )
  84. for w in (10, 50, 100, 500, 1000)
  85. )
  86. ]
  87. X = np.linspace(0, resolution - 1, resolution)
  88. first_quarter = X < resolution / 4
  89. X[first_quarter] = 3.0
  90. X[np.logical_not(first_quarter)] = -1.0
  91. X = X.reshape(1, -1)
  92. # check that the underlying model fails to converge
  93. with pytest.warns(ConvergenceWarning):
  94. model = SparseCoder(
  95. D_multi, transform_algorithm=transform_algorithm, transform_max_iter=1
  96. )
  97. model.fit_transform(X)
  98. # check that the underlying model converges w/o warnings
  99. with warnings.catch_warnings():
  100. warnings.simplefilter("error", ConvergenceWarning)
  101. model = SparseCoder(
  102. D_multi, transform_algorithm=transform_algorithm, transform_max_iter=2000
  103. )
  104. model.fit_transform(X)
  105. def test_dict_learning_lars_positive_parameter():
  106. n_components = 5
  107. alpha = 1
  108. err_msg = "Positive constraint not supported for 'lars' coding method."
  109. with pytest.raises(ValueError, match=err_msg):
  110. dict_learning(X, n_components, alpha=alpha, positive_code=True)
  111. @pytest.mark.parametrize(
  112. "transform_algorithm",
  113. [
  114. "lasso_lars",
  115. "lasso_cd",
  116. "threshold",
  117. ],
  118. )
  119. @pytest.mark.parametrize("positive_code", [False, True])
  120. @pytest.mark.parametrize("positive_dict", [False, True])
  121. def test_dict_learning_positivity(transform_algorithm, positive_code, positive_dict):
  122. n_components = 5
  123. dico = DictionaryLearning(
  124. n_components,
  125. transform_algorithm=transform_algorithm,
  126. random_state=0,
  127. positive_code=positive_code,
  128. positive_dict=positive_dict,
  129. fit_algorithm="cd",
  130. ).fit(X)
  131. code = dico.transform(X)
  132. if positive_dict:
  133. assert (dico.components_ >= 0).all()
  134. else:
  135. assert (dico.components_ < 0).any()
  136. if positive_code:
  137. assert (code >= 0).all()
  138. else:
  139. assert (code < 0).any()
  140. @pytest.mark.parametrize("positive_dict", [False, True])
  141. def test_dict_learning_lars_dict_positivity(positive_dict):
  142. n_components = 5
  143. dico = DictionaryLearning(
  144. n_components,
  145. transform_algorithm="lars",
  146. random_state=0,
  147. positive_dict=positive_dict,
  148. fit_algorithm="cd",
  149. ).fit(X)
  150. if positive_dict:
  151. assert (dico.components_ >= 0).all()
  152. else:
  153. assert (dico.components_ < 0).any()
  154. def test_dict_learning_lars_code_positivity():
  155. n_components = 5
  156. dico = DictionaryLearning(
  157. n_components,
  158. transform_algorithm="lars",
  159. random_state=0,
  160. positive_code=True,
  161. fit_algorithm="cd",
  162. ).fit(X)
  163. err_msg = "Positive constraint not supported for '{}' coding method."
  164. err_msg = err_msg.format("lars")
  165. with pytest.raises(ValueError, match=err_msg):
  166. dico.transform(X)
  167. def test_dict_learning_reconstruction():
  168. n_components = 12
  169. dico = DictionaryLearning(
  170. n_components, transform_algorithm="omp", transform_alpha=0.001, random_state=0
  171. )
  172. code = dico.fit(X).transform(X)
  173. assert_array_almost_equal(np.dot(code, dico.components_), X)
  174. dico.set_params(transform_algorithm="lasso_lars")
  175. code = dico.transform(X)
  176. assert_array_almost_equal(np.dot(code, dico.components_), X, decimal=2)
  177. # used to test lars here too, but there's no guarantee the number of
  178. # nonzero atoms is right.
  179. def test_dict_learning_reconstruction_parallel():
  180. # regression test that parallel reconstruction works with n_jobs>1
  181. n_components = 12
  182. dico = DictionaryLearning(
  183. n_components,
  184. transform_algorithm="omp",
  185. transform_alpha=0.001,
  186. random_state=0,
  187. n_jobs=4,
  188. )
  189. code = dico.fit(X).transform(X)
  190. assert_array_almost_equal(np.dot(code, dico.components_), X)
  191. dico.set_params(transform_algorithm="lasso_lars")
  192. code = dico.transform(X)
  193. assert_array_almost_equal(np.dot(code, dico.components_), X, decimal=2)
  194. def test_dict_learning_lassocd_readonly_data():
  195. n_components = 12
  196. with TempMemmap(X) as X_read_only:
  197. dico = DictionaryLearning(
  198. n_components,
  199. transform_algorithm="lasso_cd",
  200. transform_alpha=0.001,
  201. random_state=0,
  202. n_jobs=4,
  203. )
  204. with ignore_warnings(category=ConvergenceWarning):
  205. code = dico.fit(X_read_only).transform(X_read_only)
  206. assert_array_almost_equal(
  207. np.dot(code, dico.components_), X_read_only, decimal=2
  208. )
  209. def test_dict_learning_nonzero_coefs():
  210. n_components = 4
  211. dico = DictionaryLearning(
  212. n_components,
  213. transform_algorithm="lars",
  214. transform_n_nonzero_coefs=3,
  215. random_state=0,
  216. )
  217. code = dico.fit(X).transform(X[np.newaxis, 1])
  218. assert len(np.flatnonzero(code)) == 3
  219. dico.set_params(transform_algorithm="omp")
  220. code = dico.transform(X[np.newaxis, 1])
  221. assert len(np.flatnonzero(code)) == 3
  222. def test_dict_learning_split():
  223. n_components = 5
  224. dico = DictionaryLearning(
  225. n_components, transform_algorithm="threshold", random_state=0
  226. )
  227. code = dico.fit(X).transform(X)
  228. dico.split_sign = True
  229. split_code = dico.transform(X)
  230. assert_array_almost_equal(
  231. split_code[:, :n_components] - split_code[:, n_components:], code
  232. )
  233. def test_dict_learning_online_shapes():
  234. rng = np.random.RandomState(0)
  235. n_components = 8
  236. code, dictionary = dict_learning_online(
  237. X,
  238. n_components=n_components,
  239. batch_size=4,
  240. max_iter=10,
  241. method="cd",
  242. random_state=rng,
  243. return_code=True,
  244. )
  245. assert code.shape == (n_samples, n_components)
  246. assert dictionary.shape == (n_components, n_features)
  247. assert np.dot(code, dictionary).shape == X.shape
  248. dictionary = dict_learning_online(
  249. X,
  250. n_components=n_components,
  251. batch_size=4,
  252. max_iter=10,
  253. method="cd",
  254. random_state=rng,
  255. return_code=False,
  256. )
  257. assert dictionary.shape == (n_components, n_features)
  258. def test_dict_learning_online_lars_positive_parameter():
  259. err_msg = "Positive constraint not supported for 'lars' coding method."
  260. with pytest.raises(ValueError, match=err_msg):
  261. dict_learning_online(X, batch_size=4, max_iter=10, positive_code=True)
  262. @pytest.mark.parametrize(
  263. "transform_algorithm",
  264. [
  265. "lasso_lars",
  266. "lasso_cd",
  267. "threshold",
  268. ],
  269. )
  270. @pytest.mark.parametrize("positive_code", [False, True])
  271. @pytest.mark.parametrize("positive_dict", [False, True])
  272. def test_minibatch_dictionary_learning_positivity(
  273. transform_algorithm, positive_code, positive_dict
  274. ):
  275. n_components = 8
  276. dico = MiniBatchDictionaryLearning(
  277. n_components,
  278. batch_size=4,
  279. max_iter=10,
  280. transform_algorithm=transform_algorithm,
  281. random_state=0,
  282. positive_code=positive_code,
  283. positive_dict=positive_dict,
  284. fit_algorithm="cd",
  285. ).fit(X)
  286. code = dico.transform(X)
  287. if positive_dict:
  288. assert (dico.components_ >= 0).all()
  289. else:
  290. assert (dico.components_ < 0).any()
  291. if positive_code:
  292. assert (code >= 0).all()
  293. else:
  294. assert (code < 0).any()
  295. @pytest.mark.parametrize("positive_dict", [False, True])
  296. def test_minibatch_dictionary_learning_lars(positive_dict):
  297. n_components = 8
  298. dico = MiniBatchDictionaryLearning(
  299. n_components,
  300. batch_size=4,
  301. max_iter=10,
  302. transform_algorithm="lars",
  303. random_state=0,
  304. positive_dict=positive_dict,
  305. fit_algorithm="cd",
  306. ).fit(X)
  307. if positive_dict:
  308. assert (dico.components_ >= 0).all()
  309. else:
  310. assert (dico.components_ < 0).any()
  311. @pytest.mark.parametrize("positive_code", [False, True])
  312. @pytest.mark.parametrize("positive_dict", [False, True])
  313. def test_dict_learning_online_positivity(positive_code, positive_dict):
  314. rng = np.random.RandomState(0)
  315. n_components = 8
  316. code, dictionary = dict_learning_online(
  317. X,
  318. n_components=n_components,
  319. batch_size=4,
  320. method="cd",
  321. alpha=1,
  322. random_state=rng,
  323. positive_dict=positive_dict,
  324. positive_code=positive_code,
  325. )
  326. if positive_dict:
  327. assert (dictionary >= 0).all()
  328. else:
  329. assert (dictionary < 0).any()
  330. if positive_code:
  331. assert (code >= 0).all()
  332. else:
  333. assert (code < 0).any()
  334. def test_dict_learning_online_verbosity():
  335. # test verbosity for better coverage
  336. n_components = 5
  337. import sys
  338. from io import StringIO
  339. old_stdout = sys.stdout
  340. try:
  341. sys.stdout = StringIO()
  342. # convergence monitoring verbosity
  343. dico = MiniBatchDictionaryLearning(
  344. n_components, batch_size=4, max_iter=5, verbose=1, tol=0.1, random_state=0
  345. )
  346. dico.fit(X)
  347. dico = MiniBatchDictionaryLearning(
  348. n_components,
  349. batch_size=4,
  350. max_iter=5,
  351. verbose=1,
  352. max_no_improvement=2,
  353. random_state=0,
  354. )
  355. dico.fit(X)
  356. # higher verbosity level
  357. dico = MiniBatchDictionaryLearning(
  358. n_components, batch_size=4, max_iter=5, verbose=2, random_state=0
  359. )
  360. dico.fit(X)
  361. # function API verbosity
  362. dict_learning_online(
  363. X,
  364. n_components=n_components,
  365. batch_size=4,
  366. alpha=1,
  367. verbose=1,
  368. random_state=0,
  369. )
  370. dict_learning_online(
  371. X,
  372. n_components=n_components,
  373. batch_size=4,
  374. alpha=1,
  375. verbose=2,
  376. random_state=0,
  377. )
  378. finally:
  379. sys.stdout = old_stdout
  380. assert dico.components_.shape == (n_components, n_features)
  381. def test_dict_learning_online_estimator_shapes():
  382. n_components = 5
  383. dico = MiniBatchDictionaryLearning(
  384. n_components, batch_size=4, max_iter=5, random_state=0
  385. )
  386. dico.fit(X)
  387. assert dico.components_.shape == (n_components, n_features)
  388. def test_dict_learning_online_overcomplete():
  389. n_components = 12
  390. dico = MiniBatchDictionaryLearning(
  391. n_components, batch_size=4, max_iter=5, random_state=0
  392. ).fit(X)
  393. assert dico.components_.shape == (n_components, n_features)
  394. def test_dict_learning_online_initialization():
  395. n_components = 12
  396. rng = np.random.RandomState(0)
  397. V = rng.randn(n_components, n_features)
  398. dico = MiniBatchDictionaryLearning(
  399. n_components, batch_size=4, max_iter=0, dict_init=V, random_state=0
  400. ).fit(X)
  401. assert_array_equal(dico.components_, V)
  402. def test_dict_learning_online_readonly_initialization():
  403. n_components = 12
  404. rng = np.random.RandomState(0)
  405. V = rng.randn(n_components, n_features)
  406. V.setflags(write=False)
  407. MiniBatchDictionaryLearning(
  408. n_components,
  409. batch_size=4,
  410. max_iter=1,
  411. dict_init=V,
  412. random_state=0,
  413. shuffle=False,
  414. ).fit(X)
  415. def test_dict_learning_online_partial_fit():
  416. n_components = 12
  417. rng = np.random.RandomState(0)
  418. V = rng.randn(n_components, n_features) # random init
  419. V /= np.sum(V**2, axis=1)[:, np.newaxis]
  420. dict1 = MiniBatchDictionaryLearning(
  421. n_components,
  422. max_iter=10,
  423. batch_size=1,
  424. alpha=1,
  425. shuffle=False,
  426. dict_init=V,
  427. max_no_improvement=None,
  428. tol=0.0,
  429. random_state=0,
  430. ).fit(X)
  431. dict2 = MiniBatchDictionaryLearning(
  432. n_components, alpha=1, dict_init=V, random_state=0
  433. )
  434. for i in range(10):
  435. for sample in X:
  436. dict2.partial_fit(sample[np.newaxis, :])
  437. assert not np.all(sparse_encode(X, dict1.components_, alpha=1) == 0)
  438. assert_array_almost_equal(dict1.components_, dict2.components_, decimal=2)
  439. # partial_fit should ignore max_iter (#17433)
  440. assert dict1.n_steps_ == dict2.n_steps_ == 100
  441. def test_sparse_encode_shapes():
  442. n_components = 12
  443. rng = np.random.RandomState(0)
  444. V = rng.randn(n_components, n_features) # random init
  445. V /= np.sum(V**2, axis=1)[:, np.newaxis]
  446. for algo in ("lasso_lars", "lasso_cd", "lars", "omp", "threshold"):
  447. code = sparse_encode(X, V, algorithm=algo)
  448. assert code.shape == (n_samples, n_components)
  449. @pytest.mark.parametrize("algo", ["lasso_lars", "lasso_cd", "threshold"])
  450. @pytest.mark.parametrize("positive", [False, True])
  451. def test_sparse_encode_positivity(algo, positive):
  452. n_components = 12
  453. rng = np.random.RandomState(0)
  454. V = rng.randn(n_components, n_features) # random init
  455. V /= np.sum(V**2, axis=1)[:, np.newaxis]
  456. code = sparse_encode(X, V, algorithm=algo, positive=positive)
  457. if positive:
  458. assert (code >= 0).all()
  459. else:
  460. assert (code < 0).any()
  461. @pytest.mark.parametrize("algo", ["lars", "omp"])
  462. def test_sparse_encode_unavailable_positivity(algo):
  463. n_components = 12
  464. rng = np.random.RandomState(0)
  465. V = rng.randn(n_components, n_features) # random init
  466. V /= np.sum(V**2, axis=1)[:, np.newaxis]
  467. err_msg = "Positive constraint not supported for '{}' coding method."
  468. err_msg = err_msg.format(algo)
  469. with pytest.raises(ValueError, match=err_msg):
  470. sparse_encode(X, V, algorithm=algo, positive=True)
  471. def test_sparse_encode_input():
  472. n_components = 100
  473. rng = np.random.RandomState(0)
  474. V = rng.randn(n_components, n_features) # random init
  475. V /= np.sum(V**2, axis=1)[:, np.newaxis]
  476. Xf = check_array(X, order="F")
  477. for algo in ("lasso_lars", "lasso_cd", "lars", "omp", "threshold"):
  478. a = sparse_encode(X, V, algorithm=algo)
  479. b = sparse_encode(Xf, V, algorithm=algo)
  480. assert_array_almost_equal(a, b)
  481. def test_sparse_encode_error():
  482. n_components = 12
  483. rng = np.random.RandomState(0)
  484. V = rng.randn(n_components, n_features) # random init
  485. V /= np.sum(V**2, axis=1)[:, np.newaxis]
  486. code = sparse_encode(X, V, alpha=0.001)
  487. assert not np.all(code == 0)
  488. assert np.sqrt(np.sum((np.dot(code, V) - X) ** 2)) < 0.1
  489. def test_sparse_encode_error_default_sparsity():
  490. rng = np.random.RandomState(0)
  491. X = rng.randn(100, 64)
  492. D = rng.randn(2, 64)
  493. code = ignore_warnings(sparse_encode)(X, D, algorithm="omp", n_nonzero_coefs=None)
  494. assert code.shape == (100, 2)
  495. def test_sparse_coder_estimator():
  496. n_components = 12
  497. rng = np.random.RandomState(0)
  498. V = rng.randn(n_components, n_features) # random init
  499. V /= np.sum(V**2, axis=1)[:, np.newaxis]
  500. coder = SparseCoder(
  501. dictionary=V, transform_algorithm="lasso_lars", transform_alpha=0.001
  502. ).transform(X)
  503. assert not np.all(coder == 0)
  504. assert np.sqrt(np.sum((np.dot(coder, V) - X) ** 2)) < 0.1
  505. def test_sparse_coder_estimator_clone():
  506. n_components = 12
  507. rng = np.random.RandomState(0)
  508. V = rng.randn(n_components, n_features) # random init
  509. V /= np.sum(V**2, axis=1)[:, np.newaxis]
  510. coder = SparseCoder(
  511. dictionary=V, transform_algorithm="lasso_lars", transform_alpha=0.001
  512. )
  513. cloned = clone(coder)
  514. assert id(cloned) != id(coder)
  515. np.testing.assert_allclose(cloned.dictionary, coder.dictionary)
  516. assert id(cloned.dictionary) != id(coder.dictionary)
  517. assert cloned.n_components_ == coder.n_components_
  518. assert cloned.n_features_in_ == coder.n_features_in_
  519. data = np.random.rand(n_samples, n_features).astype(np.float32)
  520. np.testing.assert_allclose(cloned.transform(data), coder.transform(data))
  521. def test_sparse_coder_parallel_mmap():
  522. # Non-regression test for:
  523. # https://github.com/scikit-learn/scikit-learn/issues/5956
  524. # Test that SparseCoder does not error by passing reading only
  525. # arrays to child processes
  526. rng = np.random.RandomState(777)
  527. n_components, n_features = 40, 64
  528. init_dict = rng.rand(n_components, n_features)
  529. # Ensure that `data` is >2M. Joblib memory maps arrays
  530. # if they are larger than 1MB. The 4 accounts for float32
  531. # data type
  532. n_samples = int(2e6) // (4 * n_features)
  533. data = np.random.rand(n_samples, n_features).astype(np.float32)
  534. sc = SparseCoder(init_dict, transform_algorithm="omp", n_jobs=2)
  535. sc.fit_transform(data)
  536. def test_sparse_coder_common_transformer():
  537. rng = np.random.RandomState(777)
  538. n_components, n_features = 40, 3
  539. init_dict = rng.rand(n_components, n_features)
  540. sc = SparseCoder(init_dict)
  541. check_transformer_data_not_an_array(sc.__class__.__name__, sc)
  542. check_transformer_general(sc.__class__.__name__, sc)
  543. check_transformer_general_memmap = partial(
  544. check_transformer_general, readonly_memmap=True
  545. )
  546. check_transformer_general_memmap(sc.__class__.__name__, sc)
  547. check_transformers_unfitted(sc.__class__.__name__, sc)
  548. def test_sparse_coder_n_features_in():
  549. d = np.array([[1, 2, 3], [1, 2, 3]])
  550. sc = SparseCoder(d)
  551. assert sc.n_features_in_ == d.shape[1]
  552. def test_minibatch_dict_learning_n_iter_deprecated():
  553. # check the deprecation warning of n_iter
  554. # TODO(1.4) remove
  555. depr_msg = (
  556. "'n_iter' is deprecated in version 1.1 and will be removed in version 1.4"
  557. )
  558. est = MiniBatchDictionaryLearning(
  559. n_components=2, batch_size=4, n_iter=5, random_state=0
  560. )
  561. with pytest.warns(FutureWarning, match=depr_msg):
  562. est.fit(X)
  563. @pytest.mark.parametrize(
  564. "arg, val",
  565. [
  566. ("iter_offset", 0),
  567. ("inner_stats", None),
  568. ("return_inner_stats", False),
  569. ("return_n_iter", False),
  570. ("n_iter", 5),
  571. ],
  572. )
  573. def test_dict_learning_online_deprecated_args(arg, val):
  574. # check the deprecation warning for the deprecated args of
  575. # dict_learning_online
  576. # TODO(1.4) remove
  577. depr_msg = (
  578. f"'{arg}' is deprecated in version 1.1 and will be removed in version 1.4."
  579. )
  580. with pytest.warns(FutureWarning, match=depr_msg):
  581. dict_learning_online(
  582. X, n_components=2, batch_size=4, random_state=0, **{arg: val}
  583. )
  584. def test_update_dict():
  585. # Check the dict update in batch mode vs online mode
  586. # Non-regression test for #4866
  587. rng = np.random.RandomState(0)
  588. code = np.array([[0.5, -0.5], [0.1, 0.9]])
  589. dictionary = np.array([[1.0, 0.0], [0.6, 0.8]])
  590. X = np.dot(code, dictionary) + rng.randn(2, 2)
  591. # full batch update
  592. newd_batch = dictionary.copy()
  593. _update_dict(newd_batch, X, code)
  594. # online update
  595. A = np.dot(code.T, code)
  596. B = np.dot(X.T, code)
  597. newd_online = dictionary.copy()
  598. _update_dict(newd_online, X, code, A, B)
  599. assert_allclose(newd_batch, newd_online)
  600. # TODO(1.4) remove
  601. def test_dict_learning_online_n_iter_deprecated():
  602. # Check that an error is raised when a deprecated argument is set when max_iter
  603. # is also set.
  604. msg = "The following arguments are incompatible with 'max_iter'"
  605. with pytest.raises(ValueError, match=msg):
  606. dict_learning_online(X, max_iter=10, return_inner_stats=True)
  607. @pytest.mark.parametrize(
  608. "algorithm", ("lasso_lars", "lasso_cd", "lars", "threshold", "omp")
  609. )
  610. @pytest.mark.parametrize("data_type", (np.float32, np.float64))
  611. # Note: do not check integer input because `lasso_lars` and `lars` fail with
  612. # `ValueError` in `_lars_path_solver`
  613. def test_sparse_encode_dtype_match(data_type, algorithm):
  614. n_components = 6
  615. rng = np.random.RandomState(0)
  616. dictionary = rng.randn(n_components, n_features)
  617. code = sparse_encode(
  618. X.astype(data_type), dictionary.astype(data_type), algorithm=algorithm
  619. )
  620. assert code.dtype == data_type
  621. @pytest.mark.parametrize(
  622. "algorithm", ("lasso_lars", "lasso_cd", "lars", "threshold", "omp")
  623. )
  624. def test_sparse_encode_numerical_consistency(algorithm):
  625. # verify numerical consistency among np.float32 and np.float64
  626. rtol = 1e-4
  627. n_components = 6
  628. rng = np.random.RandomState(0)
  629. dictionary = rng.randn(n_components, n_features)
  630. code_32 = sparse_encode(
  631. X.astype(np.float32), dictionary.astype(np.float32), algorithm=algorithm
  632. )
  633. code_64 = sparse_encode(
  634. X.astype(np.float64), dictionary.astype(np.float64), algorithm=algorithm
  635. )
  636. assert_allclose(code_32, code_64, rtol=rtol)
  637. @pytest.mark.parametrize(
  638. "transform_algorithm", ("lasso_lars", "lasso_cd", "lars", "threshold", "omp")
  639. )
  640. @pytest.mark.parametrize("data_type", (np.float32, np.float64))
  641. # Note: do not check integer input because `lasso_lars` and `lars` fail with
  642. # `ValueError` in `_lars_path_solver`
  643. def test_sparse_coder_dtype_match(data_type, transform_algorithm):
  644. # Verify preserving dtype for transform in sparse coder
  645. n_components = 6
  646. rng = np.random.RandomState(0)
  647. dictionary = rng.randn(n_components, n_features)
  648. coder = SparseCoder(
  649. dictionary.astype(data_type), transform_algorithm=transform_algorithm
  650. )
  651. code = coder.transform(X.astype(data_type))
  652. assert code.dtype == data_type
  653. @pytest.mark.parametrize("fit_algorithm", ("lars", "cd"))
  654. @pytest.mark.parametrize(
  655. "transform_algorithm", ("lasso_lars", "lasso_cd", "lars", "threshold", "omp")
  656. )
  657. @pytest.mark.parametrize(
  658. "data_type, expected_type",
  659. (
  660. (np.float32, np.float32),
  661. (np.float64, np.float64),
  662. (np.int32, np.float64),
  663. (np.int64, np.float64),
  664. ),
  665. )
  666. def test_dictionary_learning_dtype_match(
  667. data_type,
  668. expected_type,
  669. fit_algorithm,
  670. transform_algorithm,
  671. ):
  672. # Verify preserving dtype for fit and transform in dictionary learning class
  673. dict_learner = DictionaryLearning(
  674. n_components=8,
  675. fit_algorithm=fit_algorithm,
  676. transform_algorithm=transform_algorithm,
  677. random_state=0,
  678. )
  679. dict_learner.fit(X.astype(data_type))
  680. assert dict_learner.components_.dtype == expected_type
  681. assert dict_learner.transform(X.astype(data_type)).dtype == expected_type
  682. @pytest.mark.parametrize("fit_algorithm", ("lars", "cd"))
  683. @pytest.mark.parametrize(
  684. "transform_algorithm", ("lasso_lars", "lasso_cd", "lars", "threshold", "omp")
  685. )
  686. @pytest.mark.parametrize(
  687. "data_type, expected_type",
  688. (
  689. (np.float32, np.float32),
  690. (np.float64, np.float64),
  691. (np.int32, np.float64),
  692. (np.int64, np.float64),
  693. ),
  694. )
  695. def test_minibatch_dictionary_learning_dtype_match(
  696. data_type,
  697. expected_type,
  698. fit_algorithm,
  699. transform_algorithm,
  700. ):
  701. # Verify preserving dtype for fit and transform in minibatch dictionary learning
  702. dict_learner = MiniBatchDictionaryLearning(
  703. n_components=8,
  704. batch_size=10,
  705. fit_algorithm=fit_algorithm,
  706. transform_algorithm=transform_algorithm,
  707. max_iter=100,
  708. tol=1e-1,
  709. random_state=0,
  710. )
  711. dict_learner.fit(X.astype(data_type))
  712. assert dict_learner.components_.dtype == expected_type
  713. assert dict_learner.transform(X.astype(data_type)).dtype == expected_type
  714. assert dict_learner._A.dtype == expected_type
  715. assert dict_learner._B.dtype == expected_type
  716. @pytest.mark.parametrize("method", ("lars", "cd"))
  717. @pytest.mark.parametrize(
  718. "data_type, expected_type",
  719. (
  720. (np.float32, np.float32),
  721. (np.float64, np.float64),
  722. (np.int32, np.float64),
  723. (np.int64, np.float64),
  724. ),
  725. )
  726. def test_dict_learning_dtype_match(data_type, expected_type, method):
  727. # Verify output matrix dtype
  728. rng = np.random.RandomState(0)
  729. n_components = 8
  730. code, dictionary, _ = dict_learning(
  731. X.astype(data_type),
  732. n_components=n_components,
  733. alpha=1,
  734. random_state=rng,
  735. method=method,
  736. )
  737. assert code.dtype == expected_type
  738. assert dictionary.dtype == expected_type
  739. @pytest.mark.parametrize("method", ("lars", "cd"))
  740. def test_dict_learning_numerical_consistency(method):
  741. # verify numerically consistent among np.float32 and np.float64
  742. rtol = 1e-6
  743. n_components = 4
  744. alpha = 2
  745. U_64, V_64, _ = dict_learning(
  746. X.astype(np.float64),
  747. n_components=n_components,
  748. alpha=alpha,
  749. random_state=0,
  750. method=method,
  751. )
  752. U_32, V_32, _ = dict_learning(
  753. X.astype(np.float32),
  754. n_components=n_components,
  755. alpha=alpha,
  756. random_state=0,
  757. method=method,
  758. )
  759. # Optimal solution (U*, V*) is not unique.
  760. # If (U*, V*) is optimal solution, (-U*,-V*) is also optimal,
  761. # and (column permutated U*, row permutated V*) are also optional
  762. # as long as holding UV.
  763. # So here UV, ||U||_1,1 and sum(||V_k||_2^2) are verified
  764. # instead of comparing directly U and V.
  765. assert_allclose(np.matmul(U_64, V_64), np.matmul(U_32, V_32), rtol=rtol)
  766. assert_allclose(np.sum(np.abs(U_64)), np.sum(np.abs(U_32)), rtol=rtol)
  767. assert_allclose(np.sum(V_64**2), np.sum(V_32**2), rtol=rtol)
  768. # verify an obtained solution is not degenerate
  769. assert np.mean(U_64 != 0.0) > 0.05
  770. assert np.count_nonzero(U_64 != 0.0) == np.count_nonzero(U_32 != 0.0)
  771. @pytest.mark.parametrize("method", ("lars", "cd"))
  772. @pytest.mark.parametrize(
  773. "data_type, expected_type",
  774. (
  775. (np.float32, np.float32),
  776. (np.float64, np.float64),
  777. (np.int32, np.float64),
  778. (np.int64, np.float64),
  779. ),
  780. )
  781. def test_dict_learning_online_dtype_match(data_type, expected_type, method):
  782. # Verify output matrix dtype
  783. rng = np.random.RandomState(0)
  784. n_components = 8
  785. code, dictionary = dict_learning_online(
  786. X.astype(data_type),
  787. n_components=n_components,
  788. alpha=1,
  789. batch_size=10,
  790. random_state=rng,
  791. method=method,
  792. )
  793. assert code.dtype == expected_type
  794. assert dictionary.dtype == expected_type
  795. @pytest.mark.parametrize("method", ("lars", "cd"))
  796. def test_dict_learning_online_numerical_consistency(method):
  797. # verify numerically consistent among np.float32 and np.float64
  798. rtol = 1e-4
  799. n_components = 4
  800. alpha = 1
  801. U_64, V_64 = dict_learning_online(
  802. X.astype(np.float64),
  803. n_components=n_components,
  804. alpha=alpha,
  805. batch_size=10,
  806. random_state=0,
  807. method=method,
  808. )
  809. U_32, V_32 = dict_learning_online(
  810. X.astype(np.float32),
  811. n_components=n_components,
  812. alpha=alpha,
  813. batch_size=10,
  814. random_state=0,
  815. method=method,
  816. )
  817. # Optimal solution (U*, V*) is not unique.
  818. # If (U*, V*) is optimal solution, (-U*,-V*) is also optimal,
  819. # and (column permutated U*, row permutated V*) are also optional
  820. # as long as holding UV.
  821. # So here UV, ||U||_1,1 and sum(||V_k||_2) are verified
  822. # instead of comparing directly U and V.
  823. assert_allclose(np.matmul(U_64, V_64), np.matmul(U_32, V_32), rtol=rtol)
  824. assert_allclose(np.sum(np.abs(U_64)), np.sum(np.abs(U_32)), rtol=rtol)
  825. assert_allclose(np.sum(V_64**2), np.sum(V_32**2), rtol=rtol)
  826. # verify an obtained solution is not degenerate
  827. assert np.mean(U_64 != 0.0) > 0.05
  828. assert np.count_nonzero(U_64 != 0.0) == np.count_nonzero(U_32 != 0.0)
  829. @pytest.mark.parametrize(
  830. "estimator",
  831. [
  832. SparseCoder(X.T),
  833. DictionaryLearning(),
  834. MiniBatchDictionaryLearning(batch_size=4, max_iter=10),
  835. ],
  836. ids=lambda x: x.__class__.__name__,
  837. )
  838. def test_get_feature_names_out(estimator):
  839. """Check feature names for dict learning estimators."""
  840. estimator.fit(X)
  841. n_components = X.shape[1]
  842. feature_names_out = estimator.get_feature_names_out()
  843. estimator_name = estimator.__class__.__name__.lower()
  844. assert_array_equal(
  845. feature_names_out,
  846. [f"{estimator_name}{i}" for i in range(n_components)],
  847. )
  848. def test_cd_work_on_joblib_memmapped_data(monkeypatch):
  849. monkeypatch.setattr(
  850. sklearn.decomposition._dict_learning,
  851. "Parallel",
  852. partial(Parallel, max_nbytes=100),
  853. )
  854. rng = np.random.RandomState(0)
  855. X_train = rng.randn(10, 10)
  856. dict_learner = DictionaryLearning(
  857. n_components=5,
  858. random_state=0,
  859. n_jobs=2,
  860. fit_algorithm="cd",
  861. max_iter=50,
  862. verbose=True,
  863. )
  864. # This must run and complete without error.
  865. dict_learner.fit(X_train)
  866. # TODO(1.4) remove
  867. def test_minibatch_dictionary_learning_warns_and_ignore_n_iter():
  868. """Check that we always raise a warning when `n_iter` is set even if it is
  869. ignored if `max_iter` is set.
  870. """
  871. warn_msg = "'n_iter' is deprecated in version 1.1"
  872. with pytest.warns(FutureWarning, match=warn_msg):
  873. model = MiniBatchDictionaryLearning(batch_size=256, n_iter=2, max_iter=2).fit(X)
  874. assert model.n_iter_ == 2