test_set_output.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. from collections import namedtuple
  2. import numpy as np
  3. import pytest
  4. from numpy.testing import assert_array_equal
  5. from scipy.sparse import csr_matrix
  6. from sklearn._config import config_context, get_config
  7. from sklearn.utils._set_output import (
  8. _get_output_config,
  9. _safe_set_output,
  10. _SetOutputMixin,
  11. _wrap_in_pandas_container,
  12. )
  13. def test__wrap_in_pandas_container_dense():
  14. """Check _wrap_in_pandas_container for dense data."""
  15. pd = pytest.importorskip("pandas")
  16. X = np.asarray([[1, 0, 3], [0, 0, 1]])
  17. columns = np.asarray(["f0", "f1", "f2"], dtype=object)
  18. index = np.asarray([0, 1])
  19. dense_named = _wrap_in_pandas_container(X, columns=lambda: columns, index=index)
  20. assert isinstance(dense_named, pd.DataFrame)
  21. assert_array_equal(dense_named.columns, columns)
  22. assert_array_equal(dense_named.index, index)
  23. def test__wrap_in_pandas_container_dense_update_columns_and_index():
  24. """Check that _wrap_in_pandas_container overrides columns and index."""
  25. pd = pytest.importorskip("pandas")
  26. X_df = pd.DataFrame([[1, 0, 3], [0, 0, 1]], columns=["a", "b", "c"])
  27. new_columns = np.asarray(["f0", "f1", "f2"], dtype=object)
  28. new_index = [10, 12]
  29. new_df = _wrap_in_pandas_container(X_df, columns=new_columns, index=new_index)
  30. assert_array_equal(new_df.columns, new_columns)
  31. # Index does not change when the input is a DataFrame
  32. assert_array_equal(new_df.index, X_df.index)
  33. def test__wrap_in_pandas_container_error_validation():
  34. """Check errors in _wrap_in_pandas_container."""
  35. X = np.asarray([[1, 0, 3], [0, 0, 1]])
  36. X_csr = csr_matrix(X)
  37. match = "Pandas output does not support sparse data"
  38. with pytest.raises(ValueError, match=match):
  39. _wrap_in_pandas_container(X_csr, columns=["a", "b", "c"])
  40. class EstimatorWithoutSetOutputAndWithoutTransform:
  41. pass
  42. class EstimatorNoSetOutputWithTransform:
  43. def transform(self, X, y=None):
  44. return X # pragma: no cover
  45. class EstimatorWithSetOutput(_SetOutputMixin):
  46. def fit(self, X, y=None):
  47. self.n_features_in_ = X.shape[1]
  48. return self
  49. def transform(self, X, y=None):
  50. return X
  51. def get_feature_names_out(self, input_features=None):
  52. return np.asarray([f"X{i}" for i in range(self.n_features_in_)], dtype=object)
  53. def test__safe_set_output():
  54. """Check _safe_set_output works as expected."""
  55. # Estimator without transform will not raise when setting set_output for transform.
  56. est = EstimatorWithoutSetOutputAndWithoutTransform()
  57. _safe_set_output(est, transform="pandas")
  58. # Estimator with transform but without set_output will raise
  59. est = EstimatorNoSetOutputWithTransform()
  60. with pytest.raises(ValueError, match="Unable to configure output"):
  61. _safe_set_output(est, transform="pandas")
  62. est = EstimatorWithSetOutput().fit(np.asarray([[1, 2, 3]]))
  63. _safe_set_output(est, transform="pandas")
  64. config = _get_output_config("transform", est)
  65. assert config["dense"] == "pandas"
  66. _safe_set_output(est, transform="default")
  67. config = _get_output_config("transform", est)
  68. assert config["dense"] == "default"
  69. # transform is None is a no-op, so the config remains "default"
  70. _safe_set_output(est, transform=None)
  71. config = _get_output_config("transform", est)
  72. assert config["dense"] == "default"
  73. class EstimatorNoSetOutputWithTransformNoFeatureNamesOut(_SetOutputMixin):
  74. def transform(self, X, y=None):
  75. return X # pragma: no cover
  76. def test_set_output_mixin():
  77. """Estimator without get_feature_names_out does not define `set_output`."""
  78. est = EstimatorNoSetOutputWithTransformNoFeatureNamesOut()
  79. assert not hasattr(est, "set_output")
  80. def test__safe_set_output_error():
  81. """Check transform with invalid config."""
  82. X = np.asarray([[1, 0, 3], [0, 0, 1]])
  83. est = EstimatorWithSetOutput()
  84. _safe_set_output(est, transform="bad")
  85. msg = "output config must be 'default'"
  86. with pytest.raises(ValueError, match=msg):
  87. est.transform(X)
  88. def test_set_output_method():
  89. """Check that the output is pandas."""
  90. pd = pytest.importorskip("pandas")
  91. X = np.asarray([[1, 0, 3], [0, 0, 1]])
  92. est = EstimatorWithSetOutput().fit(X)
  93. # transform=None is a no-op
  94. est2 = est.set_output(transform=None)
  95. assert est2 is est
  96. X_trans_np = est2.transform(X)
  97. assert isinstance(X_trans_np, np.ndarray)
  98. est.set_output(transform="pandas")
  99. X_trans_pd = est.transform(X)
  100. assert isinstance(X_trans_pd, pd.DataFrame)
  101. def test_set_output_method_error():
  102. """Check transform fails with invalid transform."""
  103. X = np.asarray([[1, 0, 3], [0, 0, 1]])
  104. est = EstimatorWithSetOutput().fit(X)
  105. est.set_output(transform="bad")
  106. msg = "output config must be 'default'"
  107. with pytest.raises(ValueError, match=msg):
  108. est.transform(X)
  109. def test__get_output_config():
  110. """Check _get_output_config works as expected."""
  111. # Without a configuration set, the global config is used
  112. global_config = get_config()["transform_output"]
  113. config = _get_output_config("transform")
  114. assert config["dense"] == global_config
  115. with config_context(transform_output="pandas"):
  116. # with estimator=None, the global config is used
  117. config = _get_output_config("transform")
  118. assert config["dense"] == "pandas"
  119. est = EstimatorNoSetOutputWithTransform()
  120. config = _get_output_config("transform", est)
  121. assert config["dense"] == "pandas"
  122. est = EstimatorWithSetOutput()
  123. # If estimator has not config, use global config
  124. config = _get_output_config("transform", est)
  125. assert config["dense"] == "pandas"
  126. # If estimator has a config, use local config
  127. est.set_output(transform="default")
  128. config = _get_output_config("transform", est)
  129. assert config["dense"] == "default"
  130. est.set_output(transform="pandas")
  131. config = _get_output_config("transform", est)
  132. assert config["dense"] == "pandas"
  133. class EstimatorWithSetOutputNoAutoWrap(_SetOutputMixin, auto_wrap_output_keys=None):
  134. def transform(self, X, y=None):
  135. return X
  136. def test_get_output_auto_wrap_false():
  137. """Check that auto_wrap_output_keys=None does not wrap."""
  138. est = EstimatorWithSetOutputNoAutoWrap()
  139. assert not hasattr(est, "set_output")
  140. X = np.asarray([[1, 0, 3], [0, 0, 1]])
  141. assert X is est.transform(X)
  142. def test_auto_wrap_output_keys_errors_with_incorrect_input():
  143. msg = "auto_wrap_output_keys must be None or a tuple of keys."
  144. with pytest.raises(ValueError, match=msg):
  145. class BadEstimator(_SetOutputMixin, auto_wrap_output_keys="bad_parameter"):
  146. pass
  147. class AnotherMixin:
  148. def __init_subclass__(cls, custom_parameter, **kwargs):
  149. super().__init_subclass__(**kwargs)
  150. cls.custom_parameter = custom_parameter
  151. def test_set_output_mixin_custom_mixin():
  152. """Check that multiple init_subclasses passes parameters up."""
  153. class BothMixinEstimator(_SetOutputMixin, AnotherMixin, custom_parameter=123):
  154. def transform(self, X, y=None):
  155. return X
  156. def get_feature_names_out(self, input_features=None):
  157. return input_features
  158. est = BothMixinEstimator()
  159. assert est.custom_parameter == 123
  160. assert hasattr(est, "set_output")
  161. def test__wrap_in_pandas_container_column_errors():
  162. """If a callable `columns` errors, it has the same semantics as columns=None."""
  163. pd = pytest.importorskip("pandas")
  164. def get_columns():
  165. raise ValueError("No feature names defined")
  166. X_df = pd.DataFrame({"feat1": [1, 2, 3], "feat2": [3, 4, 5]})
  167. X_wrapped = _wrap_in_pandas_container(X_df, columns=get_columns)
  168. assert_array_equal(X_wrapped.columns, X_df.columns)
  169. X_np = np.asarray([[1, 3], [2, 4], [3, 5]])
  170. X_wrapped = _wrap_in_pandas_container(X_np, columns=get_columns)
  171. assert_array_equal(X_wrapped.columns, range(X_np.shape[1]))
  172. def test_set_output_mro():
  173. """Check that multi-inheritance resolves to the correct class method.
  174. Non-regression test gh-25293.
  175. """
  176. class Base(_SetOutputMixin):
  177. def transform(self, X):
  178. return "Base" # noqa
  179. class A(Base):
  180. pass
  181. class B(Base):
  182. def transform(self, X):
  183. return "B"
  184. class C(A, B):
  185. pass
  186. assert C().transform(None) == "B"
  187. class EstimatorWithSetOutputIndex(_SetOutputMixin):
  188. def fit(self, X, y=None):
  189. self.n_features_in_ = X.shape[1]
  190. return self
  191. def transform(self, X, y=None):
  192. import pandas as pd
  193. # transform by giving output a new index.
  194. return pd.DataFrame(X.to_numpy(), index=[f"s{i}" for i in range(X.shape[0])])
  195. def get_feature_names_out(self, input_features=None):
  196. return np.asarray([f"X{i}" for i in range(self.n_features_in_)], dtype=object)
  197. def test_set_output_pandas_keep_index():
  198. """Check that set_output does not override index.
  199. Non-regression test for gh-25730.
  200. """
  201. pd = pytest.importorskip("pandas")
  202. X = pd.DataFrame([[1, 2, 3], [4, 5, 6]], index=[0, 1])
  203. est = EstimatorWithSetOutputIndex().set_output(transform="pandas")
  204. est.fit(X)
  205. X_trans = est.transform(X)
  206. assert_array_equal(X_trans.index, ["s0", "s1"])
  207. class EstimatorReturnTuple(_SetOutputMixin):
  208. def __init__(self, OutputTuple):
  209. self.OutputTuple = OutputTuple
  210. def transform(self, X, y=None):
  211. return self.OutputTuple(X, 2 * X)
  212. def test_set_output_named_tuple_out():
  213. """Check that namedtuples are kept by default."""
  214. Output = namedtuple("Output", "X, Y")
  215. X = np.asarray([[1, 2, 3]])
  216. est = EstimatorReturnTuple(OutputTuple=Output)
  217. X_trans = est.transform(X)
  218. assert isinstance(X_trans, Output)
  219. assert_array_equal(X_trans.X, X)
  220. assert_array_equal(X_trans.Y, 2 * X)
  221. class EstimatorWithListInput(_SetOutputMixin):
  222. def fit(self, X, y=None):
  223. assert isinstance(X, list)
  224. self.n_features_in_ = len(X[0])
  225. return self
  226. def transform(self, X, y=None):
  227. return X
  228. def get_feature_names_out(self, input_features=None):
  229. return np.asarray([f"X{i}" for i in range(self.n_features_in_)], dtype=object)
  230. def test_set_output_list_input():
  231. """Check set_output for list input.
  232. Non-regression test for #27037.
  233. """
  234. pd = pytest.importorskip("pandas")
  235. X = [[0, 1, 2, 3], [4, 5, 6, 7]]
  236. est = EstimatorWithListInput()
  237. est.set_output(transform="pandas")
  238. X_out = est.fit(X).transform(X)
  239. assert isinstance(X_out, pd.DataFrame)
  240. assert_array_equal(X_out.columns, ["X0", "X1", "X2", "X3"])