test_variance_threshold.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import numpy as np
  2. import pytest
  3. from scipy.sparse import bsr_matrix, csc_matrix, csr_matrix
  4. from sklearn.feature_selection import VarianceThreshold
  5. from sklearn.utils._testing import assert_array_equal
  6. data = [[0, 1, 2, 3, 4], [0, 2, 2, 3, 5], [1, 1, 2, 4, 0]]
  7. data2 = [[-0.13725701]] * 10
  8. def test_zero_variance():
  9. # Test VarianceThreshold with default setting, zero variance.
  10. for X in [data, csr_matrix(data), csc_matrix(data), bsr_matrix(data)]:
  11. sel = VarianceThreshold().fit(X)
  12. assert_array_equal([0, 1, 3, 4], sel.get_support(indices=True))
  13. with pytest.raises(ValueError):
  14. VarianceThreshold().fit([[0, 1, 2, 3]])
  15. with pytest.raises(ValueError):
  16. VarianceThreshold().fit([[0, 1], [0, 1]])
  17. def test_variance_threshold():
  18. # Test VarianceThreshold with custom variance.
  19. for X in [data, csr_matrix(data)]:
  20. X = VarianceThreshold(threshold=0.4).fit_transform(X)
  21. assert (len(data), 1) == X.shape
  22. @pytest.mark.skipif(
  23. np.var(data2) == 0,
  24. reason=(
  25. "This test is not valid for this platform, "
  26. "as it relies on numerical instabilities."
  27. ),
  28. )
  29. def test_zero_variance_floating_point_error():
  30. # Test that VarianceThreshold(0.0).fit eliminates features that have
  31. # the same value in every sample, even when floating point errors
  32. # cause np.var not to be 0 for the feature.
  33. # See #13691
  34. for X in [data2, csr_matrix(data2), csc_matrix(data2), bsr_matrix(data2)]:
  35. msg = "No feature in X meets the variance threshold 0.00000"
  36. with pytest.raises(ValueError, match=msg):
  37. VarianceThreshold().fit(X)
  38. def test_variance_nan():
  39. arr = np.array(data, dtype=np.float64)
  40. # add single NaN and feature should still be included
  41. arr[0, 0] = np.nan
  42. # make all values in feature NaN and feature should be rejected
  43. arr[:, 1] = np.nan
  44. for X in [arr, csr_matrix(arr), csc_matrix(arr), bsr_matrix(arr)]:
  45. sel = VarianceThreshold().fit(X)
  46. assert_array_equal([0, 3, 4], sel.get_support(indices=True))