_response.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. """Utilities to get the response values of a classifier or a regressor.
  2. It allows to make uniform checks and validation.
  3. """
  4. import numpy as np
  5. from ..base import is_classifier
  6. from .multiclass import type_of_target
  7. from .validation import _check_response_method, check_is_fitted
  8. def _process_predict_proba(*, y_pred, target_type, classes, pos_label):
  9. """Get the response values when the response method is `predict_proba`.
  10. This function process the `y_pred` array in the binary and multi-label cases.
  11. In the binary case, it selects the column corresponding to the positive
  12. class. In the multi-label case, it stacks the predictions if they are not
  13. in the "compressed" format `(n_samples, n_outputs)`.
  14. Parameters
  15. ----------
  16. y_pred : ndarray
  17. Output of `estimator.predict_proba`. The shape depends on the target type:
  18. - for binary classification, it is a 2d array of shape `(n_samples, 2)`;
  19. - for multiclass classification, it is a 2d array of shape
  20. `(n_samples, n_classes)`;
  21. - for multilabel classification, it is either a list of 2d arrays of shape
  22. `(n_samples, 2)` (e.g. `RandomForestClassifier` or `KNeighborsClassifier`) or
  23. an array of shape `(n_samples, n_outputs)` (e.g. `MLPClassifier` or
  24. `RidgeClassifier`).
  25. target_type : {"binary", "multiclass", "multilabel-indicator"}
  26. Type of the target.
  27. classes : ndarray of shape (n_classes,) or list of such arrays
  28. Class labels as reported by `estimator.classes_`.
  29. pos_label : int, float, bool or str
  30. Only used with binary and multiclass targets.
  31. Returns
  32. -------
  33. y_pred : ndarray of shape (n_samples,), (n_samples, n_classes) or \
  34. (n_samples, n_output)
  35. Compressed predictions format as requested by the metrics.
  36. """
  37. if target_type == "binary" and y_pred.shape[1] < 2:
  38. # We don't handle classifiers trained on a single class.
  39. raise ValueError(
  40. f"Got predict_proba of shape {y_pred.shape}, but need "
  41. "classifier with two classes."
  42. )
  43. if target_type == "binary":
  44. col_idx = np.flatnonzero(classes == pos_label)[0]
  45. return y_pred[:, col_idx]
  46. elif target_type == "multilabel-indicator":
  47. # Use a compress format of shape `(n_samples, n_output)`.
  48. # Only `MLPClassifier` and `RidgeClassifier` return an array of shape
  49. # `(n_samples, n_outputs)`.
  50. if isinstance(y_pred, list):
  51. # list of arrays of shape `(n_samples, 2)`
  52. return np.vstack([p[:, -1] for p in y_pred]).T
  53. else:
  54. # array of shape `(n_samples, n_outputs)`
  55. return y_pred
  56. return y_pred
  57. def _process_decision_function(*, y_pred, target_type, classes, pos_label):
  58. """Get the response values when the response method is `decision_function`.
  59. This function process the `y_pred` array in the binary and multi-label cases.
  60. In the binary case, it inverts the sign of the score if the positive label
  61. is not `classes[1]`. In the multi-label case, it stacks the predictions if
  62. they are not in the "compressed" format `(n_samples, n_outputs)`.
  63. Parameters
  64. ----------
  65. y_pred : ndarray
  66. Output of `estimator.predict_proba`. The shape depends on the target type:
  67. - for binary classification, it is a 1d array of shape `(n_samples,)` where the
  68. sign is assuming that `classes[1]` is the positive class;
  69. - for multiclass classification, it is a 2d array of shape
  70. `(n_samples, n_classes)`;
  71. - for multilabel classification, it is a 2d array of shape `(n_samples,
  72. n_outputs)`.
  73. target_type : {"binary", "multiclass", "multilabel-indicator"}
  74. Type of the target.
  75. classes : ndarray of shape (n_classes,) or list of such arrays
  76. Class labels as reported by `estimator.classes_`.
  77. pos_label : int, float, bool or str
  78. Only used with binary and multiclass targets.
  79. Returns
  80. -------
  81. y_pred : ndarray of shape (n_samples,), (n_samples, n_classes) or \
  82. (n_samples, n_output)
  83. Compressed predictions format as requested by the metrics.
  84. """
  85. if target_type == "binary" and pos_label == classes[0]:
  86. return -1 * y_pred
  87. return y_pred
  88. def _get_response_values(
  89. estimator,
  90. X,
  91. response_method,
  92. pos_label=None,
  93. ):
  94. """Compute the response values of a classifier or a regressor.
  95. The response values are predictions such that it follows the following shape:
  96. - for binary classification, it is a 1d array of shape `(n_samples,)`;
  97. - for multiclass classification, it is a 2d array of shape `(n_samples, n_classes)`;
  98. - for multilabel classification, it is a 2d array of shape `(n_samples, n_outputs)`;
  99. - for regression, it is a 1d array of shape `(n_samples,)`.
  100. If `estimator` is a binary classifier, also return the label for the
  101. effective positive class.
  102. This utility is used primarily in the displays and the scikit-learn scorers.
  103. .. versionadded:: 1.3
  104. Parameters
  105. ----------
  106. estimator : estimator instance
  107. Fitted classifier or regressor or a fitted :class:`~sklearn.pipeline.Pipeline`
  108. in which the last estimator is a classifier or a regressor.
  109. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  110. Input values.
  111. response_method : {"predict_proba", "decision_function", "predict"} or \
  112. list of such str
  113. Specifies the response method to use get prediction from an estimator
  114. (i.e. :term:`predict_proba`, :term:`decision_function` or
  115. :term:`predict`). Possible choices are:
  116. - if `str`, it corresponds to the name to the method to return;
  117. - if a list of `str`, it provides the method names in order of
  118. preference. The method returned corresponds to the first method in
  119. the list and which is implemented by `estimator`.
  120. pos_label : int, float, bool or str, default=None
  121. The class considered as the positive class when computing
  122. the metrics. By default, `estimators.classes_[1]` is
  123. considered as the positive class.
  124. Returns
  125. -------
  126. y_pred : ndarray of shape (n_samples,), (n_samples, n_classes) or \
  127. (n_samples, n_outputs)
  128. Target scores calculated from the provided `response_method`
  129. and `pos_label`.
  130. pos_label : int, float, bool, str or None
  131. The class considered as the positive class when computing
  132. the metrics. Returns `None` if `estimator` is a regressor.
  133. Raises
  134. ------
  135. ValueError
  136. If `pos_label` is not a valid label.
  137. If the shape of `y_pred` is not consistent for binary classifier.
  138. If the response method can be applied to a classifier only and
  139. `estimator` is a regressor.
  140. """
  141. from sklearn.base import is_classifier # noqa
  142. if is_classifier(estimator):
  143. prediction_method = _check_response_method(estimator, response_method)
  144. classes = estimator.classes_
  145. target_type = type_of_target(classes)
  146. if target_type in ("binary", "multiclass"):
  147. if pos_label is not None and pos_label not in classes.tolist():
  148. raise ValueError(
  149. f"pos_label={pos_label} is not a valid label: It should be "
  150. f"one of {classes}"
  151. )
  152. elif pos_label is None and target_type == "binary":
  153. pos_label = classes[-1]
  154. y_pred = prediction_method(X)
  155. if prediction_method.__name__ == "predict_proba":
  156. y_pred = _process_predict_proba(
  157. y_pred=y_pred,
  158. target_type=target_type,
  159. classes=classes,
  160. pos_label=pos_label,
  161. )
  162. elif prediction_method.__name__ == "decision_function":
  163. y_pred = _process_decision_function(
  164. y_pred=y_pred,
  165. target_type=target_type,
  166. classes=classes,
  167. pos_label=pos_label,
  168. )
  169. else: # estimator is a regressor
  170. if response_method != "predict":
  171. raise ValueError(
  172. f"{estimator.__class__.__name__} should either be a classifier to be "
  173. f"used with response_method={response_method} or the response_method "
  174. "should be 'predict'. Got a regressor with response_method="
  175. f"{response_method} instead."
  176. )
  177. y_pred, pos_label = estimator.predict(X), None
  178. return y_pred, pos_label
  179. def _get_response_values_binary(estimator, X, response_method, pos_label=None):
  180. """Compute the response values of a binary classifier.
  181. Parameters
  182. ----------
  183. estimator : estimator instance
  184. Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
  185. in which the last estimator is a binary classifier.
  186. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  187. Input values.
  188. response_method : {'auto', 'predict_proba', 'decision_function'}
  189. Specifies whether to use :term:`predict_proba` or
  190. :term:`decision_function` as the target response. If set to 'auto',
  191. :term:`predict_proba` is tried first and if it does not exist
  192. :term:`decision_function` is tried next.
  193. pos_label : int, float, bool or str, default=None
  194. The class considered as the positive class when computing
  195. the metrics. By default, `estimators.classes_[1]` is
  196. considered as the positive class.
  197. Returns
  198. -------
  199. y_pred : ndarray of shape (n_samples,)
  200. Target scores calculated from the provided response_method
  201. and pos_label.
  202. pos_label : int, float, bool or str
  203. The class considered as the positive class when computing
  204. the metrics.
  205. """
  206. classification_error = "Expected 'estimator' to be a binary classifier."
  207. check_is_fitted(estimator)
  208. if not is_classifier(estimator):
  209. raise ValueError(
  210. classification_error + f" Got {estimator.__class__.__name__} instead."
  211. )
  212. elif len(estimator.classes_) != 2:
  213. raise ValueError(
  214. classification_error + f" Got {len(estimator.classes_)} classes instead."
  215. )
  216. if response_method == "auto":
  217. response_method = ["predict_proba", "decision_function"]
  218. return _get_response_values(
  219. estimator,
  220. X,
  221. response_method,
  222. pos_label=pos_label,
  223. )