| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346 |
- from collections import namedtuple
- import numpy as np
- import pytest
- from numpy.testing import assert_array_equal
- from scipy.sparse import csr_matrix
- from sklearn._config import config_context, get_config
- from sklearn.utils._set_output import (
- _get_output_config,
- _safe_set_output,
- _SetOutputMixin,
- _wrap_in_pandas_container,
- )
- def test__wrap_in_pandas_container_dense():
- """Check _wrap_in_pandas_container for dense data."""
- pd = pytest.importorskip("pandas")
- X = np.asarray([[1, 0, 3], [0, 0, 1]])
- columns = np.asarray(["f0", "f1", "f2"], dtype=object)
- index = np.asarray([0, 1])
- dense_named = _wrap_in_pandas_container(X, columns=lambda: columns, index=index)
- assert isinstance(dense_named, pd.DataFrame)
- assert_array_equal(dense_named.columns, columns)
- assert_array_equal(dense_named.index, index)
- def test__wrap_in_pandas_container_dense_update_columns_and_index():
- """Check that _wrap_in_pandas_container overrides columns and index."""
- pd = pytest.importorskip("pandas")
- X_df = pd.DataFrame([[1, 0, 3], [0, 0, 1]], columns=["a", "b", "c"])
- new_columns = np.asarray(["f0", "f1", "f2"], dtype=object)
- new_index = [10, 12]
- new_df = _wrap_in_pandas_container(X_df, columns=new_columns, index=new_index)
- assert_array_equal(new_df.columns, new_columns)
- # Index does not change when the input is a DataFrame
- assert_array_equal(new_df.index, X_df.index)
- def test__wrap_in_pandas_container_error_validation():
- """Check errors in _wrap_in_pandas_container."""
- X = np.asarray([[1, 0, 3], [0, 0, 1]])
- X_csr = csr_matrix(X)
- match = "Pandas output does not support sparse data"
- with pytest.raises(ValueError, match=match):
- _wrap_in_pandas_container(X_csr, columns=["a", "b", "c"])
- class EstimatorWithoutSetOutputAndWithoutTransform:
- pass
- class EstimatorNoSetOutputWithTransform:
- def transform(self, X, y=None):
- return X # pragma: no cover
- class EstimatorWithSetOutput(_SetOutputMixin):
- def fit(self, X, y=None):
- self.n_features_in_ = X.shape[1]
- return self
- def transform(self, X, y=None):
- return X
- def get_feature_names_out(self, input_features=None):
- return np.asarray([f"X{i}" for i in range(self.n_features_in_)], dtype=object)
- def test__safe_set_output():
- """Check _safe_set_output works as expected."""
- # Estimator without transform will not raise when setting set_output for transform.
- est = EstimatorWithoutSetOutputAndWithoutTransform()
- _safe_set_output(est, transform="pandas")
- # Estimator with transform but without set_output will raise
- est = EstimatorNoSetOutputWithTransform()
- with pytest.raises(ValueError, match="Unable to configure output"):
- _safe_set_output(est, transform="pandas")
- est = EstimatorWithSetOutput().fit(np.asarray([[1, 2, 3]]))
- _safe_set_output(est, transform="pandas")
- config = _get_output_config("transform", est)
- assert config["dense"] == "pandas"
- _safe_set_output(est, transform="default")
- config = _get_output_config("transform", est)
- assert config["dense"] == "default"
- # transform is None is a no-op, so the config remains "default"
- _safe_set_output(est, transform=None)
- config = _get_output_config("transform", est)
- assert config["dense"] == "default"
- class EstimatorNoSetOutputWithTransformNoFeatureNamesOut(_SetOutputMixin):
- def transform(self, X, y=None):
- return X # pragma: no cover
- def test_set_output_mixin():
- """Estimator without get_feature_names_out does not define `set_output`."""
- est = EstimatorNoSetOutputWithTransformNoFeatureNamesOut()
- assert not hasattr(est, "set_output")
- def test__safe_set_output_error():
- """Check transform with invalid config."""
- X = np.asarray([[1, 0, 3], [0, 0, 1]])
- est = EstimatorWithSetOutput()
- _safe_set_output(est, transform="bad")
- msg = "output config must be 'default'"
- with pytest.raises(ValueError, match=msg):
- est.transform(X)
- def test_set_output_method():
- """Check that the output is pandas."""
- pd = pytest.importorskip("pandas")
- X = np.asarray([[1, 0, 3], [0, 0, 1]])
- est = EstimatorWithSetOutput().fit(X)
- # transform=None is a no-op
- est2 = est.set_output(transform=None)
- assert est2 is est
- X_trans_np = est2.transform(X)
- assert isinstance(X_trans_np, np.ndarray)
- est.set_output(transform="pandas")
- X_trans_pd = est.transform(X)
- assert isinstance(X_trans_pd, pd.DataFrame)
- def test_set_output_method_error():
- """Check transform fails with invalid transform."""
- X = np.asarray([[1, 0, 3], [0, 0, 1]])
- est = EstimatorWithSetOutput().fit(X)
- est.set_output(transform="bad")
- msg = "output config must be 'default'"
- with pytest.raises(ValueError, match=msg):
- est.transform(X)
- def test__get_output_config():
- """Check _get_output_config works as expected."""
- # Without a configuration set, the global config is used
- global_config = get_config()["transform_output"]
- config = _get_output_config("transform")
- assert config["dense"] == global_config
- with config_context(transform_output="pandas"):
- # with estimator=None, the global config is used
- config = _get_output_config("transform")
- assert config["dense"] == "pandas"
- est = EstimatorNoSetOutputWithTransform()
- config = _get_output_config("transform", est)
- assert config["dense"] == "pandas"
- est = EstimatorWithSetOutput()
- # If estimator has not config, use global config
- config = _get_output_config("transform", est)
- assert config["dense"] == "pandas"
- # If estimator has a config, use local config
- est.set_output(transform="default")
- config = _get_output_config("transform", est)
- assert config["dense"] == "default"
- est.set_output(transform="pandas")
- config = _get_output_config("transform", est)
- assert config["dense"] == "pandas"
- class EstimatorWithSetOutputNoAutoWrap(_SetOutputMixin, auto_wrap_output_keys=None):
- def transform(self, X, y=None):
- return X
- def test_get_output_auto_wrap_false():
- """Check that auto_wrap_output_keys=None does not wrap."""
- est = EstimatorWithSetOutputNoAutoWrap()
- assert not hasattr(est, "set_output")
- X = np.asarray([[1, 0, 3], [0, 0, 1]])
- assert X is est.transform(X)
- def test_auto_wrap_output_keys_errors_with_incorrect_input():
- msg = "auto_wrap_output_keys must be None or a tuple of keys."
- with pytest.raises(ValueError, match=msg):
- class BadEstimator(_SetOutputMixin, auto_wrap_output_keys="bad_parameter"):
- pass
- class AnotherMixin:
- def __init_subclass__(cls, custom_parameter, **kwargs):
- super().__init_subclass__(**kwargs)
- cls.custom_parameter = custom_parameter
- def test_set_output_mixin_custom_mixin():
- """Check that multiple init_subclasses passes parameters up."""
- class BothMixinEstimator(_SetOutputMixin, AnotherMixin, custom_parameter=123):
- def transform(self, X, y=None):
- return X
- def get_feature_names_out(self, input_features=None):
- return input_features
- est = BothMixinEstimator()
- assert est.custom_parameter == 123
- assert hasattr(est, "set_output")
- def test__wrap_in_pandas_container_column_errors():
- """If a callable `columns` errors, it has the same semantics as columns=None."""
- pd = pytest.importorskip("pandas")
- def get_columns():
- raise ValueError("No feature names defined")
- X_df = pd.DataFrame({"feat1": [1, 2, 3], "feat2": [3, 4, 5]})
- X_wrapped = _wrap_in_pandas_container(X_df, columns=get_columns)
- assert_array_equal(X_wrapped.columns, X_df.columns)
- X_np = np.asarray([[1, 3], [2, 4], [3, 5]])
- X_wrapped = _wrap_in_pandas_container(X_np, columns=get_columns)
- assert_array_equal(X_wrapped.columns, range(X_np.shape[1]))
- def test_set_output_mro():
- """Check that multi-inheritance resolves to the correct class method.
- Non-regression test gh-25293.
- """
- class Base(_SetOutputMixin):
- def transform(self, X):
- return "Base" # noqa
- class A(Base):
- pass
- class B(Base):
- def transform(self, X):
- return "B"
- class C(A, B):
- pass
- assert C().transform(None) == "B"
- class EstimatorWithSetOutputIndex(_SetOutputMixin):
- def fit(self, X, y=None):
- self.n_features_in_ = X.shape[1]
- return self
- def transform(self, X, y=None):
- import pandas as pd
- # transform by giving output a new index.
- return pd.DataFrame(X.to_numpy(), index=[f"s{i}" for i in range(X.shape[0])])
- def get_feature_names_out(self, input_features=None):
- return np.asarray([f"X{i}" for i in range(self.n_features_in_)], dtype=object)
- def test_set_output_pandas_keep_index():
- """Check that set_output does not override index.
- Non-regression test for gh-25730.
- """
- pd = pytest.importorskip("pandas")
- X = pd.DataFrame([[1, 2, 3], [4, 5, 6]], index=[0, 1])
- est = EstimatorWithSetOutputIndex().set_output(transform="pandas")
- est.fit(X)
- X_trans = est.transform(X)
- assert_array_equal(X_trans.index, ["s0", "s1"])
- class EstimatorReturnTuple(_SetOutputMixin):
- def __init__(self, OutputTuple):
- self.OutputTuple = OutputTuple
- def transform(self, X, y=None):
- return self.OutputTuple(X, 2 * X)
- def test_set_output_named_tuple_out():
- """Check that namedtuples are kept by default."""
- Output = namedtuple("Output", "X, Y")
- X = np.asarray([[1, 2, 3]])
- est = EstimatorReturnTuple(OutputTuple=Output)
- X_trans = est.transform(X)
- assert isinstance(X_trans, Output)
- assert_array_equal(X_trans.X, X)
- assert_array_equal(X_trans.Y, 2 * X)
- class EstimatorWithListInput(_SetOutputMixin):
- def fit(self, X, y=None):
- assert isinstance(X, list)
- self.n_features_in_ = len(X[0])
- return self
- def transform(self, X, y=None):
- return X
- def get_feature_names_out(self, input_features=None):
- return np.asarray([f"X{i}" for i in range(self.n_features_in_)], dtype=object)
- def test_set_output_list_input():
- """Check set_output for list input.
- Non-regression test for #27037.
- """
- pd = pytest.importorskip("pandas")
- X = [[0, 1, 2, 3], [4, 5, 6, 7]]
- est = EstimatorWithListInput()
- est.set_output(transform="pandas")
- X_out = est.fit(X).transform(X)
- assert isinstance(X_out, pd.DataFrame)
- assert_array_equal(X_out.columns, ["X0", "X1", "X2", "X3"])
|