test_base.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import numpy as np
  2. import pytest
  3. from numpy.testing import assert_array_equal
  4. from scipy import sparse as sp
  5. from sklearn.base import BaseEstimator
  6. from sklearn.feature_selection._base import SelectorMixin
  7. class StepSelector(SelectorMixin, BaseEstimator):
  8. """Retain every `step` features (beginning with 0).
  9. If `step < 1`, then no features are selected.
  10. """
  11. def __init__(self, step=2):
  12. self.step = step
  13. def fit(self, X, y=None):
  14. X = self._validate_data(X, accept_sparse="csc")
  15. return self
  16. def _get_support_mask(self):
  17. mask = np.zeros(self.n_features_in_, dtype=bool)
  18. if self.step >= 1:
  19. mask[:: self.step] = True
  20. return mask
  21. support = [True, False] * 5
  22. support_inds = [0, 2, 4, 6, 8]
  23. X = np.arange(20).reshape(2, 10)
  24. Xt = np.arange(0, 20, 2).reshape(2, 5)
  25. Xinv = X.copy()
  26. Xinv[:, 1::2] = 0
  27. y = [0, 1]
  28. feature_names = list("ABCDEFGHIJ")
  29. feature_names_t = feature_names[::2]
  30. feature_names_inv = np.array(feature_names)
  31. feature_names_inv[1::2] = ""
  32. def test_transform_dense():
  33. sel = StepSelector()
  34. Xt_actual = sel.fit(X, y).transform(X)
  35. Xt_actual2 = StepSelector().fit_transform(X, y)
  36. assert_array_equal(Xt, Xt_actual)
  37. assert_array_equal(Xt, Xt_actual2)
  38. # Check dtype matches
  39. assert np.int32 == sel.transform(X.astype(np.int32)).dtype
  40. assert np.float32 == sel.transform(X.astype(np.float32)).dtype
  41. # Check 1d list and other dtype:
  42. names_t_actual = sel.transform([feature_names])
  43. assert_array_equal(feature_names_t, names_t_actual.ravel())
  44. # Check wrong shape raises error
  45. with pytest.raises(ValueError):
  46. sel.transform(np.array([[1], [2]]))
  47. def test_transform_sparse():
  48. sparse = sp.csc_matrix
  49. sel = StepSelector()
  50. Xt_actual = sel.fit(sparse(X)).transform(sparse(X))
  51. Xt_actual2 = sel.fit_transform(sparse(X))
  52. assert_array_equal(Xt, Xt_actual.toarray())
  53. assert_array_equal(Xt, Xt_actual2.toarray())
  54. # Check dtype matches
  55. assert np.int32 == sel.transform(sparse(X).astype(np.int32)).dtype
  56. assert np.float32 == sel.transform(sparse(X).astype(np.float32)).dtype
  57. # Check wrong shape raises error
  58. with pytest.raises(ValueError):
  59. sel.transform(np.array([[1], [2]]))
  60. def test_inverse_transform_dense():
  61. sel = StepSelector()
  62. Xinv_actual = sel.fit(X, y).inverse_transform(Xt)
  63. assert_array_equal(Xinv, Xinv_actual)
  64. # Check dtype matches
  65. assert np.int32 == sel.inverse_transform(Xt.astype(np.int32)).dtype
  66. assert np.float32 == sel.inverse_transform(Xt.astype(np.float32)).dtype
  67. # Check 1d list and other dtype:
  68. names_inv_actual = sel.inverse_transform([feature_names_t])
  69. assert_array_equal(feature_names_inv, names_inv_actual.ravel())
  70. # Check wrong shape raises error
  71. with pytest.raises(ValueError):
  72. sel.inverse_transform(np.array([[1], [2]]))
  73. def test_inverse_transform_sparse():
  74. sparse = sp.csc_matrix
  75. sel = StepSelector()
  76. Xinv_actual = sel.fit(sparse(X)).inverse_transform(sparse(Xt))
  77. assert_array_equal(Xinv, Xinv_actual.toarray())
  78. # Check dtype matches
  79. assert np.int32 == sel.inverse_transform(sparse(Xt).astype(np.int32)).dtype
  80. assert np.float32 == sel.inverse_transform(sparse(Xt).astype(np.float32)).dtype
  81. # Check wrong shape raises error
  82. with pytest.raises(ValueError):
  83. sel.inverse_transform(np.array([[1], [2]]))
  84. def test_get_support():
  85. sel = StepSelector()
  86. sel.fit(X, y)
  87. assert_array_equal(support, sel.get_support())
  88. assert_array_equal(support_inds, sel.get_support(indices=True))
  89. def test_output_dataframe():
  90. """Check output dtypes for dataframes is consistent with the input dtypes."""
  91. pd = pytest.importorskip("pandas")
  92. X = pd.DataFrame(
  93. {
  94. "a": pd.Series([1.0, 2.4, 4.5], dtype=np.float32),
  95. "b": pd.Series(["a", "b", "a"], dtype="category"),
  96. "c": pd.Series(["j", "b", "b"], dtype="category"),
  97. "d": pd.Series([3.0, 2.4, 1.2], dtype=np.float64),
  98. }
  99. )
  100. for step in [2, 3]:
  101. sel = StepSelector(step=step).set_output(transform="pandas")
  102. sel.fit(X)
  103. output = sel.transform(X)
  104. for name, dtype in output.dtypes.items():
  105. assert dtype == X.dtypes[name]
  106. # step=0 will select nothing
  107. sel0 = StepSelector(step=0).set_output(transform="pandas")
  108. sel0.fit(X, y)
  109. msg = "No features were selected"
  110. with pytest.warns(UserWarning, match=msg):
  111. output0 = sel0.transform(X)
  112. assert_array_equal(output0.index, X.index)
  113. assert output0.shape == (X.shape[0], 0)