test_function_transformer.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. import warnings
  2. import numpy as np
  3. import pytest
  4. from scipy import sparse
  5. from sklearn.pipeline import make_pipeline
  6. from sklearn.preprocessing import FunctionTransformer
  7. from sklearn.utils import _safe_indexing
  8. from sklearn.utils._testing import (
  9. _convert_container,
  10. assert_allclose_dense_sparse,
  11. assert_array_equal,
  12. )
  13. def _make_func(args_store, kwargs_store, func=lambda X, *a, **k: X):
  14. def _func(X, *args, **kwargs):
  15. args_store.append(X)
  16. args_store.extend(args)
  17. kwargs_store.update(kwargs)
  18. return func(X)
  19. return _func
  20. def test_delegate_to_func():
  21. # (args|kwargs)_store will hold the positional and keyword arguments
  22. # passed to the function inside the FunctionTransformer.
  23. args_store = []
  24. kwargs_store = {}
  25. X = np.arange(10).reshape((5, 2))
  26. assert_array_equal(
  27. FunctionTransformer(_make_func(args_store, kwargs_store)).transform(X),
  28. X,
  29. "transform should have returned X unchanged",
  30. )
  31. # The function should only have received X.
  32. assert args_store == [
  33. X
  34. ], "Incorrect positional arguments passed to func: {args}".format(args=args_store)
  35. assert (
  36. not kwargs_store
  37. ), "Unexpected keyword arguments passed to func: {args}".format(args=kwargs_store)
  38. # reset the argument stores.
  39. args_store[:] = []
  40. kwargs_store.clear()
  41. transformed = FunctionTransformer(
  42. _make_func(args_store, kwargs_store),
  43. ).transform(X)
  44. assert_array_equal(
  45. transformed, X, err_msg="transform should have returned X unchanged"
  46. )
  47. # The function should have received X
  48. assert args_store == [
  49. X
  50. ], "Incorrect positional arguments passed to func: {args}".format(args=args_store)
  51. assert (
  52. not kwargs_store
  53. ), "Unexpected keyword arguments passed to func: {args}".format(args=kwargs_store)
  54. def test_np_log():
  55. X = np.arange(10).reshape((5, 2))
  56. # Test that the numpy.log example still works.
  57. assert_array_equal(
  58. FunctionTransformer(np.log1p).transform(X),
  59. np.log1p(X),
  60. )
  61. def test_kw_arg():
  62. X = np.linspace(0, 1, num=10).reshape((5, 2))
  63. F = FunctionTransformer(np.around, kw_args=dict(decimals=3))
  64. # Test that rounding is correct
  65. assert_array_equal(F.transform(X), np.around(X, decimals=3))
  66. def test_kw_arg_update():
  67. X = np.linspace(0, 1, num=10).reshape((5, 2))
  68. F = FunctionTransformer(np.around, kw_args=dict(decimals=3))
  69. F.kw_args["decimals"] = 1
  70. # Test that rounding is correct
  71. assert_array_equal(F.transform(X), np.around(X, decimals=1))
  72. def test_kw_arg_reset():
  73. X = np.linspace(0, 1, num=10).reshape((5, 2))
  74. F = FunctionTransformer(np.around, kw_args=dict(decimals=3))
  75. F.kw_args = dict(decimals=1)
  76. # Test that rounding is correct
  77. assert_array_equal(F.transform(X), np.around(X, decimals=1))
  78. def test_inverse_transform():
  79. X = np.array([1, 4, 9, 16]).reshape((2, 2))
  80. # Test that inverse_transform works correctly
  81. F = FunctionTransformer(
  82. func=np.sqrt,
  83. inverse_func=np.around,
  84. inv_kw_args=dict(decimals=3),
  85. )
  86. assert_array_equal(
  87. F.inverse_transform(F.transform(X)),
  88. np.around(np.sqrt(X), decimals=3),
  89. )
  90. def test_check_inverse():
  91. X_dense = np.array([1, 4, 9, 16], dtype=np.float64).reshape((2, 2))
  92. X_list = [X_dense, sparse.csr_matrix(X_dense), sparse.csc_matrix(X_dense)]
  93. for X in X_list:
  94. if sparse.issparse(X):
  95. accept_sparse = True
  96. else:
  97. accept_sparse = False
  98. trans = FunctionTransformer(
  99. func=np.sqrt,
  100. inverse_func=np.around,
  101. accept_sparse=accept_sparse,
  102. check_inverse=True,
  103. validate=True,
  104. )
  105. warning_message = (
  106. "The provided functions are not strictly"
  107. " inverse of each other. If you are sure you"
  108. " want to proceed regardless, set"
  109. " 'check_inverse=False'."
  110. )
  111. with pytest.warns(UserWarning, match=warning_message):
  112. trans.fit(X)
  113. trans = FunctionTransformer(
  114. func=np.expm1,
  115. inverse_func=np.log1p,
  116. accept_sparse=accept_sparse,
  117. check_inverse=True,
  118. validate=True,
  119. )
  120. with warnings.catch_warnings():
  121. warnings.simplefilter("error", UserWarning)
  122. Xt = trans.fit_transform(X)
  123. assert_allclose_dense_sparse(X, trans.inverse_transform(Xt))
  124. # check that we don't check inverse when one of the func or inverse is not
  125. # provided.
  126. trans = FunctionTransformer(
  127. func=np.expm1, inverse_func=None, check_inverse=True, validate=True
  128. )
  129. with warnings.catch_warnings():
  130. warnings.simplefilter("error", UserWarning)
  131. trans.fit(X_dense)
  132. trans = FunctionTransformer(
  133. func=None, inverse_func=np.expm1, check_inverse=True, validate=True
  134. )
  135. with warnings.catch_warnings():
  136. warnings.simplefilter("error", UserWarning)
  137. trans.fit(X_dense)
  138. def test_function_transformer_frame():
  139. pd = pytest.importorskip("pandas")
  140. X_df = pd.DataFrame(np.random.randn(100, 10))
  141. transformer = FunctionTransformer()
  142. X_df_trans = transformer.fit_transform(X_df)
  143. assert hasattr(X_df_trans, "loc")
  144. @pytest.mark.parametrize("X_type", ["array", "series"])
  145. def test_function_transformer_raise_error_with_mixed_dtype(X_type):
  146. """Check that `FunctionTransformer.check_inverse` raises error on mixed dtype."""
  147. mapping = {"one": 1, "two": 2, "three": 3, 5: "five", 6: "six"}
  148. inverse_mapping = {value: key for key, value in mapping.items()}
  149. dtype = "object"
  150. data = ["one", "two", "three", "one", "one", 5, 6]
  151. data = _convert_container(data, X_type, columns_name=["value"], dtype=dtype)
  152. def func(X):
  153. return np.array(
  154. [mapping[_safe_indexing(X, i)] for i in range(X.size)], dtype=object
  155. )
  156. def inverse_func(X):
  157. return _convert_container(
  158. [inverse_mapping[x] for x in X],
  159. X_type,
  160. columns_name=["value"],
  161. dtype=dtype,
  162. )
  163. transformer = FunctionTransformer(
  164. func=func, inverse_func=inverse_func, validate=False, check_inverse=True
  165. )
  166. msg = "'check_inverse' is only supported when all the elements in `X` is numerical."
  167. with pytest.raises(ValueError, match=msg):
  168. transformer.fit(data)
  169. def test_function_transformer_support_all_nummerical_dataframes_check_inverse_True():
  170. """Check support for dataframes with only numerical values."""
  171. pd = pytest.importorskip("pandas")
  172. df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
  173. transformer = FunctionTransformer(
  174. func=lambda x: x + 2, inverse_func=lambda x: x - 2, check_inverse=True
  175. )
  176. # Does not raise an error
  177. df_out = transformer.fit_transform(df)
  178. assert_allclose_dense_sparse(df_out, df + 2)
  179. def test_function_transformer_with_dataframe_and_check_inverse_True():
  180. """Check error is raised when check_inverse=True.
  181. Non-regresion test for gh-25261.
  182. """
  183. pd = pytest.importorskip("pandas")
  184. transformer = FunctionTransformer(
  185. func=lambda x: x, inverse_func=lambda x: x, check_inverse=True
  186. )
  187. df_mixed = pd.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]})
  188. msg = "'check_inverse' is only supported when all the elements in `X` is numerical."
  189. with pytest.raises(ValueError, match=msg):
  190. transformer.fit(df_mixed)
  191. @pytest.mark.parametrize(
  192. "X, feature_names_out, input_features, expected",
  193. [
  194. (
  195. # NumPy inputs, default behavior: generate names
  196. np.random.rand(100, 3),
  197. "one-to-one",
  198. None,
  199. ("x0", "x1", "x2"),
  200. ),
  201. (
  202. # Pandas input, default behavior: use input feature names
  203. {"a": np.random.rand(100), "b": np.random.rand(100)},
  204. "one-to-one",
  205. None,
  206. ("a", "b"),
  207. ),
  208. (
  209. # NumPy input, feature_names_out=callable
  210. np.random.rand(100, 3),
  211. lambda transformer, input_features: ("a", "b"),
  212. None,
  213. ("a", "b"),
  214. ),
  215. (
  216. # Pandas input, feature_names_out=callable
  217. {"a": np.random.rand(100), "b": np.random.rand(100)},
  218. lambda transformer, input_features: ("c", "d", "e"),
  219. None,
  220. ("c", "d", "e"),
  221. ),
  222. (
  223. # NumPy input, feature_names_out=callable – default input_features
  224. np.random.rand(100, 3),
  225. lambda transformer, input_features: tuple(input_features) + ("a",),
  226. None,
  227. ("x0", "x1", "x2", "a"),
  228. ),
  229. (
  230. # Pandas input, feature_names_out=callable – default input_features
  231. {"a": np.random.rand(100), "b": np.random.rand(100)},
  232. lambda transformer, input_features: tuple(input_features) + ("c",),
  233. None,
  234. ("a", "b", "c"),
  235. ),
  236. (
  237. # NumPy input, input_features=list of names
  238. np.random.rand(100, 3),
  239. "one-to-one",
  240. ("a", "b", "c"),
  241. ("a", "b", "c"),
  242. ),
  243. (
  244. # Pandas input, input_features=list of names
  245. {"a": np.random.rand(100), "b": np.random.rand(100)},
  246. "one-to-one",
  247. ("a", "b"), # must match feature_names_in_
  248. ("a", "b"),
  249. ),
  250. (
  251. # NumPy input, feature_names_out=callable, input_features=list
  252. np.random.rand(100, 3),
  253. lambda transformer, input_features: tuple(input_features) + ("d",),
  254. ("a", "b", "c"),
  255. ("a", "b", "c", "d"),
  256. ),
  257. (
  258. # Pandas input, feature_names_out=callable, input_features=list
  259. {"a": np.random.rand(100), "b": np.random.rand(100)},
  260. lambda transformer, input_features: tuple(input_features) + ("c",),
  261. ("a", "b"), # must match feature_names_in_
  262. ("a", "b", "c"),
  263. ),
  264. ],
  265. )
  266. @pytest.mark.parametrize("validate", [True, False])
  267. def test_function_transformer_get_feature_names_out(
  268. X, feature_names_out, input_features, expected, validate
  269. ):
  270. if isinstance(X, dict):
  271. pd = pytest.importorskip("pandas")
  272. X = pd.DataFrame(X)
  273. transformer = FunctionTransformer(
  274. feature_names_out=feature_names_out, validate=validate
  275. )
  276. transformer.fit_transform(X)
  277. names = transformer.get_feature_names_out(input_features)
  278. assert isinstance(names, np.ndarray)
  279. assert names.dtype == object
  280. assert_array_equal(names, expected)
  281. def test_function_transformer_get_feature_names_out_without_validation():
  282. transformer = FunctionTransformer(feature_names_out="one-to-one", validate=False)
  283. X = np.random.rand(100, 2)
  284. transformer.fit_transform(X)
  285. names = transformer.get_feature_names_out(("a", "b"))
  286. assert isinstance(names, np.ndarray)
  287. assert names.dtype == object
  288. assert_array_equal(names, ("a", "b"))
  289. def test_function_transformer_feature_names_out_is_None():
  290. transformer = FunctionTransformer()
  291. X = np.random.rand(100, 2)
  292. transformer.fit_transform(X)
  293. msg = "This 'FunctionTransformer' has no attribute 'get_feature_names_out'"
  294. with pytest.raises(AttributeError, match=msg):
  295. transformer.get_feature_names_out()
  296. def test_function_transformer_feature_names_out_uses_estimator():
  297. def add_n_random_features(X, n):
  298. return np.concatenate([X, np.random.rand(len(X), n)], axis=1)
  299. def feature_names_out(transformer, input_features):
  300. n = transformer.kw_args["n"]
  301. return list(input_features) + [f"rnd{i}" for i in range(n)]
  302. transformer = FunctionTransformer(
  303. func=add_n_random_features,
  304. feature_names_out=feature_names_out,
  305. kw_args=dict(n=3),
  306. validate=True,
  307. )
  308. pd = pytest.importorskip("pandas")
  309. df = pd.DataFrame({"a": np.random.rand(100), "b": np.random.rand(100)})
  310. transformer.fit_transform(df)
  311. names = transformer.get_feature_names_out()
  312. assert isinstance(names, np.ndarray)
  313. assert names.dtype == object
  314. assert_array_equal(names, ("a", "b", "rnd0", "rnd1", "rnd2"))
  315. def test_function_transformer_validate_inverse():
  316. """Test that function transformer does not reset estimator in
  317. `inverse_transform`."""
  318. def add_constant_feature(X):
  319. X_one = np.ones((X.shape[0], 1))
  320. return np.concatenate((X, X_one), axis=1)
  321. def inverse_add_constant(X):
  322. return X[:, :-1]
  323. X = np.array([[1, 2], [3, 4], [3, 4]])
  324. trans = FunctionTransformer(
  325. func=add_constant_feature,
  326. inverse_func=inverse_add_constant,
  327. validate=True,
  328. )
  329. X_trans = trans.fit_transform(X)
  330. assert trans.n_features_in_ == X.shape[1]
  331. trans.inverse_transform(X_trans)
  332. assert trans.n_features_in_ == X.shape[1]
  333. @pytest.mark.parametrize(
  334. "feature_names_out, expected",
  335. [
  336. ("one-to-one", ["pet", "color"]),
  337. [lambda est, names: [f"{n}_out" for n in names], ["pet_out", "color_out"]],
  338. ],
  339. )
  340. @pytest.mark.parametrize("in_pipeline", [True, False])
  341. def test_get_feature_names_out_dataframe_with_string_data(
  342. feature_names_out, expected, in_pipeline
  343. ):
  344. """Check that get_feature_names_out works with DataFrames with string data."""
  345. pd = pytest.importorskip("pandas")
  346. X = pd.DataFrame({"pet": ["dog", "cat"], "color": ["red", "green"]})
  347. transformer = FunctionTransformer(feature_names_out=feature_names_out)
  348. if in_pipeline:
  349. transformer = make_pipeline(transformer)
  350. X_trans = transformer.fit_transform(X)
  351. assert isinstance(X_trans, pd.DataFrame)
  352. names = transformer.get_feature_names_out()
  353. assert isinstance(names, np.ndarray)
  354. assert names.dtype == object
  355. assert_array_equal(names, expected)
  356. def test_set_output_func():
  357. """Check behavior of set_output with different settings."""
  358. pd = pytest.importorskip("pandas")
  359. X = pd.DataFrame({"a": [1, 2, 3], "b": [10, 20, 100]})
  360. ft = FunctionTransformer(np.log, feature_names_out="one-to-one")
  361. # no warning is raised when feature_names_out is defined
  362. with warnings.catch_warnings():
  363. warnings.simplefilter("error", UserWarning)
  364. ft.set_output(transform="pandas")
  365. X_trans = ft.fit_transform(X)
  366. assert isinstance(X_trans, pd.DataFrame)
  367. assert_array_equal(X_trans.columns, ["a", "b"])
  368. # If feature_names_out is not defined, then a warning is raised in
  369. # `set_output`
  370. ft = FunctionTransformer(lambda x: 2 * x)
  371. msg = "should return a DataFrame to follow the set_output API"
  372. with pytest.warns(UserWarning, match=msg):
  373. ft.set_output(transform="pandas")
  374. X_trans = ft.fit_transform(X)
  375. assert isinstance(X_trans, pd.DataFrame)
  376. assert_array_equal(X_trans.columns, ["a", "b"])