test_omp.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. # Author: Vlad Niculae
  2. # License: BSD 3 clause
  3. import warnings
  4. import numpy as np
  5. import pytest
  6. from sklearn.datasets import make_sparse_coded_signal
  7. from sklearn.linear_model import (
  8. LinearRegression,
  9. OrthogonalMatchingPursuit,
  10. OrthogonalMatchingPursuitCV,
  11. orthogonal_mp,
  12. orthogonal_mp_gram,
  13. )
  14. from sklearn.utils import check_random_state
  15. from sklearn.utils._testing import (
  16. assert_allclose,
  17. assert_array_almost_equal,
  18. assert_array_equal,
  19. ignore_warnings,
  20. )
  21. n_samples, n_features, n_nonzero_coefs, n_targets = 25, 35, 5, 3
  22. y, X, gamma = make_sparse_coded_signal(
  23. n_samples=n_targets,
  24. n_components=n_features,
  25. n_features=n_samples,
  26. n_nonzero_coefs=n_nonzero_coefs,
  27. random_state=0,
  28. )
  29. y, X, gamma = y.T, X.T, gamma.T
  30. # Make X not of norm 1 for testing
  31. X *= 10
  32. y *= 10
  33. G, Xy = np.dot(X.T, X), np.dot(X.T, y)
  34. # this makes X (n_samples, n_features)
  35. # and y (n_samples, 3)
  36. # TODO(1.4): remove
  37. @pytest.mark.parametrize(
  38. "OmpModel", [OrthogonalMatchingPursuit, OrthogonalMatchingPursuitCV]
  39. )
  40. @pytest.mark.parametrize(
  41. "normalize, n_warnings", [(True, 1), (False, 1), ("deprecated", 0)]
  42. )
  43. def test_assure_warning_when_normalize(OmpModel, normalize, n_warnings):
  44. # check that we issue a FutureWarning when normalize was set
  45. rng = check_random_state(0)
  46. n_samples = 200
  47. n_features = 2
  48. X = rng.randn(n_samples, n_features)
  49. X[X < 0.1] = 0.0
  50. y = rng.rand(n_samples)
  51. model = OmpModel(normalize=normalize)
  52. with warnings.catch_warnings(record=True) as rec:
  53. warnings.simplefilter("always", FutureWarning)
  54. model.fit(X, y)
  55. assert len([w.message for w in rec]) == n_warnings
  56. def test_correct_shapes():
  57. assert orthogonal_mp(X, y[:, 0], n_nonzero_coefs=5).shape == (n_features,)
  58. assert orthogonal_mp(X, y, n_nonzero_coefs=5).shape == (n_features, 3)
  59. def test_correct_shapes_gram():
  60. assert orthogonal_mp_gram(G, Xy[:, 0], n_nonzero_coefs=5).shape == (n_features,)
  61. assert orthogonal_mp_gram(G, Xy, n_nonzero_coefs=5).shape == (n_features, 3)
  62. def test_n_nonzero_coefs():
  63. assert np.count_nonzero(orthogonal_mp(X, y[:, 0], n_nonzero_coefs=5)) <= 5
  64. assert (
  65. np.count_nonzero(orthogonal_mp(X, y[:, 0], n_nonzero_coefs=5, precompute=True))
  66. <= 5
  67. )
  68. def test_tol():
  69. tol = 0.5
  70. gamma = orthogonal_mp(X, y[:, 0], tol=tol)
  71. gamma_gram = orthogonal_mp(X, y[:, 0], tol=tol, precompute=True)
  72. assert np.sum((y[:, 0] - np.dot(X, gamma)) ** 2) <= tol
  73. assert np.sum((y[:, 0] - np.dot(X, gamma_gram)) ** 2) <= tol
  74. def test_with_without_gram():
  75. assert_array_almost_equal(
  76. orthogonal_mp(X, y, n_nonzero_coefs=5),
  77. orthogonal_mp(X, y, n_nonzero_coefs=5, precompute=True),
  78. )
  79. def test_with_without_gram_tol():
  80. assert_array_almost_equal(
  81. orthogonal_mp(X, y, tol=1.0), orthogonal_mp(X, y, tol=1.0, precompute=True)
  82. )
  83. def test_unreachable_accuracy():
  84. assert_array_almost_equal(
  85. orthogonal_mp(X, y, tol=0), orthogonal_mp(X, y, n_nonzero_coefs=n_features)
  86. )
  87. warning_message = (
  88. "Orthogonal matching pursuit ended prematurely "
  89. "due to linear dependence in the dictionary. "
  90. "The requested precision might not have been met."
  91. )
  92. with pytest.warns(RuntimeWarning, match=warning_message):
  93. assert_array_almost_equal(
  94. orthogonal_mp(X, y, tol=0, precompute=True),
  95. orthogonal_mp(X, y, precompute=True, n_nonzero_coefs=n_features),
  96. )
  97. @pytest.mark.parametrize("positional_params", [(X, y), (G, Xy)])
  98. @pytest.mark.parametrize(
  99. "keyword_params",
  100. [{"n_nonzero_coefs": n_features + 1}],
  101. )
  102. def test_bad_input(positional_params, keyword_params):
  103. with pytest.raises(ValueError):
  104. orthogonal_mp(*positional_params, **keyword_params)
  105. def test_perfect_signal_recovery():
  106. (idx,) = gamma[:, 0].nonzero()
  107. gamma_rec = orthogonal_mp(X, y[:, 0], n_nonzero_coefs=5)
  108. gamma_gram = orthogonal_mp_gram(G, Xy[:, 0], n_nonzero_coefs=5)
  109. assert_array_equal(idx, np.flatnonzero(gamma_rec))
  110. assert_array_equal(idx, np.flatnonzero(gamma_gram))
  111. assert_array_almost_equal(gamma[:, 0], gamma_rec, decimal=2)
  112. assert_array_almost_equal(gamma[:, 0], gamma_gram, decimal=2)
  113. def test_orthogonal_mp_gram_readonly():
  114. # Non-regression test for:
  115. # https://github.com/scikit-learn/scikit-learn/issues/5956
  116. (idx,) = gamma[:, 0].nonzero()
  117. G_readonly = G.copy()
  118. G_readonly.setflags(write=False)
  119. Xy_readonly = Xy.copy()
  120. Xy_readonly.setflags(write=False)
  121. gamma_gram = orthogonal_mp_gram(
  122. G_readonly, Xy_readonly[:, 0], n_nonzero_coefs=5, copy_Gram=False, copy_Xy=False
  123. )
  124. assert_array_equal(idx, np.flatnonzero(gamma_gram))
  125. assert_array_almost_equal(gamma[:, 0], gamma_gram, decimal=2)
  126. # TODO(1.4): 'normalize' to be removed
  127. @pytest.mark.filterwarnings("ignore:'normalize' was deprecated")
  128. def test_estimator():
  129. omp = OrthogonalMatchingPursuit(n_nonzero_coefs=n_nonzero_coefs)
  130. omp.fit(X, y[:, 0])
  131. assert omp.coef_.shape == (n_features,)
  132. assert omp.intercept_.shape == ()
  133. assert np.count_nonzero(omp.coef_) <= n_nonzero_coefs
  134. omp.fit(X, y)
  135. assert omp.coef_.shape == (n_targets, n_features)
  136. assert omp.intercept_.shape == (n_targets,)
  137. assert np.count_nonzero(omp.coef_) <= n_targets * n_nonzero_coefs
  138. coef_normalized = omp.coef_[0].copy()
  139. omp.set_params(fit_intercept=True)
  140. omp.fit(X, y[:, 0])
  141. assert_array_almost_equal(coef_normalized, omp.coef_)
  142. omp.set_params(fit_intercept=False)
  143. omp.fit(X, y[:, 0])
  144. assert np.count_nonzero(omp.coef_) <= n_nonzero_coefs
  145. assert omp.coef_.shape == (n_features,)
  146. assert omp.intercept_ == 0
  147. omp.fit(X, y)
  148. assert omp.coef_.shape == (n_targets, n_features)
  149. assert omp.intercept_ == 0
  150. assert np.count_nonzero(omp.coef_) <= n_targets * n_nonzero_coefs
  151. def test_identical_regressors():
  152. newX = X.copy()
  153. newX[:, 1] = newX[:, 0]
  154. gamma = np.zeros(n_features)
  155. gamma[0] = gamma[1] = 1.0
  156. newy = np.dot(newX, gamma)
  157. warning_message = (
  158. "Orthogonal matching pursuit ended prematurely "
  159. "due to linear dependence in the dictionary. "
  160. "The requested precision might not have been met."
  161. )
  162. with pytest.warns(RuntimeWarning, match=warning_message):
  163. orthogonal_mp(newX, newy, n_nonzero_coefs=2)
  164. def test_swapped_regressors():
  165. gamma = np.zeros(n_features)
  166. # X[:, 21] should be selected first, then X[:, 0] selected second,
  167. # which will take X[:, 21]'s place in case the algorithm does
  168. # column swapping for optimization (which is the case at the moment)
  169. gamma[21] = 1.0
  170. gamma[0] = 0.5
  171. new_y = np.dot(X, gamma)
  172. new_Xy = np.dot(X.T, new_y)
  173. gamma_hat = orthogonal_mp(X, new_y, n_nonzero_coefs=2)
  174. gamma_hat_gram = orthogonal_mp_gram(G, new_Xy, n_nonzero_coefs=2)
  175. assert_array_equal(np.flatnonzero(gamma_hat), [0, 21])
  176. assert_array_equal(np.flatnonzero(gamma_hat_gram), [0, 21])
  177. def test_no_atoms():
  178. y_empty = np.zeros_like(y)
  179. Xy_empty = np.dot(X.T, y_empty)
  180. gamma_empty = ignore_warnings(orthogonal_mp)(X, y_empty, n_nonzero_coefs=1)
  181. gamma_empty_gram = ignore_warnings(orthogonal_mp)(G, Xy_empty, n_nonzero_coefs=1)
  182. assert np.all(gamma_empty == 0)
  183. assert np.all(gamma_empty_gram == 0)
  184. def test_omp_path():
  185. path = orthogonal_mp(X, y, n_nonzero_coefs=5, return_path=True)
  186. last = orthogonal_mp(X, y, n_nonzero_coefs=5, return_path=False)
  187. assert path.shape == (n_features, n_targets, 5)
  188. assert_array_almost_equal(path[:, :, -1], last)
  189. path = orthogonal_mp_gram(G, Xy, n_nonzero_coefs=5, return_path=True)
  190. last = orthogonal_mp_gram(G, Xy, n_nonzero_coefs=5, return_path=False)
  191. assert path.shape == (n_features, n_targets, 5)
  192. assert_array_almost_equal(path[:, :, -1], last)
  193. def test_omp_return_path_prop_with_gram():
  194. path = orthogonal_mp(X, y, n_nonzero_coefs=5, return_path=True, precompute=True)
  195. last = orthogonal_mp(X, y, n_nonzero_coefs=5, return_path=False, precompute=True)
  196. assert path.shape == (n_features, n_targets, 5)
  197. assert_array_almost_equal(path[:, :, -1], last)
  198. # TODO(1.4): 'normalize' to be removed
  199. @pytest.mark.filterwarnings("ignore:'normalize' was deprecated")
  200. def test_omp_cv():
  201. y_ = y[:, 0]
  202. gamma_ = gamma[:, 0]
  203. ompcv = OrthogonalMatchingPursuitCV(
  204. normalize=True, fit_intercept=False, max_iter=10
  205. )
  206. ompcv.fit(X, y_)
  207. assert ompcv.n_nonzero_coefs_ == n_nonzero_coefs
  208. assert_array_almost_equal(ompcv.coef_, gamma_)
  209. omp = OrthogonalMatchingPursuit(
  210. normalize=True, fit_intercept=False, n_nonzero_coefs=ompcv.n_nonzero_coefs_
  211. )
  212. omp.fit(X, y_)
  213. assert_array_almost_equal(ompcv.coef_, omp.coef_)
  214. # TODO(1.4): 'normalize' to be removed
  215. @pytest.mark.filterwarnings("ignore:'normalize' was deprecated")
  216. def test_omp_reaches_least_squares():
  217. # Use small simple data; it's a sanity check but OMP can stop early
  218. rng = check_random_state(0)
  219. n_samples, n_features = (10, 8)
  220. n_targets = 3
  221. X = rng.randn(n_samples, n_features)
  222. Y = rng.randn(n_samples, n_targets)
  223. omp = OrthogonalMatchingPursuit(n_nonzero_coefs=n_features)
  224. lstsq = LinearRegression()
  225. omp.fit(X, Y)
  226. lstsq.fit(X, Y)
  227. assert_array_almost_equal(omp.coef_, lstsq.coef_)
  228. @pytest.mark.parametrize("data_type", (np.float32, np.float64))
  229. def test_omp_gram_dtype_match(data_type):
  230. # verify matching input data type and output data type
  231. coef = orthogonal_mp_gram(
  232. G.astype(data_type), Xy.astype(data_type), n_nonzero_coefs=5
  233. )
  234. assert coef.dtype == data_type
  235. def test_omp_gram_numerical_consistency():
  236. # verify numericaly consistency among np.float32 and np.float64
  237. coef_32 = orthogonal_mp_gram(
  238. G.astype(np.float32), Xy.astype(np.float32), n_nonzero_coefs=5
  239. )
  240. coef_64 = orthogonal_mp_gram(
  241. G.astype(np.float32), Xy.astype(np.float64), n_nonzero_coefs=5
  242. )
  243. assert_allclose(coef_32, coef_64)