test_chi2.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. """
  2. Tests for chi2, currently the only feature selection function designed
  3. specifically to work with sparse matrices.
  4. """
  5. import warnings
  6. import numpy as np
  7. import pytest
  8. import scipy.stats
  9. from scipy.sparse import coo_matrix, csr_matrix
  10. from sklearn.feature_selection import SelectKBest, chi2
  11. from sklearn.feature_selection._univariate_selection import _chisquare
  12. from sklearn.utils._testing import assert_array_almost_equal, assert_array_equal
  13. # Feature 0 is highly informative for class 1;
  14. # feature 1 is the same everywhere;
  15. # feature 2 is a bit informative for class 2.
  16. X = [[2, 1, 2], [9, 1, 1], [6, 1, 2], [0, 1, 2]]
  17. y = [0, 1, 2, 2]
  18. def mkchi2(k):
  19. """Make k-best chi2 selector"""
  20. return SelectKBest(chi2, k=k)
  21. def test_chi2():
  22. # Test Chi2 feature extraction
  23. chi2 = mkchi2(k=1).fit(X, y)
  24. chi2 = mkchi2(k=1).fit(X, y)
  25. assert_array_equal(chi2.get_support(indices=True), [0])
  26. assert_array_equal(chi2.transform(X), np.array(X)[:, [0]])
  27. chi2 = mkchi2(k=2).fit(X, y)
  28. assert_array_equal(sorted(chi2.get_support(indices=True)), [0, 2])
  29. Xsp = csr_matrix(X, dtype=np.float64)
  30. chi2 = mkchi2(k=2).fit(Xsp, y)
  31. assert_array_equal(sorted(chi2.get_support(indices=True)), [0, 2])
  32. Xtrans = chi2.transform(Xsp)
  33. assert_array_equal(Xtrans.shape, [Xsp.shape[0], 2])
  34. # == doesn't work on scipy.sparse matrices
  35. Xtrans = Xtrans.toarray()
  36. Xtrans2 = mkchi2(k=2).fit_transform(Xsp, y).toarray()
  37. assert_array_almost_equal(Xtrans, Xtrans2)
  38. def test_chi2_coo():
  39. # Check that chi2 works with a COO matrix
  40. # (as returned by CountVectorizer, DictVectorizer)
  41. Xcoo = coo_matrix(X)
  42. mkchi2(k=2).fit_transform(Xcoo, y)
  43. # if we got here without an exception, we're safe
  44. def test_chi2_negative():
  45. # Check for proper error on negative numbers in the input X.
  46. X, y = [[0, 1], [-1e-20, 1]], [0, 1]
  47. for X in (X, np.array(X), csr_matrix(X)):
  48. with pytest.raises(ValueError):
  49. chi2(X, y)
  50. def test_chi2_unused_feature():
  51. # Unused feature should evaluate to NaN
  52. # and should issue no runtime warning
  53. with warnings.catch_warnings(record=True) as warned:
  54. warnings.simplefilter("always")
  55. chi, p = chi2([[1, 0], [0, 0]], [1, 0])
  56. for w in warned:
  57. if "divide by zero" in repr(w):
  58. raise AssertionError("Found unexpected warning %s" % w)
  59. assert_array_equal(chi, [1, np.nan])
  60. assert_array_equal(p[1], np.nan)
  61. def test_chisquare():
  62. # Test replacement for scipy.stats.chisquare against the original.
  63. obs = np.array([[2.0, 2.0], [1.0, 1.0]])
  64. exp = np.array([[1.5, 1.5], [1.5, 1.5]])
  65. # call SciPy first because our version overwrites obs
  66. chi_scp, p_scp = scipy.stats.chisquare(obs, exp)
  67. chi_our, p_our = _chisquare(obs, exp)
  68. assert_array_almost_equal(chi_scp, chi_our)
  69. assert_array_almost_equal(p_scp, p_our)