| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061 |
- import numpy as np
- import pytest
- from scipy.sparse import bsr_matrix, csc_matrix, csr_matrix
- from sklearn.feature_selection import VarianceThreshold
- from sklearn.utils._testing import assert_array_equal
- data = [[0, 1, 2, 3, 4], [0, 2, 2, 3, 5], [1, 1, 2, 4, 0]]
- data2 = [[-0.13725701]] * 10
- def test_zero_variance():
- # Test VarianceThreshold with default setting, zero variance.
- for X in [data, csr_matrix(data), csc_matrix(data), bsr_matrix(data)]:
- sel = VarianceThreshold().fit(X)
- assert_array_equal([0, 1, 3, 4], sel.get_support(indices=True))
- with pytest.raises(ValueError):
- VarianceThreshold().fit([[0, 1, 2, 3]])
- with pytest.raises(ValueError):
- VarianceThreshold().fit([[0, 1], [0, 1]])
- def test_variance_threshold():
- # Test VarianceThreshold with custom variance.
- for X in [data, csr_matrix(data)]:
- X = VarianceThreshold(threshold=0.4).fit_transform(X)
- assert (len(data), 1) == X.shape
- @pytest.mark.skipif(
- np.var(data2) == 0,
- reason=(
- "This test is not valid for this platform, "
- "as it relies on numerical instabilities."
- ),
- )
- def test_zero_variance_floating_point_error():
- # Test that VarianceThreshold(0.0).fit eliminates features that have
- # the same value in every sample, even when floating point errors
- # cause np.var not to be 0 for the feature.
- # See #13691
- for X in [data2, csr_matrix(data2), csc_matrix(data2), bsr_matrix(data2)]:
- msg = "No feature in X meets the variance threshold 0.00000"
- with pytest.raises(ValueError, match=msg):
- VarianceThreshold().fit(X)
- def test_variance_nan():
- arr = np.array(data, dtype=np.float64)
- # add single NaN and feature should still be included
- arr[0, 0] = np.nan
- # make all values in feature NaN and feature should be rejected
- arr[:, 1] = np.nan
- for X in [arr, csr_matrix(arr), csc_matrix(arr), bsr_matrix(arr)]:
- sel = VarianceThreshold().fit(X)
- assert_array_equal([0, 3, 4], sel.get_support(indices=True))
|