test_response.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. import numpy as np
  2. import pytest
  3. from sklearn.datasets import (
  4. load_iris,
  5. make_classification,
  6. make_multilabel_classification,
  7. make_regression,
  8. )
  9. from sklearn.linear_model import (
  10. LinearRegression,
  11. LogisticRegression,
  12. )
  13. from sklearn.multioutput import ClassifierChain
  14. from sklearn.preprocessing import scale
  15. from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
  16. from sklearn.utils._mocking import _MockEstimatorOnOffPrediction
  17. from sklearn.utils._response import _get_response_values, _get_response_values_binary
  18. from sklearn.utils._testing import assert_allclose, assert_array_equal
  19. X, y = load_iris(return_X_y=True)
  20. # scale the data to avoid ConvergenceWarning with LogisticRegression
  21. X = scale(X, copy=False)
  22. X_binary, y_binary = X[:100], y[:100]
  23. @pytest.mark.parametrize("response_method", ["decision_function", "predict_proba"])
  24. def test_get_response_values_regressor_error(response_method):
  25. """Check the error message with regressor an not supported response
  26. method."""
  27. my_estimator = _MockEstimatorOnOffPrediction(response_methods=[response_method])
  28. X = "mocking_data", "mocking_target"
  29. err_msg = f"{my_estimator.__class__.__name__} should either be a classifier"
  30. with pytest.raises(ValueError, match=err_msg):
  31. _get_response_values(my_estimator, X, response_method=response_method)
  32. def test_get_response_values_regressor():
  33. """Check the behaviour of `_get_response_values` with regressor."""
  34. X, y = make_regression(n_samples=10, random_state=0)
  35. regressor = LinearRegression().fit(X, y)
  36. y_pred, pos_label = _get_response_values(
  37. regressor,
  38. X,
  39. response_method="predict",
  40. )
  41. assert_array_equal(y_pred, regressor.predict(X))
  42. assert pos_label is None
  43. @pytest.mark.parametrize(
  44. "response_method",
  45. ["predict_proba", "decision_function", "predict"],
  46. )
  47. def test_get_response_values_classifier_unknown_pos_label(response_method):
  48. """Check that `_get_response_values` raises the proper error message with
  49. classifier."""
  50. X, y = make_classification(n_samples=10, n_classes=2, random_state=0)
  51. classifier = LogisticRegression().fit(X, y)
  52. # provide a `pos_label` which is not in `y`
  53. err_msg = r"pos_label=whatever is not a valid label: It should be one of \[0 1\]"
  54. with pytest.raises(ValueError, match=err_msg):
  55. _get_response_values(
  56. classifier,
  57. X,
  58. response_method=response_method,
  59. pos_label="whatever",
  60. )
  61. def test_get_response_values_classifier_inconsistent_y_pred_for_binary_proba():
  62. """Check that `_get_response_values` will raise an error when `y_pred` has a
  63. single class with `predict_proba`."""
  64. X, y_two_class = make_classification(n_samples=10, n_classes=2, random_state=0)
  65. y_single_class = np.zeros_like(y_two_class)
  66. classifier = DecisionTreeClassifier().fit(X, y_single_class)
  67. err_msg = (
  68. r"Got predict_proba of shape \(10, 1\), but need classifier with "
  69. r"two classes"
  70. )
  71. with pytest.raises(ValueError, match=err_msg):
  72. _get_response_values(classifier, X, response_method="predict_proba")
  73. def test_get_response_values_binary_classifier_decision_function():
  74. """Check the behaviour of `_get_response_values` with `decision_function`
  75. and binary classifier."""
  76. X, y = make_classification(
  77. n_samples=10,
  78. n_classes=2,
  79. weights=[0.3, 0.7],
  80. random_state=0,
  81. )
  82. classifier = LogisticRegression().fit(X, y)
  83. response_method = "decision_function"
  84. # default `pos_label`
  85. y_pred, pos_label = _get_response_values(
  86. classifier,
  87. X,
  88. response_method=response_method,
  89. pos_label=None,
  90. )
  91. assert_allclose(y_pred, classifier.decision_function(X))
  92. assert pos_label == 1
  93. # when forcing `pos_label=classifier.classes_[0]`
  94. y_pred, pos_label = _get_response_values(
  95. classifier,
  96. X,
  97. response_method=response_method,
  98. pos_label=classifier.classes_[0],
  99. )
  100. assert_allclose(y_pred, classifier.decision_function(X) * -1)
  101. assert pos_label == 0
  102. def test_get_response_values_binary_classifier_predict_proba():
  103. """Check that `_get_response_values` with `predict_proba` and binary
  104. classifier."""
  105. X, y = make_classification(
  106. n_samples=10,
  107. n_classes=2,
  108. weights=[0.3, 0.7],
  109. random_state=0,
  110. )
  111. classifier = LogisticRegression().fit(X, y)
  112. response_method = "predict_proba"
  113. # default `pos_label`
  114. y_pred, pos_label = _get_response_values(
  115. classifier,
  116. X,
  117. response_method=response_method,
  118. pos_label=None,
  119. )
  120. assert_allclose(y_pred, classifier.predict_proba(X)[:, 1])
  121. assert pos_label == 1
  122. # when forcing `pos_label=classifier.classes_[0]`
  123. y_pred, pos_label = _get_response_values(
  124. classifier,
  125. X,
  126. response_method=response_method,
  127. pos_label=classifier.classes_[0],
  128. )
  129. assert_allclose(y_pred, classifier.predict_proba(X)[:, 0])
  130. assert pos_label == 0
  131. @pytest.mark.parametrize(
  132. "estimator, X, y, err_msg, params",
  133. [
  134. (
  135. DecisionTreeRegressor(),
  136. X_binary,
  137. y_binary,
  138. "Expected 'estimator' to be a binary classifier",
  139. {"response_method": "auto"},
  140. ),
  141. (
  142. DecisionTreeClassifier(),
  143. X_binary,
  144. y_binary,
  145. r"pos_label=unknown is not a valid label: It should be one of \[0 1\]",
  146. {"response_method": "auto", "pos_label": "unknown"},
  147. ),
  148. (
  149. DecisionTreeClassifier(),
  150. X,
  151. y,
  152. "be a binary classifier. Got 3 classes instead.",
  153. {"response_method": "predict_proba"},
  154. ),
  155. ],
  156. )
  157. def test_get_response_error(estimator, X, y, err_msg, params):
  158. """Check that we raise the proper error messages in _get_response_values_binary."""
  159. estimator.fit(X, y)
  160. with pytest.raises(ValueError, match=err_msg):
  161. _get_response_values_binary(estimator, X, **params)
  162. def test_get_response_predict_proba():
  163. """Check the behaviour of `_get_response_values_binary` using `predict_proba`."""
  164. classifier = DecisionTreeClassifier().fit(X_binary, y_binary)
  165. y_proba, pos_label = _get_response_values_binary(
  166. classifier, X_binary, response_method="predict_proba"
  167. )
  168. np.testing.assert_allclose(y_proba, classifier.predict_proba(X_binary)[:, 1])
  169. assert pos_label == 1
  170. y_proba, pos_label = _get_response_values_binary(
  171. classifier, X_binary, response_method="predict_proba", pos_label=0
  172. )
  173. np.testing.assert_allclose(y_proba, classifier.predict_proba(X_binary)[:, 0])
  174. assert pos_label == 0
  175. def test_get_response_decision_function():
  176. """Check the behaviour of `_get_response_values_binary` using decision_function."""
  177. classifier = LogisticRegression().fit(X_binary, y_binary)
  178. y_score, pos_label = _get_response_values_binary(
  179. classifier, X_binary, response_method="decision_function"
  180. )
  181. np.testing.assert_allclose(y_score, classifier.decision_function(X_binary))
  182. assert pos_label == 1
  183. y_score, pos_label = _get_response_values_binary(
  184. classifier, X_binary, response_method="decision_function", pos_label=0
  185. )
  186. np.testing.assert_allclose(y_score, classifier.decision_function(X_binary) * -1)
  187. assert pos_label == 0
  188. @pytest.mark.parametrize(
  189. "estimator, response_method",
  190. [
  191. (DecisionTreeClassifier(max_depth=2, random_state=0), "predict_proba"),
  192. (LogisticRegression(), "decision_function"),
  193. ],
  194. )
  195. def test_get_response_values_multiclass(estimator, response_method):
  196. """Check that we can call `_get_response_values` with a multiclass estimator.
  197. It should return the predictions untouched.
  198. """
  199. estimator.fit(X, y)
  200. predictions, pos_label = _get_response_values(
  201. estimator, X, response_method=response_method
  202. )
  203. assert pos_label is None
  204. assert predictions.shape == (X.shape[0], len(estimator.classes_))
  205. if response_method == "predict_proba":
  206. assert np.logical_and(predictions >= 0, predictions <= 1).all()
  207. @pytest.mark.parametrize(
  208. "response_method", ["predict_proba", "decision_function", "predict"]
  209. )
  210. def test_get_response_values_multilabel_indicator(response_method):
  211. X, Y = make_multilabel_classification(random_state=0)
  212. estimator = ClassifierChain(LogisticRegression()).fit(X, Y)
  213. y_pred, pos_label = _get_response_values(
  214. estimator, X, response_method=response_method
  215. )
  216. assert pos_label is None
  217. assert y_pred.shape == Y.shape
  218. if response_method == "predict_proba":
  219. assert np.logical_and(y_pred >= 0, y_pred <= 1).all()
  220. elif response_method == "decision_function":
  221. # values returned by `decision_function` are not bounded in [0, 1]
  222. assert (y_pred < 0).sum() > 0
  223. assert (y_pred > 1).sum() > 0
  224. else: # response_method == "predict"
  225. assert np.logical_or(y_pred == 0, y_pred == 1).all()