test_from_model.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663
  1. import re
  2. import warnings
  3. from unittest.mock import Mock
  4. import numpy as np
  5. import pytest
  6. from sklearn import datasets
  7. from sklearn.base import BaseEstimator
  8. from sklearn.cross_decomposition import CCA, PLSCanonical, PLSRegression
  9. from sklearn.datasets import make_friedman1
  10. from sklearn.decomposition import PCA
  11. from sklearn.ensemble import HistGradientBoostingClassifier, RandomForestClassifier
  12. from sklearn.exceptions import NotFittedError
  13. from sklearn.feature_selection import SelectFromModel
  14. from sklearn.linear_model import (
  15. ElasticNet,
  16. ElasticNetCV,
  17. Lasso,
  18. LassoCV,
  19. LogisticRegression,
  20. PassiveAggressiveClassifier,
  21. SGDClassifier,
  22. )
  23. from sklearn.pipeline import make_pipeline
  24. from sklearn.svm import LinearSVC
  25. from sklearn.utils._testing import (
  26. MinimalClassifier,
  27. assert_allclose,
  28. assert_array_almost_equal,
  29. assert_array_equal,
  30. skip_if_32bit,
  31. )
  32. class NaNTag(BaseEstimator):
  33. def _more_tags(self):
  34. return {"allow_nan": True}
  35. class NoNaNTag(BaseEstimator):
  36. def _more_tags(self):
  37. return {"allow_nan": False}
  38. class NaNTagRandomForest(RandomForestClassifier):
  39. def _more_tags(self):
  40. return {"allow_nan": True}
  41. iris = datasets.load_iris()
  42. data, y = iris.data, iris.target
  43. rng = np.random.RandomState(0)
  44. def test_invalid_input():
  45. clf = SGDClassifier(
  46. alpha=0.1, max_iter=10, shuffle=True, random_state=None, tol=None
  47. )
  48. for threshold in ["gobbledigook", ".5 * gobbledigook"]:
  49. model = SelectFromModel(clf, threshold=threshold)
  50. model.fit(data, y)
  51. with pytest.raises(ValueError):
  52. model.transform(data)
  53. def test_input_estimator_unchanged():
  54. # Test that SelectFromModel fits on a clone of the estimator.
  55. est = RandomForestClassifier()
  56. transformer = SelectFromModel(estimator=est)
  57. transformer.fit(data, y)
  58. assert transformer.estimator is est
  59. @pytest.mark.parametrize(
  60. "max_features, err_type, err_msg",
  61. [
  62. (
  63. data.shape[1] + 1,
  64. ValueError,
  65. "max_features ==",
  66. ),
  67. (
  68. lambda X: 1.5,
  69. TypeError,
  70. "max_features must be an instance of int, not float.",
  71. ),
  72. (
  73. lambda X: data.shape[1] + 1,
  74. ValueError,
  75. "max_features ==",
  76. ),
  77. (
  78. lambda X: -1,
  79. ValueError,
  80. "max_features ==",
  81. ),
  82. ],
  83. )
  84. def test_max_features_error(max_features, err_type, err_msg):
  85. err_msg = re.escape(err_msg)
  86. clf = RandomForestClassifier(n_estimators=5, random_state=0)
  87. transformer = SelectFromModel(
  88. estimator=clf, max_features=max_features, threshold=-np.inf
  89. )
  90. with pytest.raises(err_type, match=err_msg):
  91. transformer.fit(data, y)
  92. @pytest.mark.parametrize("max_features", [0, 2, data.shape[1], None])
  93. def test_inferred_max_features_integer(max_features):
  94. """Check max_features_ and output shape for integer max_features."""
  95. clf = RandomForestClassifier(n_estimators=5, random_state=0)
  96. transformer = SelectFromModel(
  97. estimator=clf, max_features=max_features, threshold=-np.inf
  98. )
  99. X_trans = transformer.fit_transform(data, y)
  100. if max_features is not None:
  101. assert transformer.max_features_ == max_features
  102. assert X_trans.shape[1] == transformer.max_features_
  103. else:
  104. assert not hasattr(transformer, "max_features_")
  105. assert X_trans.shape[1] == data.shape[1]
  106. @pytest.mark.parametrize(
  107. "max_features",
  108. [lambda X: 1, lambda X: X.shape[1], lambda X: min(X.shape[1], 10000)],
  109. )
  110. def test_inferred_max_features_callable(max_features):
  111. """Check max_features_ and output shape for callable max_features."""
  112. clf = RandomForestClassifier(n_estimators=5, random_state=0)
  113. transformer = SelectFromModel(
  114. estimator=clf, max_features=max_features, threshold=-np.inf
  115. )
  116. X_trans = transformer.fit_transform(data, y)
  117. assert transformer.max_features_ == max_features(data)
  118. assert X_trans.shape[1] == transformer.max_features_
  119. @pytest.mark.parametrize("max_features", [lambda X: round(len(X[0]) / 2), 2])
  120. def test_max_features_array_like(max_features):
  121. X = [
  122. [0.87, -1.34, 0.31],
  123. [-2.79, -0.02, -0.85],
  124. [-1.34, -0.48, -2.55],
  125. [1.92, 1.48, 0.65],
  126. ]
  127. y = [0, 1, 0, 1]
  128. clf = RandomForestClassifier(n_estimators=5, random_state=0)
  129. transformer = SelectFromModel(
  130. estimator=clf, max_features=max_features, threshold=-np.inf
  131. )
  132. X_trans = transformer.fit_transform(X, y)
  133. assert X_trans.shape[1] == transformer.max_features_
  134. @pytest.mark.parametrize(
  135. "max_features",
  136. [lambda X: min(X.shape[1], 10000), lambda X: X.shape[1], lambda X: 1],
  137. )
  138. def test_max_features_callable_data(max_features):
  139. """Tests that the callable passed to `fit` is called on X."""
  140. clf = RandomForestClassifier(n_estimators=50, random_state=0)
  141. m = Mock(side_effect=max_features)
  142. transformer = SelectFromModel(estimator=clf, max_features=m, threshold=-np.inf)
  143. transformer.fit_transform(data, y)
  144. m.assert_called_with(data)
  145. class FixedImportanceEstimator(BaseEstimator):
  146. def __init__(self, importances):
  147. self.importances = importances
  148. def fit(self, X, y=None):
  149. self.feature_importances_ = np.array(self.importances)
  150. def test_max_features():
  151. # Test max_features parameter using various values
  152. X, y = datasets.make_classification(
  153. n_samples=1000,
  154. n_features=10,
  155. n_informative=3,
  156. n_redundant=0,
  157. n_repeated=0,
  158. shuffle=False,
  159. random_state=0,
  160. )
  161. max_features = X.shape[1]
  162. est = RandomForestClassifier(n_estimators=50, random_state=0)
  163. transformer1 = SelectFromModel(estimator=est, threshold=-np.inf)
  164. transformer2 = SelectFromModel(
  165. estimator=est, max_features=max_features, threshold=-np.inf
  166. )
  167. X_new1 = transformer1.fit_transform(X, y)
  168. X_new2 = transformer2.fit_transform(X, y)
  169. assert_allclose(X_new1, X_new2)
  170. # Test max_features against actual model.
  171. transformer1 = SelectFromModel(estimator=Lasso(alpha=0.025, random_state=42))
  172. X_new1 = transformer1.fit_transform(X, y)
  173. scores1 = np.abs(transformer1.estimator_.coef_)
  174. candidate_indices1 = np.argsort(-scores1, kind="mergesort")
  175. for n_features in range(1, X_new1.shape[1] + 1):
  176. transformer2 = SelectFromModel(
  177. estimator=Lasso(alpha=0.025, random_state=42),
  178. max_features=n_features,
  179. threshold=-np.inf,
  180. )
  181. X_new2 = transformer2.fit_transform(X, y)
  182. scores2 = np.abs(transformer2.estimator_.coef_)
  183. candidate_indices2 = np.argsort(-scores2, kind="mergesort")
  184. assert_allclose(
  185. X[:, candidate_indices1[:n_features]], X[:, candidate_indices2[:n_features]]
  186. )
  187. assert_allclose(transformer1.estimator_.coef_, transformer2.estimator_.coef_)
  188. def test_max_features_tiebreak():
  189. # Test if max_features can break tie among feature importance
  190. X, y = datasets.make_classification(
  191. n_samples=1000,
  192. n_features=10,
  193. n_informative=3,
  194. n_redundant=0,
  195. n_repeated=0,
  196. shuffle=False,
  197. random_state=0,
  198. )
  199. max_features = X.shape[1]
  200. feature_importances = np.array([4, 4, 4, 4, 3, 3, 3, 2, 2, 1])
  201. for n_features in range(1, max_features + 1):
  202. transformer = SelectFromModel(
  203. FixedImportanceEstimator(feature_importances),
  204. max_features=n_features,
  205. threshold=-np.inf,
  206. )
  207. X_new = transformer.fit_transform(X, y)
  208. selected_feature_indices = np.where(transformer._get_support_mask())[0]
  209. assert_array_equal(selected_feature_indices, np.arange(n_features))
  210. assert X_new.shape[1] == n_features
  211. def test_threshold_and_max_features():
  212. X, y = datasets.make_classification(
  213. n_samples=1000,
  214. n_features=10,
  215. n_informative=3,
  216. n_redundant=0,
  217. n_repeated=0,
  218. shuffle=False,
  219. random_state=0,
  220. )
  221. est = RandomForestClassifier(n_estimators=50, random_state=0)
  222. transformer1 = SelectFromModel(estimator=est, max_features=3, threshold=-np.inf)
  223. X_new1 = transformer1.fit_transform(X, y)
  224. transformer2 = SelectFromModel(estimator=est, threshold=0.04)
  225. X_new2 = transformer2.fit_transform(X, y)
  226. transformer3 = SelectFromModel(estimator=est, max_features=3, threshold=0.04)
  227. X_new3 = transformer3.fit_transform(X, y)
  228. assert X_new3.shape[1] == min(X_new1.shape[1], X_new2.shape[1])
  229. selected_indices = transformer3.transform(np.arange(X.shape[1])[np.newaxis, :])
  230. assert_allclose(X_new3, X[:, selected_indices[0]])
  231. @skip_if_32bit
  232. def test_feature_importances():
  233. X, y = datasets.make_classification(
  234. n_samples=1000,
  235. n_features=10,
  236. n_informative=3,
  237. n_redundant=0,
  238. n_repeated=0,
  239. shuffle=False,
  240. random_state=0,
  241. )
  242. est = RandomForestClassifier(n_estimators=50, random_state=0)
  243. for threshold, func in zip(["mean", "median"], [np.mean, np.median]):
  244. transformer = SelectFromModel(estimator=est, threshold=threshold)
  245. transformer.fit(X, y)
  246. assert hasattr(transformer.estimator_, "feature_importances_")
  247. X_new = transformer.transform(X)
  248. assert X_new.shape[1] < X.shape[1]
  249. importances = transformer.estimator_.feature_importances_
  250. feature_mask = np.abs(importances) > func(importances)
  251. assert_array_almost_equal(X_new, X[:, feature_mask])
  252. def test_sample_weight():
  253. # Ensure sample weights are passed to underlying estimator
  254. X, y = datasets.make_classification(
  255. n_samples=100,
  256. n_features=10,
  257. n_informative=3,
  258. n_redundant=0,
  259. n_repeated=0,
  260. shuffle=False,
  261. random_state=0,
  262. )
  263. # Check with sample weights
  264. sample_weight = np.ones(y.shape)
  265. sample_weight[y == 1] *= 100
  266. est = LogisticRegression(random_state=0, fit_intercept=False)
  267. transformer = SelectFromModel(estimator=est)
  268. transformer.fit(X, y, sample_weight=None)
  269. mask = transformer._get_support_mask()
  270. transformer.fit(X, y, sample_weight=sample_weight)
  271. weighted_mask = transformer._get_support_mask()
  272. assert not np.all(weighted_mask == mask)
  273. transformer.fit(X, y, sample_weight=3 * sample_weight)
  274. reweighted_mask = transformer._get_support_mask()
  275. assert np.all(weighted_mask == reweighted_mask)
  276. @pytest.mark.parametrize(
  277. "estimator",
  278. [
  279. Lasso(alpha=0.1, random_state=42),
  280. LassoCV(random_state=42),
  281. ElasticNet(l1_ratio=1, random_state=42),
  282. ElasticNetCV(l1_ratio=[1], random_state=42),
  283. ],
  284. )
  285. def test_coef_default_threshold(estimator):
  286. X, y = datasets.make_classification(
  287. n_samples=100,
  288. n_features=10,
  289. n_informative=3,
  290. n_redundant=0,
  291. n_repeated=0,
  292. shuffle=False,
  293. random_state=0,
  294. )
  295. # For the Lasso and related models, the threshold defaults to 1e-5
  296. transformer = SelectFromModel(estimator=estimator)
  297. transformer.fit(X, y)
  298. X_new = transformer.transform(X)
  299. mask = np.abs(transformer.estimator_.coef_) > 1e-5
  300. assert_array_almost_equal(X_new, X[:, mask])
  301. @skip_if_32bit
  302. def test_2d_coef():
  303. X, y = datasets.make_classification(
  304. n_samples=1000,
  305. n_features=10,
  306. n_informative=3,
  307. n_redundant=0,
  308. n_repeated=0,
  309. shuffle=False,
  310. random_state=0,
  311. n_classes=4,
  312. )
  313. est = LogisticRegression()
  314. for threshold, func in zip(["mean", "median"], [np.mean, np.median]):
  315. for order in [1, 2, np.inf]:
  316. # Fit SelectFromModel a multi-class problem
  317. transformer = SelectFromModel(
  318. estimator=LogisticRegression(), threshold=threshold, norm_order=order
  319. )
  320. transformer.fit(X, y)
  321. assert hasattr(transformer.estimator_, "coef_")
  322. X_new = transformer.transform(X)
  323. assert X_new.shape[1] < X.shape[1]
  324. # Manually check that the norm is correctly performed
  325. est.fit(X, y)
  326. importances = np.linalg.norm(est.coef_, axis=0, ord=order)
  327. feature_mask = importances > func(importances)
  328. assert_array_almost_equal(X_new, X[:, feature_mask])
  329. def test_partial_fit():
  330. est = PassiveAggressiveClassifier(
  331. random_state=0, shuffle=False, max_iter=5, tol=None
  332. )
  333. transformer = SelectFromModel(estimator=est)
  334. transformer.partial_fit(data, y, classes=np.unique(y))
  335. old_model = transformer.estimator_
  336. transformer.partial_fit(data, y, classes=np.unique(y))
  337. new_model = transformer.estimator_
  338. assert old_model is new_model
  339. X_transform = transformer.transform(data)
  340. transformer.fit(np.vstack((data, data)), np.concatenate((y, y)))
  341. assert_array_almost_equal(X_transform, transformer.transform(data))
  342. # check that if est doesn't have partial_fit, neither does SelectFromModel
  343. transformer = SelectFromModel(estimator=RandomForestClassifier())
  344. assert not hasattr(transformer, "partial_fit")
  345. def test_calling_fit_reinitializes():
  346. est = LinearSVC(dual="auto", random_state=0)
  347. transformer = SelectFromModel(estimator=est)
  348. transformer.fit(data, y)
  349. transformer.set_params(estimator__C=100)
  350. transformer.fit(data, y)
  351. assert transformer.estimator_.C == 100
  352. def test_prefit():
  353. # Test all possible combinations of the prefit parameter.
  354. # Passing a prefit parameter with the selected model
  355. # and fitting a unfit model with prefit=False should give same results.
  356. clf = SGDClassifier(alpha=0.1, max_iter=10, shuffle=True, random_state=0, tol=None)
  357. model = SelectFromModel(clf)
  358. model.fit(data, y)
  359. X_transform = model.transform(data)
  360. clf.fit(data, y)
  361. model = SelectFromModel(clf, prefit=True)
  362. assert_array_almost_equal(model.transform(data), X_transform)
  363. model.fit(data, y)
  364. assert model.estimator_ is not clf
  365. # Check that the model is rewritten if prefit=False and a fitted model is
  366. # passed
  367. model = SelectFromModel(clf, prefit=False)
  368. model.fit(data, y)
  369. assert_array_almost_equal(model.transform(data), X_transform)
  370. # Check that passing an unfitted estimator with `prefit=True` raises a
  371. # `ValueError`
  372. clf = SGDClassifier(alpha=0.1, max_iter=10, shuffle=True, random_state=0, tol=None)
  373. model = SelectFromModel(clf, prefit=True)
  374. err_msg = "When `prefit=True`, `estimator` is expected to be a fitted estimator."
  375. with pytest.raises(NotFittedError, match=err_msg):
  376. model.fit(data, y)
  377. with pytest.raises(NotFittedError, match=err_msg):
  378. model.partial_fit(data, y)
  379. with pytest.raises(NotFittedError, match=err_msg):
  380. model.transform(data)
  381. # Check that the internal parameters of prefitted model are not changed
  382. # when calling `fit` or `partial_fit` with `prefit=True`
  383. clf = SGDClassifier(alpha=0.1, max_iter=10, shuffle=True, tol=None).fit(data, y)
  384. model = SelectFromModel(clf, prefit=True)
  385. model.fit(data, y)
  386. assert_allclose(model.estimator_.coef_, clf.coef_)
  387. model.partial_fit(data, y)
  388. assert_allclose(model.estimator_.coef_, clf.coef_)
  389. def test_prefit_max_features():
  390. """Check the interaction between `prefit` and `max_features`."""
  391. # case 1: an error should be raised at `transform` if `fit` was not called to
  392. # validate the attributes
  393. estimator = RandomForestClassifier(n_estimators=5, random_state=0)
  394. estimator.fit(data, y)
  395. model = SelectFromModel(estimator, prefit=True, max_features=lambda X: X.shape[1])
  396. err_msg = (
  397. "When `prefit=True` and `max_features` is a callable, call `fit` "
  398. "before calling `transform`."
  399. )
  400. with pytest.raises(NotFittedError, match=err_msg):
  401. model.transform(data)
  402. # case 2: `max_features` is not validated and different from an integer
  403. # FIXME: we cannot validate the upper bound of the attribute at transform
  404. # and we should force calling `fit` if we intend to force the attribute
  405. # to have such an upper bound.
  406. max_features = 2.5
  407. model.set_params(max_features=max_features)
  408. with pytest.raises(ValueError, match="`max_features` must be an integer"):
  409. model.transform(data)
  410. def test_prefit_get_feature_names_out():
  411. """Check the interaction between prefit and the feature names."""
  412. clf = RandomForestClassifier(n_estimators=2, random_state=0)
  413. clf.fit(data, y)
  414. model = SelectFromModel(clf, prefit=True, max_features=1)
  415. name = type(model).__name__
  416. err_msg = (
  417. f"This {name} instance is not fitted yet. Call 'fit' with "
  418. "appropriate arguments before using this estimator."
  419. )
  420. with pytest.raises(NotFittedError, match=err_msg):
  421. model.get_feature_names_out()
  422. model.fit(data, y)
  423. feature_names = model.get_feature_names_out()
  424. assert feature_names == ["x3"]
  425. def test_threshold_string():
  426. est = RandomForestClassifier(n_estimators=50, random_state=0)
  427. model = SelectFromModel(est, threshold="0.5*mean")
  428. model.fit(data, y)
  429. X_transform = model.transform(data)
  430. # Calculate the threshold from the estimator directly.
  431. est.fit(data, y)
  432. threshold = 0.5 * np.mean(est.feature_importances_)
  433. mask = est.feature_importances_ > threshold
  434. assert_array_almost_equal(X_transform, data[:, mask])
  435. def test_threshold_without_refitting():
  436. # Test that the threshold can be set without refitting the model.
  437. clf = SGDClassifier(alpha=0.1, max_iter=10, shuffle=True, random_state=0, tol=None)
  438. model = SelectFromModel(clf, threshold="0.1 * mean")
  439. model.fit(data, y)
  440. X_transform = model.transform(data)
  441. # Set a higher threshold to filter out more features.
  442. model.threshold = "1.0 * mean"
  443. assert X_transform.shape[1] > model.transform(data).shape[1]
  444. def test_fit_accepts_nan_inf():
  445. # Test that fit doesn't check for np.inf and np.nan values.
  446. clf = HistGradientBoostingClassifier(random_state=0)
  447. model = SelectFromModel(estimator=clf)
  448. nan_data = data.copy()
  449. nan_data[0] = np.nan
  450. nan_data[1] = np.inf
  451. model.fit(data, y)
  452. def test_transform_accepts_nan_inf():
  453. # Test that transform doesn't check for np.inf and np.nan values.
  454. clf = NaNTagRandomForest(n_estimators=100, random_state=0)
  455. nan_data = data.copy()
  456. model = SelectFromModel(estimator=clf)
  457. model.fit(nan_data, y)
  458. nan_data[0] = np.nan
  459. nan_data[1] = np.inf
  460. model.transform(nan_data)
  461. def test_allow_nan_tag_comes_from_estimator():
  462. allow_nan_est = NaNTag()
  463. model = SelectFromModel(estimator=allow_nan_est)
  464. assert model._get_tags()["allow_nan"] is True
  465. no_nan_est = NoNaNTag()
  466. model = SelectFromModel(estimator=no_nan_est)
  467. assert model._get_tags()["allow_nan"] is False
  468. def _pca_importances(pca_estimator):
  469. return np.abs(pca_estimator.explained_variance_)
  470. @pytest.mark.parametrize(
  471. "estimator, importance_getter",
  472. [
  473. (
  474. make_pipeline(PCA(random_state=0), LogisticRegression()),
  475. "named_steps.logisticregression.coef_",
  476. ),
  477. (PCA(random_state=0), _pca_importances),
  478. ],
  479. )
  480. def test_importance_getter(estimator, importance_getter):
  481. selector = SelectFromModel(
  482. estimator, threshold="mean", importance_getter=importance_getter
  483. )
  484. selector.fit(data, y)
  485. assert selector.transform(data).shape[1] == 1
  486. @pytest.mark.parametrize("PLSEstimator", [CCA, PLSCanonical, PLSRegression])
  487. def test_select_from_model_pls(PLSEstimator):
  488. """Check the behaviour of SelectFromModel with PLS estimators.
  489. Non-regression test for:
  490. https://github.com/scikit-learn/scikit-learn/issues/12410
  491. """
  492. X, y = make_friedman1(n_samples=50, n_features=10, random_state=0)
  493. estimator = PLSEstimator(n_components=1)
  494. model = make_pipeline(SelectFromModel(estimator), estimator).fit(X, y)
  495. assert model.score(X, y) > 0.5
  496. def test_estimator_does_not_support_feature_names():
  497. """SelectFromModel works with estimators that do not support feature_names_in_.
  498. Non-regression test for #21949.
  499. """
  500. pytest.importorskip("pandas")
  501. X, y = datasets.load_iris(as_frame=True, return_X_y=True)
  502. all_feature_names = set(X.columns)
  503. def importance_getter(estimator):
  504. return np.arange(X.shape[1])
  505. selector = SelectFromModel(
  506. MinimalClassifier(), importance_getter=importance_getter
  507. ).fit(X, y)
  508. # selector learns the feature names itself
  509. assert_array_equal(selector.feature_names_in_, X.columns)
  510. feature_names_out = set(selector.get_feature_names_out())
  511. assert feature_names_out < all_feature_names
  512. with warnings.catch_warnings():
  513. warnings.simplefilter("error", UserWarning)
  514. selector.transform(X.iloc[1:3])
  515. @pytest.mark.parametrize(
  516. "error, err_msg, max_features",
  517. (
  518. [ValueError, "max_features == 10, must be <= 4", 10],
  519. [ValueError, "max_features == 5, must be <= 4", lambda x: x.shape[1] + 1],
  520. ),
  521. )
  522. def test_partial_fit_validate_max_features(error, err_msg, max_features):
  523. """Test that partial_fit from SelectFromModel validates `max_features`."""
  524. X, y = datasets.make_classification(
  525. n_samples=100,
  526. n_features=4,
  527. random_state=0,
  528. )
  529. with pytest.raises(error, match=err_msg):
  530. SelectFromModel(
  531. estimator=SGDClassifier(), max_features=max_features
  532. ).partial_fit(X, y, classes=[0, 1])
  533. @pytest.mark.parametrize("as_frame", [True, False])
  534. def test_partial_fit_validate_feature_names(as_frame):
  535. """Test that partial_fit from SelectFromModel validates `feature_names_in_`."""
  536. pytest.importorskip("pandas")
  537. X, y = datasets.load_iris(as_frame=as_frame, return_X_y=True)
  538. selector = SelectFromModel(estimator=SGDClassifier(), max_features=4).partial_fit(
  539. X, y, classes=[0, 1, 2]
  540. )
  541. if as_frame:
  542. assert_array_equal(selector.feature_names_in_, X.columns)
  543. else:
  544. assert not hasattr(selector, "feature_names_in_")