test_common.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import warnings
  2. import numpy as np
  3. import pytest
  4. from scipy import sparse
  5. from sklearn.base import clone
  6. from sklearn.datasets import load_iris
  7. from sklearn.model_selection import train_test_split
  8. from sklearn.preprocessing import (
  9. MaxAbsScaler,
  10. MinMaxScaler,
  11. PowerTransformer,
  12. QuantileTransformer,
  13. RobustScaler,
  14. StandardScaler,
  15. maxabs_scale,
  16. minmax_scale,
  17. power_transform,
  18. quantile_transform,
  19. robust_scale,
  20. scale,
  21. )
  22. from sklearn.utils._testing import assert_allclose, assert_array_equal
  23. iris = load_iris()
  24. def _get_valid_samples_by_column(X, col):
  25. """Get non NaN samples in column of X"""
  26. return X[:, [col]][~np.isnan(X[:, col])]
  27. @pytest.mark.parametrize(
  28. "est, func, support_sparse, strictly_positive, omit_kwargs",
  29. [
  30. (MaxAbsScaler(), maxabs_scale, True, False, []),
  31. (MinMaxScaler(), minmax_scale, False, False, ["clip"]),
  32. (StandardScaler(), scale, False, False, []),
  33. (StandardScaler(with_mean=False), scale, True, False, []),
  34. (PowerTransformer("yeo-johnson"), power_transform, False, False, []),
  35. (PowerTransformer("box-cox"), power_transform, False, True, []),
  36. (QuantileTransformer(n_quantiles=10), quantile_transform, True, False, []),
  37. (RobustScaler(), robust_scale, False, False, []),
  38. (RobustScaler(with_centering=False), robust_scale, True, False, []),
  39. ],
  40. )
  41. def test_missing_value_handling(
  42. est, func, support_sparse, strictly_positive, omit_kwargs
  43. ):
  44. # check that the preprocessing method let pass nan
  45. rng = np.random.RandomState(42)
  46. X = iris.data.copy()
  47. n_missing = 50
  48. X[
  49. rng.randint(X.shape[0], size=n_missing), rng.randint(X.shape[1], size=n_missing)
  50. ] = np.nan
  51. if strictly_positive:
  52. X += np.nanmin(X) + 0.1
  53. X_train, X_test = train_test_split(X, random_state=1)
  54. # sanity check
  55. assert not np.all(np.isnan(X_train), axis=0).any()
  56. assert np.any(np.isnan(X_train), axis=0).all()
  57. assert np.any(np.isnan(X_test), axis=0).all()
  58. X_test[:, 0] = np.nan # make sure this boundary case is tested
  59. with warnings.catch_warnings():
  60. warnings.simplefilter("error", RuntimeWarning)
  61. Xt = est.fit(X_train).transform(X_test)
  62. # ensure no warnings are raised
  63. # missing values should still be missing, and only them
  64. assert_array_equal(np.isnan(Xt), np.isnan(X_test))
  65. # check that the function leads to the same results as the class
  66. with warnings.catch_warnings():
  67. warnings.simplefilter("error", RuntimeWarning)
  68. Xt_class = est.transform(X_train)
  69. kwargs = est.get_params()
  70. # remove the parameters which should be omitted because they
  71. # are not defined in the counterpart function of the preprocessing class
  72. for kwarg in omit_kwargs:
  73. _ = kwargs.pop(kwarg)
  74. Xt_func = func(X_train, **kwargs)
  75. assert_array_equal(np.isnan(Xt_func), np.isnan(Xt_class))
  76. assert_allclose(Xt_func[~np.isnan(Xt_func)], Xt_class[~np.isnan(Xt_class)])
  77. # check that the inverse transform keep NaN
  78. Xt_inv = est.inverse_transform(Xt)
  79. assert_array_equal(np.isnan(Xt_inv), np.isnan(X_test))
  80. # FIXME: we can introduce equal_nan=True in recent version of numpy.
  81. # For the moment which just check that non-NaN values are almost equal.
  82. assert_allclose(Xt_inv[~np.isnan(Xt_inv)], X_test[~np.isnan(X_test)])
  83. for i in range(X.shape[1]):
  84. # train only on non-NaN
  85. est.fit(_get_valid_samples_by_column(X_train, i))
  86. # check transforming with NaN works even when training without NaN
  87. with warnings.catch_warnings():
  88. warnings.simplefilter("error", RuntimeWarning)
  89. Xt_col = est.transform(X_test[:, [i]])
  90. assert_allclose(Xt_col, Xt[:, [i]])
  91. # check non-NaN is handled as before - the 1st column is all nan
  92. if not np.isnan(X_test[:, i]).all():
  93. Xt_col_nonan = est.transform(_get_valid_samples_by_column(X_test, i))
  94. assert_array_equal(Xt_col_nonan, Xt_col[~np.isnan(Xt_col.squeeze())])
  95. if support_sparse:
  96. est_dense = clone(est)
  97. est_sparse = clone(est)
  98. with warnings.catch_warnings():
  99. warnings.simplefilter("error", RuntimeWarning)
  100. Xt_dense = est_dense.fit(X_train).transform(X_test)
  101. Xt_inv_dense = est_dense.inverse_transform(Xt_dense)
  102. for sparse_constructor in (
  103. sparse.csr_matrix,
  104. sparse.csc_matrix,
  105. sparse.bsr_matrix,
  106. sparse.coo_matrix,
  107. sparse.dia_matrix,
  108. sparse.dok_matrix,
  109. sparse.lil_matrix,
  110. ):
  111. # check that the dense and sparse inputs lead to the same results
  112. # precompute the matrix to avoid catching side warnings
  113. X_train_sp = sparse_constructor(X_train)
  114. X_test_sp = sparse_constructor(X_test)
  115. with warnings.catch_warnings():
  116. warnings.simplefilter("ignore", PendingDeprecationWarning)
  117. warnings.simplefilter("error", RuntimeWarning)
  118. Xt_sp = est_sparse.fit(X_train_sp).transform(X_test_sp)
  119. assert_allclose(Xt_sp.toarray(), Xt_dense)
  120. with warnings.catch_warnings():
  121. warnings.simplefilter("ignore", PendingDeprecationWarning)
  122. warnings.simplefilter("error", RuntimeWarning)
  123. Xt_inv_sp = est_sparse.inverse_transform(Xt_sp)
  124. assert_allclose(Xt_inv_sp.toarray(), Xt_inv_dense)
  125. @pytest.mark.parametrize(
  126. "est, func",
  127. [
  128. (MaxAbsScaler(), maxabs_scale),
  129. (MinMaxScaler(), minmax_scale),
  130. (StandardScaler(), scale),
  131. (StandardScaler(with_mean=False), scale),
  132. (PowerTransformer("yeo-johnson"), power_transform),
  133. (
  134. PowerTransformer("box-cox"),
  135. power_transform,
  136. ),
  137. (QuantileTransformer(n_quantiles=3), quantile_transform),
  138. (RobustScaler(), robust_scale),
  139. (RobustScaler(with_centering=False), robust_scale),
  140. ],
  141. )
  142. def test_missing_value_pandas_na_support(est, func):
  143. # Test pandas IntegerArray with pd.NA
  144. pd = pytest.importorskip("pandas")
  145. X = np.array(
  146. [
  147. [1, 2, 3, np.nan, np.nan, 4, 5, 1],
  148. [np.nan, np.nan, 8, 4, 6, np.nan, np.nan, 8],
  149. [1, 2, 3, 4, 5, 6, 7, 8],
  150. ]
  151. ).T
  152. # Creates dataframe with IntegerArrays with pd.NA
  153. X_df = pd.DataFrame(X, dtype="Int16", columns=["a", "b", "c"])
  154. X_df["c"] = X_df["c"].astype("int")
  155. X_trans = est.fit_transform(X)
  156. X_df_trans = est.fit_transform(X_df)
  157. assert_allclose(X_trans, X_df_trans)