_plotting.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import numpy as np
  2. from . import check_consistent_length, check_matplotlib_support
  3. from ._response import _get_response_values_binary
  4. from .multiclass import type_of_target
  5. from .validation import _check_pos_label_consistency
  6. class _BinaryClassifierCurveDisplayMixin:
  7. """Mixin class to be used in Displays requiring a binary classifier.
  8. The aim of this class is to centralize some validations regarding the estimator and
  9. the target and gather the response of the estimator.
  10. """
  11. def _validate_plot_params(self, *, ax=None, name=None):
  12. check_matplotlib_support(f"{self.__class__.__name__}.plot")
  13. import matplotlib.pyplot as plt
  14. if ax is None:
  15. _, ax = plt.subplots()
  16. name = self.estimator_name if name is None else name
  17. return ax, ax.figure, name
  18. @classmethod
  19. def _validate_and_get_response_values(
  20. cls, estimator, X, y, *, response_method="auto", pos_label=None, name=None
  21. ):
  22. check_matplotlib_support(f"{cls.__name__}.from_estimator")
  23. name = estimator.__class__.__name__ if name is None else name
  24. y_pred, pos_label = _get_response_values_binary(
  25. estimator,
  26. X,
  27. response_method=response_method,
  28. pos_label=pos_label,
  29. )
  30. return y_pred, pos_label, name
  31. @classmethod
  32. def _validate_from_predictions_params(
  33. cls, y_true, y_pred, *, sample_weight=None, pos_label=None, name=None
  34. ):
  35. check_matplotlib_support(f"{cls.__name__}.from_predictions")
  36. if type_of_target(y_true) != "binary":
  37. raise ValueError(
  38. f"The target y is not binary. Got {type_of_target(y_true)} type of"
  39. " target."
  40. )
  41. check_consistent_length(y_true, y_pred, sample_weight)
  42. pos_label = _check_pos_label_consistency(pos_label, y_true)
  43. name = name if name is not None else "Classifier"
  44. return pos_label, name
  45. def _validate_score_name(score_name, scoring, negate_score):
  46. """Validate the `score_name` parameter.
  47. If `score_name` is provided, we just return it as-is.
  48. If `score_name` is `None`, we use `Score` if `negate_score` is `False` and
  49. `Negative score` otherwise.
  50. If `score_name` is a string or a callable, we infer the name. We replace `_` by
  51. spaces and capitalize the first letter. We remove `neg_` and replace it by
  52. `"Negative"` if `negate_score` is `False` or just remove it otherwise.
  53. """
  54. if score_name is not None:
  55. return score_name
  56. elif scoring is None:
  57. return "Negative score" if negate_score else "Score"
  58. else:
  59. score_name = scoring.__name__ if callable(scoring) else scoring
  60. if negate_score:
  61. if score_name.startswith("neg_"):
  62. score_name = score_name[4:]
  63. else:
  64. score_name = f"Negative {score_name}"
  65. elif score_name.startswith("neg_"):
  66. score_name = f"Negative {score_name[4:]}"
  67. score_name = score_name.replace("_", " ")
  68. return score_name.capitalize()
  69. def _interval_max_min_ratio(data):
  70. """Compute the ratio between the largest and smallest inter-point distances.
  71. A value larger than 5 typically indicates that the parameter range would
  72. better be displayed with a log scale while a linear scale would be more
  73. suitable otherwise.
  74. """
  75. diff = np.diff(np.sort(data))
  76. return diff.max() / diff.min()