test_base.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822
  1. # Author: Gael Varoquaux
  2. # License: BSD 3 clause
  3. import pickle
  4. import re
  5. import warnings
  6. import numpy as np
  7. import pytest
  8. import scipy.sparse as sp
  9. from numpy.testing import assert_allclose
  10. import sklearn
  11. from sklearn import config_context, datasets
  12. from sklearn.base import BaseEstimator, TransformerMixin, clone, is_classifier
  13. from sklearn.decomposition import PCA
  14. from sklearn.exceptions import InconsistentVersionWarning
  15. from sklearn.model_selection import GridSearchCV
  16. from sklearn.pipeline import Pipeline
  17. from sklearn.preprocessing import StandardScaler
  18. from sklearn.svm import SVC
  19. from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
  20. from sklearn.utils._mocking import MockDataFrame
  21. from sklearn.utils._set_output import _get_output_config
  22. from sklearn.utils._testing import (
  23. assert_array_equal,
  24. assert_no_warnings,
  25. ignore_warnings,
  26. )
  27. #############################################################################
  28. # A few test classes
  29. class MyEstimator(BaseEstimator):
  30. def __init__(self, l1=0, empty=None):
  31. self.l1 = l1
  32. self.empty = empty
  33. class K(BaseEstimator):
  34. def __init__(self, c=None, d=None):
  35. self.c = c
  36. self.d = d
  37. class T(BaseEstimator):
  38. def __init__(self, a=None, b=None):
  39. self.a = a
  40. self.b = b
  41. class NaNTag(BaseEstimator):
  42. def _more_tags(self):
  43. return {"allow_nan": True}
  44. class NoNaNTag(BaseEstimator):
  45. def _more_tags(self):
  46. return {"allow_nan": False}
  47. class OverrideTag(NaNTag):
  48. def _more_tags(self):
  49. return {"allow_nan": False}
  50. class DiamondOverwriteTag(NaNTag, NoNaNTag):
  51. def _more_tags(self):
  52. return dict()
  53. class InheritDiamondOverwriteTag(DiamondOverwriteTag):
  54. pass
  55. class ModifyInitParams(BaseEstimator):
  56. """Deprecated behavior.
  57. Equal parameters but with a type cast.
  58. Doesn't fulfill a is a
  59. """
  60. def __init__(self, a=np.array([0])):
  61. self.a = a.copy()
  62. class Buggy(BaseEstimator):
  63. "A buggy estimator that does not set its parameters right."
  64. def __init__(self, a=None):
  65. self.a = 1
  66. class NoEstimator:
  67. def __init__(self):
  68. pass
  69. def fit(self, X=None, y=None):
  70. return self
  71. def predict(self, X=None):
  72. return None
  73. class VargEstimator(BaseEstimator):
  74. """scikit-learn estimators shouldn't have vargs."""
  75. def __init__(self, *vargs):
  76. pass
  77. #############################################################################
  78. # The tests
  79. def test_clone():
  80. # Tests that clone creates a correct deep copy.
  81. # We create an estimator, make a copy of its original state
  82. # (which, in this case, is the current state of the estimator),
  83. # and check that the obtained copy is a correct deep copy.
  84. from sklearn.feature_selection import SelectFpr, f_classif
  85. selector = SelectFpr(f_classif, alpha=0.1)
  86. new_selector = clone(selector)
  87. assert selector is not new_selector
  88. assert selector.get_params() == new_selector.get_params()
  89. selector = SelectFpr(f_classif, alpha=np.zeros((10, 2)))
  90. new_selector = clone(selector)
  91. assert selector is not new_selector
  92. def test_clone_2():
  93. # Tests that clone doesn't copy everything.
  94. # We first create an estimator, give it an own attribute, and
  95. # make a copy of its original state. Then we check that the copy doesn't
  96. # have the specific attribute we manually added to the initial estimator.
  97. from sklearn.feature_selection import SelectFpr, f_classif
  98. selector = SelectFpr(f_classif, alpha=0.1)
  99. selector.own_attribute = "test"
  100. new_selector = clone(selector)
  101. assert not hasattr(new_selector, "own_attribute")
  102. def test_clone_buggy():
  103. # Check that clone raises an error on buggy estimators.
  104. buggy = Buggy()
  105. buggy.a = 2
  106. with pytest.raises(RuntimeError):
  107. clone(buggy)
  108. no_estimator = NoEstimator()
  109. with pytest.raises(TypeError):
  110. clone(no_estimator)
  111. varg_est = VargEstimator()
  112. with pytest.raises(RuntimeError):
  113. clone(varg_est)
  114. est = ModifyInitParams()
  115. with pytest.raises(RuntimeError):
  116. clone(est)
  117. def test_clone_empty_array():
  118. # Regression test for cloning estimators with empty arrays
  119. clf = MyEstimator(empty=np.array([]))
  120. clf2 = clone(clf)
  121. assert_array_equal(clf.empty, clf2.empty)
  122. clf = MyEstimator(empty=sp.csr_matrix(np.array([[0]])))
  123. clf2 = clone(clf)
  124. assert_array_equal(clf.empty.data, clf2.empty.data)
  125. def test_clone_nan():
  126. # Regression test for cloning estimators with default parameter as np.nan
  127. clf = MyEstimator(empty=np.nan)
  128. clf2 = clone(clf)
  129. assert clf.empty is clf2.empty
  130. def test_clone_dict():
  131. # test that clone creates a clone of a dict
  132. orig = {"a": MyEstimator()}
  133. cloned = clone(orig)
  134. assert orig["a"] is not cloned["a"]
  135. def test_clone_sparse_matrices():
  136. sparse_matrix_classes = [
  137. cls
  138. for name in dir(sp)
  139. if name.endswith("_matrix") and type(cls := getattr(sp, name)) is type
  140. ]
  141. for cls in sparse_matrix_classes:
  142. sparse_matrix = cls(np.eye(5))
  143. clf = MyEstimator(empty=sparse_matrix)
  144. clf_cloned = clone(clf)
  145. assert clf.empty.__class__ is clf_cloned.empty.__class__
  146. assert_array_equal(clf.empty.toarray(), clf_cloned.empty.toarray())
  147. def test_clone_estimator_types():
  148. # Check that clone works for parameters that are types rather than
  149. # instances
  150. clf = MyEstimator(empty=MyEstimator)
  151. clf2 = clone(clf)
  152. assert clf.empty is clf2.empty
  153. def test_clone_class_rather_than_instance():
  154. # Check that clone raises expected error message when
  155. # cloning class rather than instance
  156. msg = "You should provide an instance of scikit-learn estimator"
  157. with pytest.raises(TypeError, match=msg):
  158. clone(MyEstimator)
  159. def test_repr():
  160. # Smoke test the repr of the base estimator.
  161. my_estimator = MyEstimator()
  162. repr(my_estimator)
  163. test = T(K(), K())
  164. assert repr(test) == "T(a=K(), b=K())"
  165. some_est = T(a=["long_params"] * 1000)
  166. assert len(repr(some_est)) == 485
  167. def test_str():
  168. # Smoke test the str of the base estimator
  169. my_estimator = MyEstimator()
  170. str(my_estimator)
  171. def test_get_params():
  172. test = T(K(), K)
  173. assert "a__d" in test.get_params(deep=True)
  174. assert "a__d" not in test.get_params(deep=False)
  175. test.set_params(a__d=2)
  176. assert test.a.d == 2
  177. with pytest.raises(ValueError):
  178. test.set_params(a__a=2)
  179. def test_is_classifier():
  180. svc = SVC()
  181. assert is_classifier(svc)
  182. assert is_classifier(GridSearchCV(svc, {"C": [0.1, 1]}))
  183. assert is_classifier(Pipeline([("svc", svc)]))
  184. assert is_classifier(Pipeline([("svc_cv", GridSearchCV(svc, {"C": [0.1, 1]}))]))
  185. def test_set_params():
  186. # test nested estimator parameter setting
  187. clf = Pipeline([("svc", SVC())])
  188. # non-existing parameter in svc
  189. with pytest.raises(ValueError):
  190. clf.set_params(svc__stupid_param=True)
  191. # non-existing parameter of pipeline
  192. with pytest.raises(ValueError):
  193. clf.set_params(svm__stupid_param=True)
  194. # we don't currently catch if the things in pipeline are estimators
  195. # bad_pipeline = Pipeline([("bad", NoEstimator())])
  196. # assert_raises(AttributeError, bad_pipeline.set_params,
  197. # bad__stupid_param=True)
  198. def test_set_params_passes_all_parameters():
  199. # Make sure all parameters are passed together to set_params
  200. # of nested estimator. Regression test for #9944
  201. class TestDecisionTree(DecisionTreeClassifier):
  202. def set_params(self, **kwargs):
  203. super().set_params(**kwargs)
  204. # expected_kwargs is in test scope
  205. assert kwargs == expected_kwargs
  206. return self
  207. expected_kwargs = {"max_depth": 5, "min_samples_leaf": 2}
  208. for est in [
  209. Pipeline([("estimator", TestDecisionTree())]),
  210. GridSearchCV(TestDecisionTree(), {}),
  211. ]:
  212. est.set_params(estimator__max_depth=5, estimator__min_samples_leaf=2)
  213. def test_set_params_updates_valid_params():
  214. # Check that set_params tries to set SVC().C, not
  215. # DecisionTreeClassifier().C
  216. gscv = GridSearchCV(DecisionTreeClassifier(), {})
  217. gscv.set_params(estimator=SVC(), estimator__C=42.0)
  218. assert gscv.estimator.C == 42.0
  219. @pytest.mark.parametrize(
  220. "tree,dataset",
  221. [
  222. (
  223. DecisionTreeClassifier(max_depth=2, random_state=0),
  224. datasets.make_classification(random_state=0),
  225. ),
  226. (
  227. DecisionTreeRegressor(max_depth=2, random_state=0),
  228. datasets.make_regression(random_state=0),
  229. ),
  230. ],
  231. )
  232. def test_score_sample_weight(tree, dataset):
  233. rng = np.random.RandomState(0)
  234. # check that the score with and without sample weights are different
  235. X, y = dataset
  236. tree.fit(X, y)
  237. # generate random sample weights
  238. sample_weight = rng.randint(1, 10, size=len(y))
  239. score_unweighted = tree.score(X, y)
  240. score_weighted = tree.score(X, y, sample_weight=sample_weight)
  241. msg = "Unweighted and weighted scores are unexpectedly equal"
  242. assert score_unweighted != score_weighted, msg
  243. def test_clone_pandas_dataframe():
  244. class DummyEstimator(TransformerMixin, BaseEstimator):
  245. """This is a dummy class for generating numerical features
  246. This feature extractor extracts numerical features from pandas data
  247. frame.
  248. Parameters
  249. ----------
  250. df: pandas data frame
  251. The pandas data frame parameter.
  252. Notes
  253. -----
  254. """
  255. def __init__(self, df=None, scalar_param=1):
  256. self.df = df
  257. self.scalar_param = scalar_param
  258. def fit(self, X, y=None):
  259. pass
  260. def transform(self, X):
  261. pass
  262. # build and clone estimator
  263. d = np.arange(10)
  264. df = MockDataFrame(d)
  265. e = DummyEstimator(df, scalar_param=1)
  266. cloned_e = clone(e)
  267. # the test
  268. assert (e.df == cloned_e.df).values.all()
  269. assert e.scalar_param == cloned_e.scalar_param
  270. def test_clone_protocol():
  271. """Checks that clone works with `__sklearn_clone__` protocol."""
  272. class FrozenEstimator(BaseEstimator):
  273. def __init__(self, fitted_estimator):
  274. self.fitted_estimator = fitted_estimator
  275. def __getattr__(self, name):
  276. return getattr(self.fitted_estimator, name)
  277. def __sklearn_clone__(self):
  278. return self
  279. def fit(self, *args, **kwargs):
  280. return self
  281. def fit_transform(self, *args, **kwargs):
  282. return self.fitted_estimator.transform(*args, **kwargs)
  283. X = np.array([[-1, -1], [-2, -1], [-3, -2]])
  284. pca = PCA().fit(X)
  285. components = pca.components_
  286. frozen_pca = FrozenEstimator(pca)
  287. assert_allclose(frozen_pca.components_, components)
  288. # Calling PCA methods such as `get_feature_names_out` still works
  289. assert_array_equal(frozen_pca.get_feature_names_out(), pca.get_feature_names_out())
  290. # Fitting on a new data does not alter `components_`
  291. X_new = np.asarray([[-1, 2], [3, 4], [1, 2]])
  292. frozen_pca.fit(X_new)
  293. assert_allclose(frozen_pca.components_, components)
  294. # `fit_transform` does not alter state
  295. frozen_pca.fit_transform(X_new)
  296. assert_allclose(frozen_pca.components_, components)
  297. # Cloning estimator is a no-op
  298. clone_frozen_pca = clone(frozen_pca)
  299. assert clone_frozen_pca is frozen_pca
  300. assert_allclose(clone_frozen_pca.components_, components)
  301. def test_pickle_version_warning_is_not_raised_with_matching_version():
  302. iris = datasets.load_iris()
  303. tree = DecisionTreeClassifier().fit(iris.data, iris.target)
  304. tree_pickle = pickle.dumps(tree)
  305. assert b"version" in tree_pickle
  306. tree_restored = assert_no_warnings(pickle.loads, tree_pickle)
  307. # test that we can predict with the restored decision tree classifier
  308. score_of_original = tree.score(iris.data, iris.target)
  309. score_of_restored = tree_restored.score(iris.data, iris.target)
  310. assert score_of_original == score_of_restored
  311. class TreeBadVersion(DecisionTreeClassifier):
  312. def __getstate__(self):
  313. return dict(self.__dict__.items(), _sklearn_version="something")
  314. pickle_error_message = (
  315. "Trying to unpickle estimator {estimator} from "
  316. "version {old_version} when using version "
  317. "{current_version}. This might "
  318. "lead to breaking code or invalid results. "
  319. "Use at your own risk."
  320. )
  321. def test_pickle_version_warning_is_issued_upon_different_version():
  322. iris = datasets.load_iris()
  323. tree = TreeBadVersion().fit(iris.data, iris.target)
  324. tree_pickle_other = pickle.dumps(tree)
  325. message = pickle_error_message.format(
  326. estimator="TreeBadVersion",
  327. old_version="something",
  328. current_version=sklearn.__version__,
  329. )
  330. with pytest.warns(UserWarning, match=message) as warning_record:
  331. pickle.loads(tree_pickle_other)
  332. message = warning_record.list[0].message
  333. assert isinstance(message, InconsistentVersionWarning)
  334. assert message.estimator_name == "TreeBadVersion"
  335. assert message.original_sklearn_version == "something"
  336. assert message.current_sklearn_version == sklearn.__version__
  337. class TreeNoVersion(DecisionTreeClassifier):
  338. def __getstate__(self):
  339. return self.__dict__
  340. def test_pickle_version_warning_is_issued_when_no_version_info_in_pickle():
  341. iris = datasets.load_iris()
  342. # TreeNoVersion has no getstate, like pre-0.18
  343. tree = TreeNoVersion().fit(iris.data, iris.target)
  344. tree_pickle_noversion = pickle.dumps(tree)
  345. assert b"version" not in tree_pickle_noversion
  346. message = pickle_error_message.format(
  347. estimator="TreeNoVersion",
  348. old_version="pre-0.18",
  349. current_version=sklearn.__version__,
  350. )
  351. # check we got the warning about using pre-0.18 pickle
  352. with pytest.warns(UserWarning, match=message):
  353. pickle.loads(tree_pickle_noversion)
  354. def test_pickle_version_no_warning_is_issued_with_non_sklearn_estimator():
  355. iris = datasets.load_iris()
  356. tree = TreeNoVersion().fit(iris.data, iris.target)
  357. tree_pickle_noversion = pickle.dumps(tree)
  358. try:
  359. module_backup = TreeNoVersion.__module__
  360. TreeNoVersion.__module__ = "notsklearn"
  361. assert_no_warnings(pickle.loads, tree_pickle_noversion)
  362. finally:
  363. TreeNoVersion.__module__ = module_backup
  364. class DontPickleAttributeMixin:
  365. def __getstate__(self):
  366. data = self.__dict__.copy()
  367. data["_attribute_not_pickled"] = None
  368. return data
  369. def __setstate__(self, state):
  370. state["_restored"] = True
  371. self.__dict__.update(state)
  372. class MultiInheritanceEstimator(DontPickleAttributeMixin, BaseEstimator):
  373. def __init__(self, attribute_pickled=5):
  374. self.attribute_pickled = attribute_pickled
  375. self._attribute_not_pickled = None
  376. def test_pickling_when_getstate_is_overwritten_by_mixin():
  377. estimator = MultiInheritanceEstimator()
  378. estimator._attribute_not_pickled = "this attribute should not be pickled"
  379. serialized = pickle.dumps(estimator)
  380. estimator_restored = pickle.loads(serialized)
  381. assert estimator_restored.attribute_pickled == 5
  382. assert estimator_restored._attribute_not_pickled is None
  383. assert estimator_restored._restored
  384. def test_pickling_when_getstate_is_overwritten_by_mixin_outside_of_sklearn():
  385. try:
  386. estimator = MultiInheritanceEstimator()
  387. text = "this attribute should not be pickled"
  388. estimator._attribute_not_pickled = text
  389. old_mod = type(estimator).__module__
  390. type(estimator).__module__ = "notsklearn"
  391. serialized = estimator.__getstate__()
  392. assert serialized == {"_attribute_not_pickled": None, "attribute_pickled": 5}
  393. serialized["attribute_pickled"] = 4
  394. estimator.__setstate__(serialized)
  395. assert estimator.attribute_pickled == 4
  396. assert estimator._restored
  397. finally:
  398. type(estimator).__module__ = old_mod
  399. class SingleInheritanceEstimator(BaseEstimator):
  400. def __init__(self, attribute_pickled=5):
  401. self.attribute_pickled = attribute_pickled
  402. self._attribute_not_pickled = None
  403. def __getstate__(self):
  404. data = self.__dict__.copy()
  405. data["_attribute_not_pickled"] = None
  406. return data
  407. @ignore_warnings(category=(UserWarning))
  408. def test_pickling_works_when_getstate_is_overwritten_in_the_child_class():
  409. estimator = SingleInheritanceEstimator()
  410. estimator._attribute_not_pickled = "this attribute should not be pickled"
  411. serialized = pickle.dumps(estimator)
  412. estimator_restored = pickle.loads(serialized)
  413. assert estimator_restored.attribute_pickled == 5
  414. assert estimator_restored._attribute_not_pickled is None
  415. def test_tag_inheritance():
  416. # test that changing tags by inheritance is not allowed
  417. nan_tag_est = NaNTag()
  418. no_nan_tag_est = NoNaNTag()
  419. assert nan_tag_est._get_tags()["allow_nan"]
  420. assert not no_nan_tag_est._get_tags()["allow_nan"]
  421. redefine_tags_est = OverrideTag()
  422. assert not redefine_tags_est._get_tags()["allow_nan"]
  423. diamond_tag_est = DiamondOverwriteTag()
  424. assert diamond_tag_est._get_tags()["allow_nan"]
  425. inherit_diamond_tag_est = InheritDiamondOverwriteTag()
  426. assert inherit_diamond_tag_est._get_tags()["allow_nan"]
  427. def test_raises_on_get_params_non_attribute():
  428. class MyEstimator(BaseEstimator):
  429. def __init__(self, param=5):
  430. pass
  431. def fit(self, X, y=None):
  432. return self
  433. est = MyEstimator()
  434. msg = "'MyEstimator' object has no attribute 'param'"
  435. with pytest.raises(AttributeError, match=msg):
  436. est.get_params()
  437. def test_repr_mimebundle_():
  438. # Checks the display configuration flag controls the json output
  439. tree = DecisionTreeClassifier()
  440. output = tree._repr_mimebundle_()
  441. assert "text/plain" in output
  442. assert "text/html" in output
  443. with config_context(display="text"):
  444. output = tree._repr_mimebundle_()
  445. assert "text/plain" in output
  446. assert "text/html" not in output
  447. def test_repr_html_wraps():
  448. # Checks the display configuration flag controls the html output
  449. tree = DecisionTreeClassifier()
  450. output = tree._repr_html_()
  451. assert "<style>" in output
  452. with config_context(display="text"):
  453. msg = "_repr_html_ is only defined when"
  454. with pytest.raises(AttributeError, match=msg):
  455. output = tree._repr_html_()
  456. def test_n_features_in_validation():
  457. """Check that `_check_n_features` validates data when reset=False"""
  458. est = MyEstimator()
  459. X_train = [[1, 2, 3], [4, 5, 6]]
  460. est._check_n_features(X_train, reset=True)
  461. assert est.n_features_in_ == 3
  462. msg = "X does not contain any features, but MyEstimator is expecting 3 features"
  463. with pytest.raises(ValueError, match=msg):
  464. est._check_n_features("invalid X", reset=False)
  465. def test_n_features_in_no_validation():
  466. """Check that `_check_n_features` does not validate data when
  467. n_features_in_ is not defined."""
  468. est = MyEstimator()
  469. est._check_n_features("invalid X", reset=True)
  470. assert not hasattr(est, "n_features_in_")
  471. # does not raise
  472. est._check_n_features("invalid X", reset=False)
  473. def test_feature_names_in():
  474. """Check that feature_name_in are recorded by `_validate_data`"""
  475. pd = pytest.importorskip("pandas")
  476. iris = datasets.load_iris()
  477. X_np = iris.data
  478. df = pd.DataFrame(X_np, columns=iris.feature_names)
  479. class NoOpTransformer(TransformerMixin, BaseEstimator):
  480. def fit(self, X, y=None):
  481. self._validate_data(X)
  482. return self
  483. def transform(self, X):
  484. self._validate_data(X, reset=False)
  485. return X
  486. # fit on dataframe saves the feature names
  487. trans = NoOpTransformer().fit(df)
  488. assert_array_equal(trans.feature_names_in_, df.columns)
  489. # fit again but on ndarray does not keep the previous feature names (see #21383)
  490. trans.fit(X_np)
  491. assert not hasattr(trans, "feature_names_in_")
  492. trans.fit(df)
  493. msg = "The feature names should match those that were passed"
  494. df_bad = pd.DataFrame(X_np, columns=iris.feature_names[::-1])
  495. with pytest.raises(ValueError, match=msg):
  496. trans.transform(df_bad)
  497. # warns when fitted on dataframe and transforming a ndarray
  498. msg = (
  499. "X does not have valid feature names, but NoOpTransformer was "
  500. "fitted with feature names"
  501. )
  502. with pytest.warns(UserWarning, match=msg):
  503. trans.transform(X_np)
  504. # warns when fitted on a ndarray and transforming dataframe
  505. msg = "X has feature names, but NoOpTransformer was fitted without feature names"
  506. trans = NoOpTransformer().fit(X_np)
  507. with pytest.warns(UserWarning, match=msg):
  508. trans.transform(df)
  509. # fit on dataframe with all integer feature names works without warning
  510. df_int_names = pd.DataFrame(X_np)
  511. trans = NoOpTransformer()
  512. with warnings.catch_warnings():
  513. warnings.simplefilter("error", UserWarning)
  514. trans.fit(df_int_names)
  515. # fit on dataframe with no feature names or all integer feature names
  516. # -> do not warn on transform
  517. Xs = [X_np, df_int_names]
  518. for X in Xs:
  519. with warnings.catch_warnings():
  520. warnings.simplefilter("error", UserWarning)
  521. trans.transform(X)
  522. # fit on dataframe with feature names that are mixed raises an error:
  523. df_mixed = pd.DataFrame(X_np, columns=["a", "b", 1, 2])
  524. trans = NoOpTransformer()
  525. msg = re.escape(
  526. "Feature names are only supported if all input features have string names, "
  527. "but your input has ['int', 'str'] as feature name / column name types. "
  528. "If you want feature names to be stored and validated, you must convert "
  529. "them all to strings, by using X.columns = X.columns.astype(str) for "
  530. "example. Otherwise you can remove feature / column names from your input "
  531. "data, or convert them all to a non-string data type."
  532. )
  533. with pytest.raises(TypeError, match=msg):
  534. trans.fit(df_mixed)
  535. # transform on feature names that are mixed also raises:
  536. with pytest.raises(TypeError, match=msg):
  537. trans.transform(df_mixed)
  538. def test_validate_data_cast_to_ndarray():
  539. """Check cast_to_ndarray option of _validate_data."""
  540. pd = pytest.importorskip("pandas")
  541. iris = datasets.load_iris()
  542. df = pd.DataFrame(iris.data, columns=iris.feature_names)
  543. y = pd.Series(iris.target)
  544. class NoOpTransformer(TransformerMixin, BaseEstimator):
  545. pass
  546. no_op = NoOpTransformer()
  547. X_np_out = no_op._validate_data(df, cast_to_ndarray=True)
  548. assert isinstance(X_np_out, np.ndarray)
  549. assert_allclose(X_np_out, df.to_numpy())
  550. X_df_out = no_op._validate_data(df, cast_to_ndarray=False)
  551. assert X_df_out is df
  552. y_np_out = no_op._validate_data(y=y, cast_to_ndarray=True)
  553. assert isinstance(y_np_out, np.ndarray)
  554. assert_allclose(y_np_out, y.to_numpy())
  555. y_series_out = no_op._validate_data(y=y, cast_to_ndarray=False)
  556. assert y_series_out is y
  557. X_np_out, y_np_out = no_op._validate_data(df, y, cast_to_ndarray=True)
  558. assert isinstance(X_np_out, np.ndarray)
  559. assert_allclose(X_np_out, df.to_numpy())
  560. assert isinstance(y_np_out, np.ndarray)
  561. assert_allclose(y_np_out, y.to_numpy())
  562. X_df_out, y_series_out = no_op._validate_data(df, y, cast_to_ndarray=False)
  563. assert X_df_out is df
  564. assert y_series_out is y
  565. msg = "Validation should be done on X, y or both."
  566. with pytest.raises(ValueError, match=msg):
  567. no_op._validate_data()
  568. def test_clone_keeps_output_config():
  569. """Check that clone keeps the set_output config."""
  570. ss = StandardScaler().set_output(transform="pandas")
  571. config = _get_output_config("transform", ss)
  572. ss_clone = clone(ss)
  573. config_clone = _get_output_config("transform", ss_clone)
  574. assert config == config_clone
  575. class _Empty:
  576. pass
  577. class EmptyEstimator(_Empty, BaseEstimator):
  578. pass
  579. @pytest.mark.parametrize("estimator", [BaseEstimator(), EmptyEstimator()])
  580. def test_estimator_empty_instance_dict(estimator):
  581. """Check that ``__getstate__`` returns an empty ``dict`` with an empty
  582. instance.
  583. Python 3.11+ changed behaviour by returning ``None`` instead of raising an
  584. ``AttributeError``. Non-regression test for gh-25188.
  585. """
  586. state = estimator.__getstate__()
  587. expected = {"_sklearn_version": sklearn.__version__}
  588. assert state == expected
  589. # this should not raise
  590. pickle.loads(pickle.dumps(BaseEstimator()))
  591. def test_estimator_getstate_using_slots_error_message():
  592. """Using a `BaseEstimator` with `__slots__` is not supported."""
  593. class WithSlots:
  594. __slots__ = ("x",)
  595. class Estimator(BaseEstimator, WithSlots):
  596. pass
  597. msg = (
  598. "You cannot use `__slots__` in objects inheriting from "
  599. "`sklearn.base.BaseEstimator`"
  600. )
  601. with pytest.raises(TypeError, match=msg):
  602. Estimator().__getstate__()
  603. with pytest.raises(TypeError, match=msg):
  604. pickle.dumps(Estimator())