test_class_weight.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. import numpy as np
  2. import pytest
  3. from numpy.testing import assert_allclose
  4. from scipy import sparse
  5. from sklearn.datasets import make_blobs
  6. from sklearn.linear_model import LogisticRegression
  7. from sklearn.tree import DecisionTreeClassifier
  8. from sklearn.utils._testing import assert_almost_equal, assert_array_almost_equal
  9. from sklearn.utils.class_weight import compute_class_weight, compute_sample_weight
  10. def test_compute_class_weight():
  11. # Test (and demo) compute_class_weight.
  12. y = np.asarray([2, 2, 2, 3, 3, 4])
  13. classes = np.unique(y)
  14. cw = compute_class_weight("balanced", classes=classes, y=y)
  15. # total effect of samples is preserved
  16. class_counts = np.bincount(y)[2:]
  17. assert_almost_equal(np.dot(cw, class_counts), y.shape[0])
  18. assert cw[0] < cw[1] < cw[2]
  19. def test_compute_class_weight_not_present():
  20. # Raise error when y does not contain all class labels
  21. classes = np.arange(4)
  22. y = np.asarray([0, 0, 0, 1, 1, 2])
  23. with pytest.raises(ValueError):
  24. compute_class_weight("balanced", classes=classes, y=y)
  25. # Fix exception in error message formatting when missing label is a string
  26. # https://github.com/scikit-learn/scikit-learn/issues/8312
  27. with pytest.raises(
  28. ValueError, match=r"The classes, \[0, 1, 2, 3\], are not in class_weight"
  29. ):
  30. compute_class_weight({"label_not_present": 1.0}, classes=classes, y=y)
  31. # Raise error when y has items not in classes
  32. classes = np.arange(2)
  33. with pytest.raises(ValueError):
  34. compute_class_weight("balanced", classes=classes, y=y)
  35. with pytest.raises(ValueError):
  36. compute_class_weight({0: 1.0, 1: 2.0}, classes=classes, y=y)
  37. # y contains a unweighted class that is not in class_weights
  38. classes = np.asarray(["cat", "dog"])
  39. y = np.asarray(["dog", "cat", "dog"])
  40. class_weights = {"dogs": 3, "cat": 2}
  41. msg = r"The classes, \['dog'\], are not in class_weight"
  42. with pytest.raises(ValueError, match=msg):
  43. compute_class_weight(class_weights, classes=classes, y=y)
  44. def test_compute_class_weight_dict():
  45. classes = np.arange(3)
  46. class_weights = {0: 1.0, 1: 2.0, 2: 3.0}
  47. y = np.asarray([0, 0, 1, 2])
  48. cw = compute_class_weight(class_weights, classes=classes, y=y)
  49. # When the user specifies class weights, compute_class_weights should just
  50. # return them.
  51. assert_array_almost_equal(np.asarray([1.0, 2.0, 3.0]), cw)
  52. # When a class weight is specified that isn't in classes, the weight is ignored
  53. class_weights = {0: 1.0, 1: 2.0, 2: 3.0, 4: 1.5}
  54. cw = compute_class_weight(class_weights, classes=classes, y=y)
  55. assert_allclose([1.0, 2.0, 3.0], cw)
  56. class_weights = {-1: 5.0, 0: 4.0, 1: 2.0, 2: 3.0}
  57. cw = compute_class_weight(class_weights, classes=classes, y=y)
  58. assert_allclose([4.0, 2.0, 3.0], cw)
  59. def test_compute_class_weight_invariance():
  60. # Test that results with class_weight="balanced" is invariant wrt
  61. # class imbalance if the number of samples is identical.
  62. # The test uses a balanced two class dataset with 100 datapoints.
  63. # It creates three versions, one where class 1 is duplicated
  64. # resulting in 150 points of class 1 and 50 of class 0,
  65. # one where there are 50 points in class 1 and 150 in class 0,
  66. # and one where there are 100 points of each class (this one is balanced
  67. # again).
  68. # With balancing class weights, all three should give the same model.
  69. X, y = make_blobs(centers=2, random_state=0)
  70. # create dataset where class 1 is duplicated twice
  71. X_1 = np.vstack([X] + [X[y == 1]] * 2)
  72. y_1 = np.hstack([y] + [y[y == 1]] * 2)
  73. # create dataset where class 0 is duplicated twice
  74. X_0 = np.vstack([X] + [X[y == 0]] * 2)
  75. y_0 = np.hstack([y] + [y[y == 0]] * 2)
  76. # duplicate everything
  77. X_ = np.vstack([X] * 2)
  78. y_ = np.hstack([y] * 2)
  79. # results should be identical
  80. logreg1 = LogisticRegression(class_weight="balanced").fit(X_1, y_1)
  81. logreg0 = LogisticRegression(class_weight="balanced").fit(X_0, y_0)
  82. logreg = LogisticRegression(class_weight="balanced").fit(X_, y_)
  83. assert_array_almost_equal(logreg1.coef_, logreg0.coef_)
  84. assert_array_almost_equal(logreg.coef_, logreg0.coef_)
  85. def test_compute_class_weight_balanced_negative():
  86. # Test compute_class_weight when labels are negative
  87. # Test with balanced class labels.
  88. classes = np.array([-2, -1, 0])
  89. y = np.asarray([-1, -1, 0, 0, -2, -2])
  90. cw = compute_class_weight("balanced", classes=classes, y=y)
  91. assert len(cw) == len(classes)
  92. assert_array_almost_equal(cw, np.array([1.0, 1.0, 1.0]))
  93. # Test with unbalanced class labels.
  94. y = np.asarray([-1, 0, 0, -2, -2, -2])
  95. cw = compute_class_weight("balanced", classes=classes, y=y)
  96. assert len(cw) == len(classes)
  97. class_counts = np.bincount(y + 2)
  98. assert_almost_equal(np.dot(cw, class_counts), y.shape[0])
  99. assert_array_almost_equal(cw, [2.0 / 3, 2.0, 1.0])
  100. def test_compute_class_weight_balanced_unordered():
  101. # Test compute_class_weight when classes are unordered
  102. classes = np.array([1, 0, 3])
  103. y = np.asarray([1, 0, 0, 3, 3, 3])
  104. cw = compute_class_weight("balanced", classes=classes, y=y)
  105. class_counts = np.bincount(y)[classes]
  106. assert_almost_equal(np.dot(cw, class_counts), y.shape[0])
  107. assert_array_almost_equal(cw, [2.0, 1.0, 2.0 / 3])
  108. def test_compute_class_weight_default():
  109. # Test for the case where no weight is given for a present class.
  110. # Current behaviour is to assign the unweighted classes a weight of 1.
  111. y = np.asarray([2, 2, 2, 3, 3, 4])
  112. classes = np.unique(y)
  113. classes_len = len(classes)
  114. # Test for non specified weights
  115. cw = compute_class_weight(None, classes=classes, y=y)
  116. assert len(cw) == classes_len
  117. assert_array_almost_equal(cw, np.ones(3))
  118. # Tests for partly specified weights
  119. cw = compute_class_weight({2: 1.5}, classes=classes, y=y)
  120. assert len(cw) == classes_len
  121. assert_array_almost_equal(cw, [1.5, 1.0, 1.0])
  122. cw = compute_class_weight({2: 1.5, 4: 0.5}, classes=classes, y=y)
  123. assert len(cw) == classes_len
  124. assert_array_almost_equal(cw, [1.5, 1.0, 0.5])
  125. def test_compute_sample_weight():
  126. # Test (and demo) compute_sample_weight.
  127. # Test with balanced classes
  128. y = np.asarray([1, 1, 1, 2, 2, 2])
  129. sample_weight = compute_sample_weight("balanced", y)
  130. assert_array_almost_equal(sample_weight, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
  131. # Test with user-defined weights
  132. sample_weight = compute_sample_weight({1: 2, 2: 1}, y)
  133. assert_array_almost_equal(sample_weight, [2.0, 2.0, 2.0, 1.0, 1.0, 1.0])
  134. # Test with column vector of balanced classes
  135. y = np.asarray([[1], [1], [1], [2], [2], [2]])
  136. sample_weight = compute_sample_weight("balanced", y)
  137. assert_array_almost_equal(sample_weight, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
  138. # Test with unbalanced classes
  139. y = np.asarray([1, 1, 1, 2, 2, 2, 3])
  140. sample_weight = compute_sample_weight("balanced", y)
  141. expected_balanced = np.array(
  142. [0.7777, 0.7777, 0.7777, 0.7777, 0.7777, 0.7777, 2.3333]
  143. )
  144. assert_array_almost_equal(sample_weight, expected_balanced, decimal=4)
  145. # Test with `None` weights
  146. sample_weight = compute_sample_weight(None, y)
  147. assert_array_almost_equal(sample_weight, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
  148. # Test with multi-output of balanced classes
  149. y = np.asarray([[1, 0], [1, 0], [1, 0], [2, 1], [2, 1], [2, 1]])
  150. sample_weight = compute_sample_weight("balanced", y)
  151. assert_array_almost_equal(sample_weight, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
  152. # Test with multi-output with user-defined weights
  153. y = np.asarray([[1, 0], [1, 0], [1, 0], [2, 1], [2, 1], [2, 1]])
  154. sample_weight = compute_sample_weight([{1: 2, 2: 1}, {0: 1, 1: 2}], y)
  155. assert_array_almost_equal(sample_weight, [2.0, 2.0, 2.0, 2.0, 2.0, 2.0])
  156. # Test with multi-output of unbalanced classes
  157. y = np.asarray([[1, 0], [1, 0], [1, 0], [2, 1], [2, 1], [2, 1], [3, -1]])
  158. sample_weight = compute_sample_weight("balanced", y)
  159. assert_array_almost_equal(sample_weight, expected_balanced**2, decimal=3)
  160. def test_compute_sample_weight_with_subsample():
  161. # Test compute_sample_weight with subsamples specified.
  162. # Test with balanced classes and all samples present
  163. y = np.asarray([1, 1, 1, 2, 2, 2])
  164. sample_weight = compute_sample_weight("balanced", y, indices=range(6))
  165. assert_array_almost_equal(sample_weight, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
  166. # Test with column vector of balanced classes and all samples present
  167. y = np.asarray([[1], [1], [1], [2], [2], [2]])
  168. sample_weight = compute_sample_weight("balanced", y, indices=range(6))
  169. assert_array_almost_equal(sample_weight, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
  170. # Test with a subsample
  171. y = np.asarray([1, 1, 1, 2, 2, 2])
  172. sample_weight = compute_sample_weight("balanced", y, indices=range(4))
  173. assert_array_almost_equal(sample_weight, [2.0 / 3, 2.0 / 3, 2.0 / 3, 2.0, 2.0, 2.0])
  174. # Test with a bootstrap subsample
  175. y = np.asarray([1, 1, 1, 2, 2, 2])
  176. sample_weight = compute_sample_weight("balanced", y, indices=[0, 1, 1, 2, 2, 3])
  177. expected_balanced = np.asarray([0.6, 0.6, 0.6, 3.0, 3.0, 3.0])
  178. assert_array_almost_equal(sample_weight, expected_balanced)
  179. # Test with a bootstrap subsample for multi-output
  180. y = np.asarray([[1, 0], [1, 0], [1, 0], [2, 1], [2, 1], [2, 1]])
  181. sample_weight = compute_sample_weight("balanced", y, indices=[0, 1, 1, 2, 2, 3])
  182. assert_array_almost_equal(sample_weight, expected_balanced**2)
  183. # Test with a missing class
  184. y = np.asarray([1, 1, 1, 2, 2, 2, 3])
  185. sample_weight = compute_sample_weight("balanced", y, indices=range(6))
  186. assert_array_almost_equal(sample_weight, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0])
  187. # Test with a missing class for multi-output
  188. y = np.asarray([[1, 0], [1, 0], [1, 0], [2, 1], [2, 1], [2, 1], [2, 2]])
  189. sample_weight = compute_sample_weight("balanced", y, indices=range(6))
  190. assert_array_almost_equal(sample_weight, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0])
  191. def test_compute_sample_weight_errors():
  192. # Test compute_sample_weight raises errors expected.
  193. # Invalid preset string
  194. y = np.asarray([1, 1, 1, 2, 2, 2])
  195. y_ = np.asarray([[1, 0], [1, 0], [1, 0], [2, 1], [2, 1], [2, 1]])
  196. with pytest.raises(ValueError):
  197. compute_sample_weight("ni", y)
  198. with pytest.raises(ValueError):
  199. compute_sample_weight("ni", y, indices=range(4))
  200. with pytest.raises(ValueError):
  201. compute_sample_weight("ni", y_)
  202. with pytest.raises(ValueError):
  203. compute_sample_weight("ni", y_, indices=range(4))
  204. # Not "balanced" for subsample
  205. with pytest.raises(ValueError):
  206. compute_sample_weight({1: 2, 2: 1}, y, indices=range(4))
  207. # Not a list or preset for multi-output
  208. with pytest.raises(ValueError):
  209. compute_sample_weight({1: 2, 2: 1}, y_)
  210. # Incorrect length list for multi-output
  211. with pytest.raises(ValueError):
  212. compute_sample_weight([{1: 2, 2: 1}], y_)
  213. def test_compute_sample_weight_more_than_32():
  214. # Non-regression smoke test for #12146
  215. y = np.arange(50) # more than 32 distinct classes
  216. indices = np.arange(50) # use subsampling
  217. weight = compute_sample_weight("balanced", y, indices=indices)
  218. assert_array_almost_equal(weight, np.ones(y.shape[0]))
  219. def test_class_weight_does_not_contains_more_classes():
  220. """Check that class_weight can contain more labels than in y.
  221. Non-regression test for #22413
  222. """
  223. tree = DecisionTreeClassifier(class_weight={0: 1, 1: 10, 2: 20})
  224. # Does not raise
  225. tree.fit([[0, 0, 1], [1, 0, 1], [1, 2, 0]], [0, 0, 1])
  226. def test_compute_sample_weight_sparse():
  227. """Check that we can compute weight for sparse `y`."""
  228. y = sparse.csc_matrix(np.asarray([0, 1, 1])).T
  229. sample_weight = compute_sample_weight("balanced", y)
  230. assert_allclose(sample_weight, [1.5, 0.75, 0.75])